import json
from datetime import date, datetime
from decimal import Decimal, InvalidOperation, ROUND_HALF_UP
from pathlib import Path
from typing import Annotated

from fastapi import APIRouter, Depends, File, HTTPException, Query, UploadFile, status
from fastapi.responses import FileResponse
from pydantic import BaseModel, ConfigDict
from sqlalchemy.orm import Session

from app.core.database import get_db
from app.core.models import Expense, ReceiptAsset
from app.core.security import get_current_user
from app.services.ai_service import extract_receipt_data
from app.services import receipt_service

router = APIRouter(prefix="/api/expenses", tags=["expenses"])


class ExpenseBase(BaseModel):
    expense_date: date
    vendor_name: str
    category: str
    description: str | None = None
    amount_net: Decimal = Decimal("0")
    gst_amount: Decimal = Decimal("0")
    tax_mode: str = "gst_15"
    custom_tax_rate: float | None = None
    gst_amount_overridden: bool = False
    gst_amount_override: float | None = None
    amount_gross: Decimal = Decimal("0")
    currency: str = "NZD"
    payment_method: str | None = None
    gst_claimable: bool = True
    source: str = "manual"
    notes: str | None = None


class ExpenseCreate(ExpenseBase):
    receipt_asset_id: int | None = None
    status: str = "draft"


class ExpenseUpdate(BaseModel):
    expense_date: date | None = None
    vendor_name: str | None = None
    category: str | None = None
    description: str | None = None
    amount_net: Decimal | None = None
    gst_amount: Decimal | None = None
    tax_mode: str | None = None
    custom_tax_rate: float | None = None
    gst_amount_overridden: bool | None = None
    gst_amount_override: float | None = None
    amount_gross: Decimal | None = None
    currency: str | None = None
    payment_method: str | None = None
    gst_claimable: bool | None = None
    source: str | None = None
    notes: str | None = None
    receipt_asset_id: int | None = None
    status: str | None = None


class ExpenseResponse(ExpenseBase):
    id: int
    receipt_asset_id: int | None = None
    receipt_preview_url: str | None = None
    receipt_original_url: str | None = None
    receipt_processing_status: str | None = None
    receipt_ai_extraction: dict | None = None
    receipt_ocr_snapshot: dict | None = None
    status: str
    created_at: datetime | None = None
    updated_at: datetime | None = None

    model_config = ConfigDict(from_attributes=True, arbitrary_types_allowed=True)


class PaginatedExpenses(BaseModel):
    items: list[ExpenseResponse]
    total: int
    page: int
    page_size: int


def _safe_json_loads(raw: str | None) -> dict | None:
    if not raw:
        return None
    try:
        parsed = json.loads(raw)
    except json.JSONDecodeError:
        return {"raw": raw}
    return parsed if isinstance(parsed, dict) else {"value": parsed}


def _expense_to_response(expense: Expense) -> ExpenseResponse:
    if not getattr(expense, "tax_mode", None):
        expense.tax_mode = "gst_15"
    if getattr(expense, "gst_amount_overridden", None) is None:
        expense.gst_amount_overridden = False
    payload = ExpenseResponse.model_validate(expense).model_dump()
    if expense.receipt_asset_id:
        payload["receipt_preview_url"] = f"/api/expenses/receipts/{expense.receipt_asset_id}/preview"
        payload["receipt_original_url"] = f"/api/expenses/receipts/{expense.receipt_asset_id}/original"
        if expense.receipt_asset:
            payload["receipt_processing_status"] = expense.receipt_asset.processing_status
            payload["receipt_ai_extraction"] = _safe_json_loads(expense.receipt_asset.ai_extraction_snapshot_json)
            payload["receipt_ocr_snapshot"] = _safe_json_loads(expense.receipt_asset.ocr_snapshot_json)
    return ExpenseResponse.model_validate(payload)


def _current_username(current_user) -> str | None:
    if not current_user:
        return None
    if isinstance(current_user, dict):
        return current_user.get("username")
    return getattr(current_user, "username", None)


def _get_expense_or_404(db: Session, expense_id: int) -> Expense:
    expense = db.get(Expense, expense_id)
    if not expense:
        raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Expense not found")
    return expense


def _money(value) -> Decimal:
    try:
        return Decimal(str(value or 0)).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
    except InvalidOperation as exc:
        raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid expense amount") from exc


