from datetime import date, datetime
from decimal import Decimal
from typing import Annotated

from fastapi import APIRouter, Depends, HTTPException, Query, status
from pydantic import BaseModel, ConfigDict
from sqlalchemy.orm import Session

from app.core.database import get_db
from app.core.models import PayrollRecord
from app.core.security import get_current_user
from app.services.payroll_service import calculate_payroll

router = APIRouter(prefix="/api/payroll", tags=["payroll"])


class PayrollRecordBase(BaseModel):
    employee_id: int
    period_start: date
    period_end: date
    pay_date: date
    gross_pay: Decimal = Decimal("0")
    bonus_amount: Decimal = Decimal("0")
    allowance_amount: Decimal = Decimal("0")
    deduction_amount: Decimal = Decimal("0")
    paye_amount: Decimal = Decimal("0")
    kiwisaver_employee: Decimal = Decimal("0")
    kiwisaver_employer: Decimal = Decimal("0")
    acc_amount: Decimal = Decimal("0")
    other_tax_amount: Decimal = Decimal("0")
    net_pay: Decimal = Decimal("0")
    calculation_mode: str = "manual"
    status: str = "draft"
    notes: str | None = None


class PayrollRecordCreate(PayrollRecordBase):
    pass


class PayrollRecordUpdate(BaseModel):
    period_start: date | None = None
    period_end: date | None = None
    pay_date: date | None = None
    gross_pay: Decimal | None = None
    bonus_amount: Decimal | None = None
    allowance_amount: Decimal | None = None
    deduction_amount: Decimal | None = None
    paye_amount: Decimal | None = None
    kiwisaver_employee: Decimal | None = None
    kiwisaver_employer: Decimal | None = None
    acc_amount: Decimal | None = None
    other_tax_amount: Decimal | None = None
    net_pay: Decimal | None = None
    calculation_mode: str | None = None
    status: str | None = None
    notes: str | None = None


class PayrollRecordResponse(PayrollRecordBase):
    id: int
    created_at: datetime | None = None
    updated_at: datetime | None = None

    model_config = ConfigDict(from_attributes=True, arbitrary_types_allowed=True)


class PaginatedPayrollRecords(BaseModel):
    items: list[PayrollRecordResponse]
    total: int
    page: int
    page_size: int


class PayrollCalculationInput(BaseModel):
    gross_pay: Decimal = Decimal("0")
    bonus_amount: Decimal = Decimal("0")
    allowance_amount: Decimal = Decimal("0")
    deduction_amount: Decimal = Decimal("0")
    paye_amount: Decimal = Decimal("0")
    kiwisaver_employee: Decimal = Decimal("0")
    kiwisaver_employer: Decimal = Decimal("0")
    acc_amount: Decimal = Decimal("0")
    other_tax_amount: Decimal = Decimal("0")


def _payroll_or_404(db: Session, record_id: int) -> PayrollRecord:
    record = db.get(PayrollRecord, record_id)
    if not record:
        raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Payroll record not found")
    return record


@router.get("", response_model=PaginatedPayrollRecords)
def list_payroll(
    current_user: Annotated[dict, Depends(get_current_user)],
    page: int = Query(1, ge=1),
    page_size: int = Query(20, ge=1, le=100),
    employee_id: int | None = Query(None),
    status_filter: str = Query("", alias="status"),
    db: Session = Depends(get_db),
):
    query = db.query(PayrollRecord)
    if employee_id is not None:
        query = query.filter(PayrollRecord.employee_id == employee_id)
    if status_filter:
        query = query.filter(PayrollRecord.status == status_filter)

    total = query.count()
    items = (
        query.order_by(PayrollRecord.pay_date.desc(), PayrollRecord.id.desc())
        .offset((page - 1) * page_size)
        .limit(page_size)
        .all()
    )
    return PaginatedPayrollRecords(
        items=[PayrollRecordResponse.model_validate(item) for item in items],
        total=total,
        page=page,
        page_size=page_size,
    )


@router.post("", response_model=PayrollRecordResponse, status_code=status.HTTP_201_CREATED)
def create_payroll_record(
    data: PayrollRecordCreate,
    current_user: Annotated[dict, Depends(get_current_user)],
    db: Session = Depends(get_db),
):
    payload = data.model_dump()
    if not payload.get("net_pay"):
        payload["net_pay"] = calculate_payroll(payload)["net_pay"]
    record = PayrollRecord(**payload)
    db.add(record)
    db.commit()
    db.refresh(record)
    return PayrollRecordResponse.model_validate(record)


@router.get("/{record_id}", response_model=PayrollRecordResponse)
def get_payroll_record(
    record_id: int,
    current_user: Annotated[dict, Depends(get_current_user)],
    db: Session = Depends(get_db),
):
    return PayrollRecordResponse.model_validate(_payroll_or_404(db, record_id))


@router.put("/{record_id}", response_model=PayrollRecordResponse)
def update_payroll_record(
    record_id: int,
    data: PayrollRecordUpdate,
    current_user: Annotated[dict, Depends(get_current_user)],
    db: Session = Depends(get_db),
):
    record = _payroll_or_404(db, record_id)
    for field, value in data.model_dump(exclude_unset=True).items():
        setattr(record, field, value)
    db.commit()
    db.refresh(record)
    return PayrollRecordResponse.model_validate(record)


@router.post("/calculate")
def calculate_payroll_preview(data: PayrollCalculationInput):
    return calculate_payroll(data.model_dump())


@router.post("/{record_id}/confirm")
def confirm_payroll_record(
    record_id: int,
    current_user: Annotated[dict, Depends(get_current_user)],
    db: Session = Depends(get_db),
):
    record = _payroll_or_404(db, record_id)
    record.status = "confirmed"
    db.commit()
    return {"ok": True}


@router.post("/{record_id}/void")
def void_payroll_record(
    record_id: int,
    current_user: Annotated[dict, Depends(get_current_user)],
    db: Session = Depends(get_db),
):
    record = _payroll_or_404(db, record_id)
    record.status = "void"
    db.commit()
    return {"ok": True}
