from calendar import monthrange
from datetime import date, timedelta
from decimal import Decimal
from typing import Annotated

from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel
from sqlalchemy import func
from sqlalchemy.orm import Session

from app.core.database import get_db
from app.core.models import Expense, GstReturn, GstSetting, Invoice, Payment, PayrollRecord
from app.core.security import get_current_user
from app.services.gst_service import (
    build_gst_return_draft,
    create_gst_adjustment,
    delete_gst_adjustment,
    delete_gst_return,
    file_gst_return,
    get_gst_return,
    lock_gst_return,
    prepare_gst_return,
    save_gst_return_draft,
    serialize_gst_return,
)

router = APIRouter(prefix="/api/finance", tags=["finance"])

REVENUE_STATUSES = ["issued", "sent", "paid", "partially_paid", "overdue"]


class FinanceSummaryResponse(BaseModel):
    revenue_total: Decimal
    cash_received_total: Decimal
    invoice_revenue_total: Decimal
    expense_total: Decimal
    confirmed_expense_total: Decimal
    payroll_total: Decimal
    gst_sales_and_income: Decimal = Decimal("0")
    gst_output: Decimal = Decimal("0")
    gst_purchases_expenses: Decimal = Decimal("0")
    gst_input: Decimal = Decimal("0")
    gst_payable: Decimal = Decimal("0")
    net_result: Decimal
    start_date: date | None = None
    end_date: date | None = None
    mode: str


class GstSettingPayload(BaseModel):
    cycle_months: int
    cycle_start_date: date
    enabled_from: date | None = None
    reporting_label: str | None = None
    default_basis: str = "invoice_date"


class GstReturnPeriodPayload(BaseModel):
    period_start: date
    period_end: date


class GstAdjustmentPayload(BaseModel):
    adjustment_type: str
    reason: str
    amount: Decimal
    target_period_start: date
    target_period_end: date
    source_period_start: date | None = None
    source_period_end: date | None = None
    linked_invoice_id: int | None = None
    linked_payment_id: int | None = None
    linked_expense_id: int | None = None


def _add_months(anchor: date, months: int) -> date:
    zero_based = anchor.month - 1 + months
    year = anchor.year + zero_based // 12
    month = zero_based % 12 + 1
    day = min(anchor.day, monthrange(year, month)[1])
    return date(year, month, day)


def _subtract_months(anchor: date, months: int) -> date:
    return _add_months(anchor, -months)


def _default_gst_setting() -> GstSettingPayload:
    return GstSettingPayload(cycle_months=3, cycle_start_date=date(date.today().year, 1, 1))


def _get_latest_gst_setting(db: Session) -> GstSettingPayload:
    setting = db.query(GstSetting).order_by(GstSetting.id.desc()).first()
    if not setting:
        return _default_gst_setting()
    return GstSettingPayload(
        cycle_months=setting.cycle_months,
        cycle_start_date=setting.cycle_start_date,
        enabled_from=setting.enabled_from,
        reporting_label=setting.reporting_label,
        default_basis=setting.default_basis,
    )


def _build_periods(setting: GstSettingPayload, anchor_date: date) -> list[dict]:
    periods = []
    start = setting.cycle_start_date

    while start <= anchor_date:
        next_start = _add_months(start, setting.cycle_months)
        end = next_start - timedelta(days=1)
        periods.append({
            "start_date": start.isoformat(),
            "end_date": end.isoformat(),
        })
        start = next_start

    return periods[-12:]


def _current_gst_range(setting: GstSettingPayload, anchor_date: date) -> tuple[date, date]:
    start = setting.cycle_start_date
    next_start = _add_months(start, setting.cycle_months)

    while next_start <= anchor_date:
        start = next_start
        next_start = _add_months(start, setting.cycle_months)

    return start, next_start - timedelta(days=1)


def _resolve_range(
    db: Session,
    mode: str,
    start_date: date | None,
    end_date: date | None,
) -> tuple[date | None, date | None]:
    if start_date or end_date:
        return start_date, end_date

    today = date.today()

    if mode == "calendar":
        return date(today.year, 1, 1), date(today.year, 12, 31)

    if mode == "fiscal":
        if today.month >= 4:
            return date(today.year, 4, 1), date(today.year + 1, 3, 31)
        return date(today.year - 1, 4, 1), date(today.year, 3, 31)

    if mode == "gst":
        return _current_gst_range(_get_latest_gst_setting(db), today)

    if mode == "custom":
        return None, None

    current_month_start = date(today.year, today.month, 1)
    current_month_end = date(today.year, today.month, monthrange(today.year, today.month)[1])
    return current_month_start, current_month_end