def _ensure_non_negative(value: Decimal, field_name: str) -> None:
    if value < 0:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=f"{field_name} cannot be negative",
        )


def _resolve_expense_tax_rate(tax_mode: str | None, custom_tax_rate=None) -> Decimal:
    mode = tax_mode or "gst_15"
    if mode == "gst_15":
        return Decimal("0.15")
    if mode in ("zero_rated", "no_gst"):
        return Decimal("0")
    if mode == "custom_rate":
        rate = Decimal(str(custom_tax_rate or 0))
        if rate < 0:
            return Decimal("0")
        if rate > 1:
            return rate / Decimal("100")
        return rate
    return Decimal("0.15")


def _recalculate_expense_amounts(expense: Expense) -> None:
    mode = expense.tax_mode or "gst_15"
    rate = _resolve_expense_tax_rate(mode, expense.custom_tax_rate)
    gross = _money(expense.amount_gross)
    net = _money(expense.amount_net)
    override_amount = (
        _money(expense.gst_amount_override)
        if expense.gst_amount_override is not None
        else None
    )

    _ensure_non_negative(gross, "amount_gross")
    _ensure_non_negative(net, "amount_net")
    if override_amount is not None:
        _ensure_non_negative(override_amount, "gst_amount_override")

    expense.tax_mode = mode
    expense.custom_tax_rate = rate if mode == "custom_rate" else None
    expense.gst_amount_overridden = bool(expense.gst_amount_overridden)
    expense.gst_amount_override = override_amount

    if mode in ("zero_rated", "no_gst"):
        if not gross and net:
            gross = net
        expense.amount_gross = gross
        expense.gst_amount = Decimal("0.00")
        expense.amount_net = gross
        return

    if expense.gst_amount_overridden and override_amount is not None:
        gst = override_amount
        if not gross and net:
            gross = _money(net + gst)
        if gst > gross:
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail="gst_amount_override cannot exceed amount_gross",
            )
        net = _money(gross - gst)
    elif gross:
        if rate:
            net = _money(gross / (Decimal("1") + rate))
            gst = _money(gross - net)
        else:
            gst = Decimal("0.00")
            net = gross
    elif net:
        gst = _money(net * rate)
        gross = _money(net + gst)
    else:
        gst = Decimal("0.00")
        gross = Decimal("0.00")
        net = Decimal("0.00")

    expense.amount_gross = gross
    expense.gst_amount = gst
    expense.amount_net = net


def _apply_ai_extraction_to_expense(expense: Expense, extraction: dict) -> None:
    vendor_name = extraction.get("vendor_name") or "Pending review"
    category = extraction.get("category_suggestion") or "uncategorized"
    invoice_date = extraction.get("invoice_date") or date.today().isoformat()
    try:
        expense.expense_date = date.fromisoformat(invoice_date)
    except ValueError:
        expense.expense_date = date.today()

    expense.vendor_name = vendor_name
    expense.category = category
    expense.description = extraction.get("notes") or f"Invoice {extraction.get('invoice_number') or ''}".strip()
    expense.amount_gross = Decimal(str(extraction.get("amount_gross") or 0))
    expense.gst_amount = Decimal(str(extraction.get("gst_amount") or 0))
    amount_net = extraction.get("amount_net")
    if amount_net in (None, ""):
        amount_net = expense.amount_gross - expense.gst_amount
    expense.amount_net = Decimal(str(amount_net or 0))
    expense.currency = extraction.get("currency") or "NZD"
    expense.notes = extraction.get("notes") or None
    expense.tax_mode = "gst_15"
    expense.custom_tax_rate = None
    expense.gst_amount_overridden = False
    expense.gst_amount_override = None


@router.get("", response_model=PaginatedExpenses)
def list_expenses(
    current_user: Annotated[dict, Depends(get_current_user)],
    page: int = Query(1, ge=1),
    page_size: int = Query(20, ge=1, le=100),
    status_filter: str = Query("", alias="status"),
    db: Session = Depends(get_db),
):
    query = db.query(Expense)
    if status_filter:
        query = query.filter(Expense.status == status_filter)

    total = query.count()
    items = (
        query.order_by(Expense.expense_date.desc(), Expense.id.desc())
        .offset((page - 1) * page_size)
        .limit(page_size)
        .all()
    )
    return PaginatedExpenses(
        items=[_expense_to_response(item) for item in items],
        total=total,
        page=page,
        page_size=page_size,
    )


