"""
Statistics API for dashboard overview and charts.
"""
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from sqlalchemy import func, extract
from datetime import datetime, date
from decimal import Decimal
from app.core.database import get_db
from app.core.models import Invoice, InvoiceStatus, Customer, Subscription, SubscriptionStatus
from app.core.security import get_current_user
from pydantic import BaseModel
from typing import Annotated, Optional

router = APIRouter(prefix="/api/stats", tags=["stats"])


class OverviewStats(BaseModel):
    total_revenue: float
    total_customers: int
    total_invoices: int
    paid_invoices: int
    overdue_invoices: int
    draft_invoices: int
    active_subscriptions: int


class MonthlyRevenue(BaseModel):
    month: str
    revenue: float
    currency: str = "NZD"


class InvoiceStatusCount(BaseModel):
    status: str
    count: int


@router.get("/overview", response_model=OverviewStats)
def get_overview(
    current_user: Annotated[dict, Depends(get_current_user)],
    db: Session = Depends(get_db),
):
    """Get dashboard overview stats."""
    total_revenue = db.query(func.coalesce(func.sum(Invoice.total_amount), 0)).filter(
        Invoice.status == InvoiceStatus.paid
    ).scalar() or Decimal("0")

    total_customers = db.query(func.count(Customer.id)).scalar() or 0
    total_invoices = db.query(func.count(Invoice.id)).scalar() or 0
    paid_invoices = db.query(func.count(Invoice.id)).filter(
        Invoice.status == InvoiceStatus.paid
    ).scalar() or 0
    overdue_invoices = db.query(func.count(Invoice.id)).filter(
        Invoice.status == InvoiceStatus.overdue
    ).scalar() or 0
    draft_invoices = db.query(func.count(Invoice.id)).filter(
        Invoice.status == InvoiceStatus.draft
    ).scalar() or 0
    active_subscriptions = db.query(func.count(Subscription.id)).filter(
        Subscription.status == SubscriptionStatus.active
    ).scalar() or 0

    return OverviewStats(
        total_revenue=float(total_revenue),
        total_customers=int(total_customers),
        total_invoices=int(total_invoices),
        paid_invoices=int(paid_invoices),
        overdue_invoices=int(overdue_invoices),
        draft_invoices=int(draft_invoices),
        active_subscriptions=int(active_subscriptions),
    )


@router.get("/monthly-revenue", response_model=list[MonthlyRevenue])
def get_monthly_revenue(
    current_user: Annotated[dict, Depends(get_current_user)],
    db: Session = Depends(get_db),
    year: Optional[int] = None,
):
    """Get monthly revenue for a given year (default: current year)."""
    if year is None:
        year = datetime.now().year

    results = (
        db.query(
            extract("month", Invoice.invoice_date).label("month"),
            func.sum(Invoice.total_amount).label("revenue"),
        )
        .filter(
            Invoice.status == InvoiceStatus.paid,
            extract("year", Invoice.invoice_date) == year,
        )
        .group_by(extract("month", Invoice.invoice_date))
        .order_by(extract("month", Invoice.invoice_date))
        .all()
    )

    month_names = ["Jan","Feb","Mar","Apr","May","Jun","Jul","Aug","Sep","Oct","Nov","Dec"]
    return [
        MonthlyRevenue(month=month_names[int(r.month)-1], revenue=float(r.revenue or 0))
        for r in results
    ]


@router.get("/invoice-status", response_model=list[InvoiceStatusCount])
def get_invoice_status_distribution(
    current_user: Annotated[dict, Depends(get_current_user)],
    db: Session = Depends(get_db),
):
    """Get invoice count grouped by status."""
    results = (
        db.query(Invoice.status, func.count(Invoice.id))
        .group_by(Invoice.status)
        .all()
    )
    return [InvoiceStatusCount(status=str(r.status), count=int(r[1])) for r in results]


@router.get("/income-by-type")
def get_income_by_type(
    current_user: Annotated[dict, Depends(get_current_user)],
    db: Session = Depends(get_db),
):
    """Get income breakdown by customer type (project/subscription/mixed)."""
    results = (
        db.query(
            Customer.customer_type,
            func.sum(Invoice.total_amount).label("total"),
        )
        .join(Invoice, Invoice.customer_id == Customer.id)
        .filter(Invoice.status == InvoiceStatus.paid)
        .group_by(Customer.customer_type)
        .all()
    )
    return [
        {"type": str(r.customer_type), "total": float(r.total or 0)}
        for r in results
    ]