def _fiscal_year_start(value: date) -> int:
    return value.year if value.month >= 4 else value.year - 1


def _option(value: int, label: str) -> dict[str, int | str]:
    return {"value": value, "label": label}


def _collect_data_dates(db: Session) -> list[date]:
    payment_period_date = func.coalesce(Payment.received_date, func.date(Payment.paid_at))
    dates: list[date] = []

    for value in db.query(Invoice.invoice_date).filter(Invoice.status.in_(REVENUE_STATUSES)).all():
        if value[0]:
            dates.append(value[0])
    for value in (
        db.query(payment_period_date)
        .join(Invoice, Payment.invoice_id == Invoice.id)
        .filter(Invoice.status.in_(REVENUE_STATUSES))
        .filter(payment_period_date.isnot(None))
        .all()
    ):
        raw_date = value[0]
        if isinstance(raw_date, str):
            raw_date = date.fromisoformat(raw_date)
        if raw_date:
            dates.append(raw_date)
    for value in db.query(Expense.expense_date).filter(Expense.status == "confirmed").all():
        if value[0]:
            dates.append(value[0])
    for value in db.query(PayrollRecord.pay_date).filter(PayrollRecord.status == "confirmed").all():
        if value[0]:
            dates.append(value[0])

    return dates


@router.get("/year-options")
def get_finance_year_options(
    current_user: Annotated[dict, Depends(get_current_user)],
    db: Session = Depends(get_db),
):
    data_dates = _collect_data_dates(db)
    calendar_years = sorted({item.year for item in data_dates}, reverse=True)
    fiscal_years = sorted({_fiscal_year_start(item) for item in data_dates}, reverse=True)

    return {
        "calendar_years": [_option(year, f"{year} 自然年") for year in calendar_years],
        "fiscal_years": [_option(year, f"{year}-{year + 1} 财年") for year in fiscal_years],
    }


@router.get("/summary", response_model=FinanceSummaryResponse)
def get_finance_summary(
    current_user: Annotated[dict, Depends(get_current_user)],
    mode: str = Query("fiscal"),
    start_date: date | None = Query(None),
    end_date: date | None = Query(None),
    db: Session = Depends(get_db),
):
    resolved_start, resolved_end = _resolve_range(db, mode, start_date, end_date)
    payment_period_date = func.coalesce(Payment.received_date, func.date(Payment.paid_at))

    invoice_revenue_query = db.query(func.coalesce(func.sum(Invoice.total_amount), 0)).filter(
        Invoice.status.in_(REVENUE_STATUSES)
    )
    cash_received_query = (
        db.query(func.coalesce(func.sum(Payment.amount), 0))
        .join(Invoice, Payment.invoice_id == Invoice.id)
        .filter(Invoice.status.in_(REVENUE_STATUSES))
        .filter(payment_period_date.isnot(None))
    )
    expense_query = db.query(func.coalesce(func.sum(Expense.amount_gross), 0)).filter(Expense.status == "confirmed")
    payroll_query = db.query(
        func.coalesce(func.sum(PayrollRecord.gross_pay + PayrollRecord.kiwisaver_employer), 0)
    ).filter(PayrollRecord.status == "confirmed")

    if resolved_start:
        invoice_revenue_query = invoice_revenue_query.filter(Invoice.invoice_date >= resolved_start)
        cash_received_query = cash_received_query.filter(payment_period_date >= resolved_start)
        expense_query = expense_query.filter(Expense.expense_date >= resolved_start)
        payroll_query = payroll_query.filter(PayrollRecord.pay_date >= resolved_start)
    if resolved_end:
        invoice_revenue_query = invoice_revenue_query.filter(Invoice.invoice_date <= resolved_end)
        cash_received_query = cash_received_query.filter(payment_period_date <= resolved_end)
        expense_query = expense_query.filter(Expense.expense_date <= resolved_end)
        payroll_query = payroll_query.filter(PayrollRecord.pay_date <= resolved_end)

    cash_received_total = Decimal(str(cash_received_query.scalar() or 0))
    invoice_revenue_total = Decimal(str(invoice_revenue_query.scalar() or 0))
    expense_total = Decimal(str(expense_query.scalar() or 0))
    payroll_total = Decimal(str(payroll_query.scalar() or 0))
    gst_sales_and_income = Decimal("0")
    gst_output = Decimal("0")
    gst_purchases_expenses = Decimal("0")
    gst_input = Decimal("0")
    gst_payable = Decimal("0")

    if resolved_start and resolved_end:
        gst_draft = build_gst_return_draft(db, resolved_start, resolved_end)
        gst_sales_and_income = Decimal(str(gst_draft.get("total_sales_and_income") or 0))
        gst_output = Decimal(str(gst_draft.get("gst_output") or 0))
        gst_purchases_expenses = Decimal(str(gst_draft.get("total_purchases_expenses") or 0))
        gst_input = Decimal(str(gst_draft.get("gst_input") or 0))
        gst_payable = Decimal(str(gst_draft.get("gst_payable") or 0))

    return FinanceSummaryResponse(
        revenue_total=cash_received_total,
        cash_received_total=cash_received_total,
        invoice_revenue_total=invoice_revenue_total,
        expense_total=expense_total,
        confirmed_expense_total=expense_total,
        payroll_total=payroll_total,
        gst_sales_and_income=gst_sales_and_income,
        gst_output=gst_output,
        gst_purchases_expenses=gst_purchases_expenses,
        gst_input=gst_input,
        gst_payable=gst_payable,
        net_result=cash_received_total - expense_total - payroll_total,
        start_date=resolved_start,
        end_date=resolved_end,
        mode=mode,
    )