@router.post("", response_model=ExpenseResponse, status_code=status.HTTP_201_CREATED)
def create_expense(
    data: ExpenseCreate,
    current_user: Annotated[dict, Depends(get_current_user)],
    db: Session = Depends(get_db),
):
    expense = Expense(**data.model_dump())
    _recalculate_expense_amounts(expense)
    db.add(expense)
    db.commit()
    db.refresh(expense)
    return _expense_to_response(expense)


@router.get("/{expense_id}", response_model=ExpenseResponse)
def get_expense(
    expense_id: int,
    current_user: Annotated[dict, Depends(get_current_user)],
    db: Session = Depends(get_db),
):
    expense = _get_expense_or_404(db, expense_id)
    return _expense_to_response(expense)


@router.put("/{expense_id}", response_model=ExpenseResponse)
def update_expense(
    expense_id: int,
    data: ExpenseUpdate,
    current_user: Annotated[dict, Depends(get_current_user)],
    db: Session = Depends(get_db),
):
    expense = _get_expense_or_404(db, expense_id)
    for field, value in data.model_dump(exclude_unset=True).items():
        setattr(expense, field, value)
    _recalculate_expense_amounts(expense)
    db.commit()
    db.refresh(expense)
    return _expense_to_response(expense)


@router.delete("/{expense_id}", status_code=status.HTTP_204_NO_CONTENT)
def delete_expense(
    expense_id: int,
    current_user: Annotated[dict, Depends(get_current_user)],
    db: Session = Depends(get_db),
):
    expense = _get_expense_or_404(db, expense_id)
    if expense.status != "draft":
        raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only draft expenses can be deleted")
    db.delete(expense)
    db.commit()


@router.post("/upload-receipt", status_code=status.HTTP_201_CREATED)
async def upload_receipt(
    file: UploadFile = File(...),
    current_user: Annotated[dict, Depends(get_current_user)] = None,
    db: Session = Depends(get_db),
):
    file_bytes = await file.read()
    try:
        # process_receipt_upload preserves the existing normalize_receipt_upload pipeline.
        return await receipt_service.process_receipt_upload(
            db=db,
            filename=file.filename or "receipt.bin",
            content_type=file.content_type or "application/octet-stream",
            file_bytes=file_bytes,
            upload_user=_current_username(current_user),
            source_channel="web",
        )
    except RuntimeError as exc:
        raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(exc)) from exc


@router.post("/{expense_id}/confirm")
def confirm_expense(
    expense_id: int,
    current_user: Annotated[dict, Depends(get_current_user)],
    db: Session = Depends(get_db),
):
    expense = _get_expense_or_404(db, expense_id)
    expense.status = "confirmed"
    db.commit()
    return {"ok": True}


@router.post("/{expense_id}/void")
def void_expense(
    expense_id: int,
    current_user: Annotated[dict, Depends(get_current_user)],
    db: Session = Depends(get_db),
):
    expense = _get_expense_or_404(db, expense_id)
    expense.status = "void"
    db.commit()
    return {"ok": True}


def _serve_receipt_file(db: Session, receipt_id: int, field_name: str) -> FileResponse:
    receipt = db.get(ReceiptAsset, receipt_id)
    if not receipt:
        raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Receipt not found")

    file_path = getattr(receipt, field_name)
    media_type = "image/webp" if field_name == "storage_path_preview" else receipt.mime_type
    filename = Path(file_path).name if field_name == "storage_path_preview" else receipt.original_filename
    return FileResponse(path=file_path, media_type=media_type, filename=filename)


@router.get("/receipts/{receipt_id}/preview")
def get_receipt_preview(
    receipt_id: int,
    current_user: Annotated[dict, Depends(get_current_user)],
    db: Session = Depends(get_db),
):
    return _serve_receipt_file(db, receipt_id, "storage_path_preview")


@router.get("/receipts/{receipt_id}/original")
def get_receipt_original(
    receipt_id: int,
    current_user: Annotated[dict, Depends(get_current_user)],
    db: Session = Depends(get_db),
):
    return _serve_receipt_file(db, receipt_id, "storage_path_original")