@router.get("/gst-periods")
def get_gst_periods(
    current_user: Annotated[dict, Depends(get_current_user)],
    anchor_date: date | None = Query(None),
    db: Session = Depends(get_db),
):
    setting = _get_latest_gst_setting(db)
    resolved_anchor = anchor_date or date.today()
    return {
        "cycle_months": setting.cycle_months,
        "cycle_start_date": setting.cycle_start_date,
        "enabled_from": setting.enabled_from,
        "reporting_label": setting.reporting_label,
        "default_basis": setting.default_basis,
        "periods": _build_periods(setting, resolved_anchor),
    }


@router.get("/gst-return/draft")
def get_gst_return_draft(
    current_user: Annotated[dict, Depends(get_current_user)],
    period_start: date = Query(...),
    period_end: date = Query(...),
    db: Session = Depends(get_db),
):
    if period_end < period_start:
        raise HTTPException(status_code=400, detail="period_end must be on or after period_start")
    return jsonable_encoder(build_gst_return_draft(db, period_start, period_end))


@router.get("/gst-returns")
def list_gst_returns(
    current_user: Annotated[dict, Depends(get_current_user)],
    limit: int = Query(50, ge=1, le=200),
    offset: int = Query(0, ge=0),
    db: Session = Depends(get_db),
):
    query = db.query(GstReturn)
    total = query.count()
    returns = (
        query.order_by(GstReturn.period_start.desc(), GstReturn.id.desc())
        .offset(offset)
        .limit(limit)
        .all()
    )
    return {
        "items": [serialize_gst_return(item) | {"rows": []} for item in returns],
        "total": total,
        "limit": limit,
        "offset": offset,
    }


def _validate_gst_period(period_start: date, period_end: date) -> None:
    if period_end < period_start:
        raise HTTPException(status_code=400, detail="period_end must be on or after period_start")


def _serialize_or_raise(action, *args):
    try:
        return jsonable_encoder(serialize_gst_return(action(*args)))
    except LookupError as exc:
        raise HTTPException(status_code=404, detail=str(exc)) from exc
    except ValueError as exc:
        message = str(exc)
        status_code = 409 if "already exists" in message or message.startswith("Only ") else 400
        raise HTTPException(status_code=status_code, detail=message) from exc


def _json_or_raise(action, *args):
    try:
        return jsonable_encoder(action(*args))
    except LookupError as exc:
        raise HTTPException(status_code=404, detail=str(exc)) from exc
    except ValueError as exc:
        message = str(exc)
        status_code = 409 if "cannot be deleted" in message or message.startswith("Only ") else 400
        raise HTTPException(status_code=status_code, detail=message) from exc


def _serialize_gst_adjustment(adjustment):
    return jsonable_encoder(
        {
            "id": adjustment.id,
            "adjustment_type": adjustment.adjustment_type,
            "direction": adjustment.direction,
            "reason": adjustment.reason,
            "amount": adjustment.amount,
            "source_period_start": adjustment.source_period_start,
            "source_period_end": adjustment.source_period_end,
            "target_period_start": adjustment.target_period_start,
            "target_period_end": adjustment.target_period_end,
            "linked_invoice_id": adjustment.linked_invoice_id,
            "linked_payment_id": adjustment.linked_payment_id,
            "linked_expense_id": adjustment.linked_expense_id,
            "created_at": adjustment.created_at,
        }
    )


@router.post("/gst-return/prepare")
def prepare_gst_return_endpoint(
    payload: GstReturnPeriodPayload,
    current_user: Annotated[dict, Depends(get_current_user)],
    db: Session = Depends(get_db),
):
    _validate_gst_period(payload.period_start, payload.period_end)
    return _serialize_or_raise(prepare_gst_return, db, payload.period_start, payload.period_end)


@router.post("/gst-return/drafts")
def save_gst_return_draft_endpoint(
    payload: GstReturnPeriodPayload,
    current_user: Annotated[dict, Depends(get_current_user)],
    db: Session = Depends(get_db),
):
    _validate_gst_period(payload.period_start, payload.period_end)
    return _serialize_or_raise(save_gst_return_draft, db, payload.period_start, payload.period_end)


@router.post("/gst-return/adjustments")
def create_gst_adjustment_endpoint(
    payload: GstAdjustmentPayload,
    current_user: Annotated[dict, Depends(get_current_user)],
    db: Session = Depends(get_db),
):
    _validate_gst_period(payload.target_period_start, payload.target_period_end)
    has_source_start = payload.source_period_start is not None
    has_source_end = payload.source_period_end is not None
    if has_source_start != has_source_end:
        raise HTTPException(
            status_code=400,
            detail="source_period_start and source_period_end must be provided together",
        )
    if (
        payload.source_period_start is not None
        and payload.source_period_end is not None
        and payload.source_period_end < payload.source_period_start
    ):
        raise HTTPException(status_code=400, detail="source_period_end must be on or after source_period_start")
    try:
        adjustment = create_gst_adjustment(
            db,
            adjustment_type=payload.adjustment_type,
            reason=payload.reason,
            amount=payload.amount,
            target_period_start=payload.target_period_start,
            target_period_end=payload.target_period_end,
            source_period_start=payload.source_period_start,
            source_period_end=payload.source_period_end,
            linked_invoice_id=payload.linked_invoice_id,
            linked_payment_id=payload.linked_payment_id,
            linked_expense_id=payload.linked_expense_id,
        )
    except ValueError as exc:
        raise HTTPException(status_code=400, detail=str(exc)) from exc
    except LookupError as exc:
        raise HTTPException(status_code=404, detail=str(exc)) from exc
    return _serialize_gst_adjustment(adjustment)


@router.delete("/gst-return/adjustments/{adjustment_id}")
def delete_gst_adjustment_endpoint(
    adjustment_id: int,
    current_user: Annotated[dict, Depends(get_current_user)],
    db: Session = Depends(get_db),
):
    try:
        return jsonable_encoder(delete_gst_adjustment(db, adjustment_id))
    except LookupError as exc:
        raise HTTPException(status_code=404, detail=str(exc)) from exc


@router.get("/gst-return/{return_id}")
def get_gst_return_endpoint(
    return_id: int,
    current_user: Annotated[dict, Depends(get_current_user)],
    db: Session = Depends(get_db),
):
    return _serialize_or_raise(get_gst_return, db, return_id)


@router.delete("/gst-return/{return_id}")
def delete_gst_return_endpoint(
    return_id: int,
    current_user: Annotated[dict, Depends(get_current_user)],
    db: Session = Depends(get_db),
):
    return _json_or_raise(delete_gst_return, db, return_id)


@router.post("/gst-return/{return_id}/lock")
def lock_gst_return_endpoint(
    return_id: int,
    current_user: Annotated[dict, Depends(get_current_user)],
    db: Session = Depends(get_db),
):
    return _serialize_or_raise(lock_gst_return, db, return_id)


@router.post("/gst-return/{return_id}/file")
def file_gst_return_endpoint(
    return_id: int,
    current_user: Annotated[dict, Depends(get_current_user)],
    db: Session = Depends(get_db),
):
    return _serialize_or_raise(file_gst_return, db, return_id)


@router.get("/gst-settings")
def get_gst_settings(
    current_user: Annotated[dict, Depends(get_current_user)],
    db: Session = Depends(get_db),
):
    return _get_latest_gst_setting(db)


@router.put("/gst-settings")
def update_gst_settings(
    data: GstSettingPayload,
    current_user: Annotated[dict, Depends(get_current_user)],
    db: Session = Depends(get_db),
):
    setting = db.query(GstSetting).order_by(GstSetting.id.desc()).first()
    if setting:
        for field, value in data.model_dump().items():
            setattr(setting, field, value)
    else:
        setting = GstSetting(**data.model_dump())
        db.add(setting)
    db.commit()
    db.refresh(setting)
    return GstSettingPayload(
        cycle_months=setting.cycle_months,
        cycle_start_date=setting.cycle_start_date,
        enabled_from=setting.enabled_from,
        reporting_label=setting.reporting_label,
        default_basis=setting.default_basis,
    )
