from __future__ import annotations

import json
from datetime import UTC, datetime
from decimal import Decimal, InvalidOperation, ROUND_HALF_UP
from typing import Any


GST_15_RATE = Decimal("0.15")
ZERO = Decimal("0")
MONEY_QUANT = Decimal("0.01")


def _decimal(value: Any) -> Decimal:
    if value is None:
        return ZERO
    if isinstance(value, Decimal):
        return value
    try:
        return Decimal(str(value))
    except (InvalidOperation, ValueError):
        return ZERO


def money(value: Any) -> Decimal:
    return _decimal(value).quantize(MONEY_QUANT, rounding=ROUND_HALF_UP)


def rate_for_mode(tax_mode: str, custom_tax_rate: Decimal | None = None) -> Decimal:
    if tax_mode == "gst_15":
        return GST_15_RATE
    if tax_mode in {"zero_rated", "no_gst"}:
        return ZERO
    if tax_mode == "custom_rate":
        rate = _decimal(custom_tax_rate)
        if rate <= ZERO:
            return ZERO
        if rate > Decimal("1"):
            return rate / Decimal("100")
        return rate
    return ZERO


def calculate_gst_from_gross(gross_amount: Decimal, tax_rate: Decimal) -> Decimal:
    gross = _decimal(gross_amount)
    rate = _decimal(tax_rate)
    if gross <= ZERO or rate <= ZERO:
        return money(ZERO)
    return money(gross * rate / (Decimal("1") + rate))


def allocate_payment_to_tax_groups(
    payment_id: Any,
    invoice_id: Any,
    received_date: Any,
    payment_amount: Decimal,
    tax_groups: list[dict[str, Any]],
) -> list[dict[str, Any]]:
    requested_payment = money(payment_amount)
    groups = [
        {
            **group,
            "gross_amount": _decimal(group.get("gross_amount")),
        }
        for group in tax_groups
        if _decimal(group.get("gross_amount")) > ZERO
    ]
    total_gross = sum(group["gross_amount"] for group in groups)
    if total_gross <= ZERO or requested_payment <= ZERO:
        return []
    payment = min(requested_payment, money(total_gross))
    invoice_gross = money(total_gross)

    rows: list[dict[str, Any]] = []
    allocated_total = ZERO
    last_index = len(groups) - 1

    for index, group in enumerate(groups):
        if index == last_index:
            allocated_gross = money(payment - allocated_total)
        else:
            allocated_gross = money(payment * group["gross_amount"] / total_gross)
            allocated_total += allocated_gross

        tax_mode = group.get("tax_mode") or "gst_15"
        custom_rate = group.get("custom_tax_rate", group.get("tax_rate"))
        tax_rate = rate_for_mode(tax_mode, _decimal(custom_rate))
        gst_amount = calculate_gst_from_gross(allocated_gross, tax_rate)
        net_amount = money(allocated_gross - gst_amount)

        rows.append(
            {
                "row_type": "payment",
                "source_id": payment_id,
                "source_date": received_date,
                "customer_or_vendor": None,
                "description": "",
                "gross_amount": allocated_gross,
                "net_amount": net_amount,
                "gst_amount": gst_amount,
                "tax_mode": tax_mode,
                "tax_rate": tax_rate,
                "metadata_json": json.dumps(
                    {
                        "invoice_id": invoice_id,
                        "invoice_gross_amount": str(invoice_gross),
                        "payment_amount": str(payment),
                        "allocated_from_tax_group_gross": str(group["gross_amount"]),
                    },
                    sort_keys=True,
                ),
            }
        )

    return rows


def calculate_expense_gst(
    gross_amount: Decimal,
    tax_mode: str,
    custom_tax_rate: Decimal | None,
    gst_claimable: bool,
    override_amount: Decimal | None = None,
) -> Decimal:
    gross = money(gross_amount)
    if not gst_claimable:
        return money(ZERO)

    tax_rate = rate_for_mode(tax_mode, custom_tax_rate)
    if tax_rate <= ZERO:
        return money(ZERO)

    if override_amount is not None:
        override = money(override_amount)
        if override < ZERO:
            return money(ZERO)
        if override > gross:
            return gross
        return override

    return calculate_gst_from_gross(gross, tax_rate)


def _status_text(value: Any) -> str:
    if value is None:
        return ""
    return str(getattr(value, "value", value)).lower()


def _customer_name(invoice: Any) -> str | None:
    customer = getattr(invoice, "customer", None)
    if not customer:
        return None
    return getattr(customer, "company_name", None) or getattr(customer, "name", None)


def _group_invoice_items(items: list[Any]) -> list[dict[str, Any]]:
    groups: dict[tuple[str, str], dict[str, Any]] = {}

    for item in items:
        tax_mode = getattr(item, "tax_mode", None) or "gst_15"
        custom_rate = getattr(item, "custom_tax_rate", None)
        stored_rate = getattr(item, "tax_rate", None)
        rate_input = custom_rate if custom_rate is not None else stored_rate
        effective_rate = rate_for_mode(tax_mode, rate_input)
        gross_amount = money(
            getattr(item, "total", None)
            or (_decimal(getattr(item, "subtotal", None)) + _decimal(getattr(item, "tax_amount", None)))
        )
        if gross_amount <= ZERO:
            continue

        key = (tax_mode, str(effective_rate))
        if key not in groups:
            group: dict[str, Any] = {
                "tax_mode": tax_mode,
                "tax_rate": effective_rate,
                "gross_amount": ZERO,
            }
            if custom_rate is not None:
                group["custom_tax_rate"] = custom_rate
            groups[key] = group
        groups[key]["gross_amount"] = money(groups[key]["gross_amount"] + gross_amount)

    return list(groups.values())


def _with_metadata(row: dict[str, Any], extra: dict[str, Any]) -> dict[str, Any]:
    try:
        metadata = json.loads(row.get("metadata_json") or "{}")
    except (TypeError, json.JSONDecodeError):
        metadata = {}
    metadata.update(extra)
    return {
        **row,
        "metadata_json": json.dumps(metadata, sort_keys=True),
    }


def _date_iso(value: Any) -> str | None:
    return value.isoformat() if hasattr(value, "isoformat") else None


def _payment_received_date(payment: Any) -> Any:
    if getattr(payment, "received_date", None):
        return payment.received_date
    paid_at = getattr(payment, "paid_at", None)
    if hasattr(paid_at, "date"):
        return paid_at.date()
    return paid_at


def build_gst_return_draft(db: Any, period_start: Any, period_end: Any) -> dict[str, Any]:
    from sqlalchemy import func
    from sqlalchemy.orm import joinedload

    from app.core.models import Expense, GstAdjustment, Invoice, Payment

    payment_rows: list[dict[str, Any]] = []
    payment_period_date = func.coalesce(Payment.received_date, func.date(Payment.paid_at))
    payments = (
        db.query(Payment)
        .options(
            joinedload(Payment.invoice).joinedload(Invoice.customer),
            joinedload(Payment.invoice).joinedload(Invoice.items),
        )
        .filter(payment_period_date >= period_start)
        .filter(payment_period_date <= period_end)
        .order_by(payment_period_date.asc(), Payment.id.asc())
        .all()
    )

    for payment in payments:
        invoice = getattr(payment, "invoice", None)
        if not invoice or _status_text(getattr(invoice, "status", None)) in {"void", "cancelled", "canceled"}:
            continue

        tax_groups = _group_invoice_items(list(getattr(invoice, "items", []) or []))
        effective_received_date = _payment_received_date(payment)
        rows = allocate_payment_to_tax_groups(
            payment_id=payment.id,
            invoice_id=getattr(invoice, "id", None),
            received_date=effective_received_date,
            payment_amount=_decimal(payment.amount),
            tax_groups=tax_groups,
        )
        invoice_number = getattr(invoice, "invoice_number", None)
        descriptions = [
            getattr(item, "description", "")
            for item in (getattr(invoice, "items", []) or [])
            if getattr(item, "description", "")
        ]
        description = invoice_number or ""
        if descriptions:
            description = f"{description}: {', '.join(descriptions)}" if description else ", ".join(descriptions)

        for row in rows:
            payment_rows.append(
                _with_metadata(
                    {
                        **row,
                        "customer_or_vendor": _customer_name(invoice),
                        "description": description,
                    },
                    {
                        "invoice_number": invoice_number,
                        "invoice_date": str(getattr(invoice, "invoice_date", "")),
                    },
                )
            )

    expense_rows: list[dict[str, Any]] = []
    expenses = (
        db.query(Expense)
        .filter(Expense.expense_date >= period_start)
        .filter(Expense.expense_date <= period_end)
        .filter(Expense.status == "confirmed")
        .order_by(Expense.expense_date.asc(), Expense.id.asc())
        .all()
    )

    for expense in expenses:
        gross_amount = money(getattr(expense, "amount_gross", None))
        tax_mode = getattr(expense, "tax_mode", None) or "gst_15"
        custom_tax_rate = getattr(expense, "custom_tax_rate", None)
        override_amount = (
            getattr(expense, "gst_amount_override", None)
            if getattr(expense, "gst_amount_overridden", False)
            else None
        )
        gst_amount = calculate_expense_gst(
            gross_amount,
            tax_mode,
            custom_tax_rate,
            bool(getattr(expense, "gst_claimable", True)),
            override_amount,
        )
        category = getattr(expense, "category", None) or ""
        description = getattr(expense, "description", None) or category
        expense_rows.append(
            {
                "row_type": "expense",
                "source_id": expense.id,
                "source_date": expense.expense_date,
                "customer_or_vendor": getattr(expense, "vendor_name", None),
                "description": description,
                "gross_amount": gross_amount,
                "net_amount": money(gross_amount - gst_amount),
                "gst_amount": gst_amount,
                "tax_mode": tax_mode,
                "tax_rate": rate_for_mode(tax_mode, custom_tax_rate),
                "metadata_json": json.dumps(
                    {
                        "category": category,
                        "gst_claimable": bool(getattr(expense, "gst_claimable", True)),
                        "gst_amount_overridden": bool(getattr(expense, "gst_amount_overridden", False)),
                    },
                    sort_keys=True,
                ),
            }
        )

    adjustment_rows: list[dict[str, Any]] = []
    debit_adjustments = ZERO
    credit_adjustments = ZERO
    adjustments = (
        db.query(GstAdjustment)
        .filter(GstAdjustment.target_period_start == period_start)
        .filter(GstAdjustment.target_period_end == period_end)
        .order_by(GstAdjustment.id.asc())
        .all()
    )
    for adjustment in adjustments:
        amount = money(getattr(adjustment, "amount", None))
        direction = _status_text(getattr(adjustment, "direction", None) or getattr(adjustment, "adjustment_type", None))
        is_credit = direction in {"credit", "input", "purchase", "purchases", "expense", "expenses"}
        if is_credit:
            credit_adjustments = money(credit_adjustments + amount)
        else:
            debit_adjustments = money(debit_adjustments + amount)
        adjustment_date = getattr(adjustment, "adjustment_date", None)
        source_date = adjustment_date.date() if hasattr(adjustment_date, "date") else period_end
        adjustment_rows.append(
            {
                "row_type": "adjustment",
                "source_id": adjustment.id,
                "source_date": source_date,
                "customer_or_vendor": None,
                "description": getattr(adjustment, "reason", None) or getattr(adjustment, "adjustment_type", ""),
                "adjustment_type": direction,
                "reason": getattr(adjustment, "reason", None),
                "gross_amount": amount,
                "net_amount": money(ZERO),
                "gst_amount": amount,
                "tax_mode": "adjustment",
                "tax_rate": ZERO,
                "metadata_json": json.dumps(
                    {
                        "adjustment_type": getattr(adjustment, "adjustment_type", None),
                        "direction": getattr(adjustment, "direction", None),
                        "source_type": getattr(adjustment, "source_type", None),
                        "source_id": getattr(adjustment, "source_id", None),
                        "linked_invoice_id": getattr(adjustment, "linked_invoice_id", None),
                        "linked_payment_id": getattr(adjustment, "linked_payment_id", None),
                        "linked_expense_id": getattr(adjustment, "linked_expense_id", None),
                        "source_period_start": _date_iso(getattr(adjustment, "source_period_start", None)),
                        "source_period_end": _date_iso(getattr(adjustment, "source_period_end", None)),
                        "target_period_start": _date_iso(getattr(adjustment, "target_period_start", None)),
                        "target_period_end": _date_iso(getattr(adjustment, "target_period_end", None)),
                    },
                    sort_keys=True,
                ),
            }
        )

    total_sales_and_income = money(sum(row["gross_amount"] for row in payment_rows))
    gst_output = money(sum(row["gst_amount"] for row in payment_rows))
    total_purchases_expenses = money(sum(row["gross_amount"] for row in expense_rows))
    gst_input = money(sum(row["gst_amount"] for row in expense_rows))
    total_gst_collected = money(gst_output + debit_adjustments)
    total_gst_purchases = money(gst_input + credit_adjustments)
    gst_payable = money(total_gst_collected - total_gst_purchases)
    totals = {
        "total_sales_and_income": total_sales_and_income,
        "gst_output": gst_output,
        "debit_adjustments": debit_adjustments,
        "total_gst_collected": total_gst_collected,
        "total_purchases_expenses": total_purchases_expenses,
        "gst_input": gst_input,
        "credit_adjustments": credit_adjustments,
        "total_gst_purchases": total_gst_purchases,
        "gst_payable": gst_payable,
    }
    rows = payment_rows + expense_rows + adjustment_rows

    return {
        "totals": totals,
        "rows": rows,
        "total_sales_and_income": total_sales_and_income,
        "gst_output": gst_output,
        "debit_adjustments": debit_adjustments,
        "total_gst_collected": total_gst_collected,
        "total_purchases_expenses": total_purchases_expenses,
        "gst_input": gst_input,
        "credit_adjustments": credit_adjustments,
        "total_gst_purchases": total_gst_purchases,
        "gst_payable": gst_payable,
        "payment_rows": payment_rows,
        "expense_rows": expense_rows,
        "adjustment_rows": adjustment_rows,
    }


def create_gst_adjustment(
    db: Any,
    *,
    adjustment_type: str,
    reason: str,
    amount: Any,
    target_period_start: Any,
    target_period_end: Any,
    source_period_start: Any | None = None,
    source_period_end: Any | None = None,
    linked_invoice_id: int | None = None,
    linked_payment_id: int | None = None,
    linked_expense_id: int | None = None,
) -> Any:
    from app.core.models import Expense, GstAdjustment, Invoice, Payment

    direction = _status_text(adjustment_type)
    if direction not in {"debit", "credit"}:
        raise ValueError("adjustment_type must be debit or credit")

    reason_text = (reason or "").strip()
    if not reason_text:
        raise ValueError("reason is required")

    adjustment_amount = money(amount)
    if adjustment_amount <= ZERO:
        raise ValueError("amount must be greater than 0")

    has_source_start = source_period_start is not None
    has_source_end = source_period_end is not None
    if has_source_start != has_source_end:
        raise ValueError("source_period_start and source_period_end must be provided together")
    if source_period_start is not None and source_period_end < source_period_start:
        raise ValueError("source_period_end must be on or after source_period_start")

    linked_invoice = None
    linked_payment = None
    if linked_invoice_id is not None:
        linked_invoice = db.query(Invoice).filter(Invoice.id == linked_invoice_id).first()
        if not linked_invoice:
            raise LookupError("linked invoice not found")
    if linked_payment_id is not None:
        linked_payment = db.query(Payment).filter(Payment.id == linked_payment_id).first()
        if not linked_payment:
            raise LookupError("linked payment not found")
    if linked_expense_id is not None:
        linked_expense = db.query(Expense).filter(Expense.id == linked_expense_id).first()
        if not linked_expense:
            raise LookupError("linked expense not found")
    if linked_invoice is not None and linked_payment is not None and linked_payment.invoice_id != linked_invoice.id:
        raise ValueError("linked payment does not belong to linked invoice")

    adjustment = GstAdjustment(
        adjustment_type=direction,
        direction=direction,
        reason=reason_text,
        amount=adjustment_amount,
        target_period_start=target_period_start,
        target_period_end=target_period_end,
        source_period_start=source_period_start,
        source_period_end=source_period_end,
        linked_invoice_id=linked_invoice_id,
        linked_payment_id=linked_payment_id,
        linked_expense_id=linked_expense_id,
        adjustment_date=_utc_now(),
    )
    db.add(adjustment)
    db.commit()
    db.refresh(adjustment)
    return adjustment


def delete_gst_adjustment(db: Any, adjustment_id: int) -> dict[str, Any]:
    from app.core.models import GstAdjustment

    adjustment = db.query(GstAdjustment).filter(GstAdjustment.id == adjustment_id).first()
    if not adjustment:
        raise LookupError("gst adjustment not found")

    deleted = {
        "id": adjustment.id,
        "adjustment_type": adjustment.adjustment_type,
        "target_period_start": adjustment.target_period_start,
        "target_period_end": adjustment.target_period_end,
    }
    db.delete(adjustment)
    db.commit()
    return deleted


def _existing_gst_return(db: Any, period_start: Any, period_end: Any) -> Any | None:
    from app.core.models import GstReturn

    return (
        db.query(GstReturn)
        .filter(GstReturn.period_start == period_start)
        .filter(GstReturn.period_end == period_end)
        .filter(GstReturn.status.in_(["draft", "prepared", "locked", "filed"]))
        .first()
    )


def _taxable_payment_total(rows: list[dict[str, Any]]) -> Decimal:
    return money(sum(row["gross_amount"] for row in rows if _decimal(row.get("tax_rate")) > ZERO))


def _zero_rated_payment_total(rows: list[dict[str, Any]]) -> Decimal:
    return money(sum(row["gross_amount"] for row in rows if _decimal(row.get("tax_rate")) <= ZERO))


def _utc_now() -> datetime:
    return datetime.now(UTC).replace(tzinfo=None)


def _apply_gst_draft_to_return(db: Any, gst_return: Any, draft: dict[str, Any]) -> None:
    from app.core.models import GstReturnSnapshotRow

    for row in list(getattr(gst_return, "snapshot_rows", []) or []):
        db.delete(row)
    db.flush()

    payment_rows = draft.get("payment_rows", [])
    gst_return.total_sales_income = draft["total_sales_and_income"]
    gst_return.zero_rated_supplies = _zero_rated_payment_total(payment_rows)
    gst_return.gst_taxable_income = _taxable_payment_total(payment_rows)
    gst_return.gst_output = draft["gst_output"]
    gst_return.debit_adjustments = draft["debit_adjustments"]
    gst_return.total_purchases_expenses = draft["total_purchases_expenses"]
    gst_return.gst_input = draft["gst_input"]
    gst_return.credit_adjustments = draft["credit_adjustments"]
    gst_return.gst_payable = draft["gst_payable"]

    for row in draft["rows"]:
        db.add(
            GstReturnSnapshotRow(
                gst_return_id=gst_return.id,
                row_type=row["row_type"],
                source_id=row.get("source_id"),
                source_date=row["source_date"],
                customer_or_vendor=row.get("customer_or_vendor"),
                description=row.get("description"),
                gross_amount=row["gross_amount"],
                net_amount=row["net_amount"],
                gst_amount=row["gst_amount"],
                tax_mode=row.get("tax_mode") or "gst_15",
                tax_rate=row.get("tax_rate"),
                metadata_json=row.get("metadata_json"),
            )
        )


def save_gst_return_draft(db: Any, period_start: Any, period_end: Any) -> Any:
    from app.core.models import GstReturn

    existing = _existing_gst_return(db, period_start, period_end)
    if existing and existing.status == "filed":
        raise ValueError("Filed GST returns cannot be updated")

    gst_return = existing or GstReturn(period_start=period_start, period_end=period_end, status="draft")
    if not existing:
        db.add(gst_return)
        db.flush()

    draft = build_gst_return_draft(db, period_start, period_end)
    gst_return.status = "draft"
    gst_return.prepared_at = None
    gst_return.locked_at = None
    gst_return.filed_at = None
    _apply_gst_draft_to_return(db, gst_return, draft)
    db.commit()
    db.refresh(gst_return)
    return gst_return


def prepare_gst_return(db: Any, period_start: Any, period_end: Any) -> Any:
    from app.core.models import GstReturn

    existing = _existing_gst_return(db, period_start, period_end)
    if existing and existing.status == "filed":
        raise ValueError("Filed GST returns cannot be updated")

    gst_return = existing or GstReturn(period_start=period_start, period_end=period_end, status="prepared")
    if not existing:
        db.add(gst_return)
        db.flush()

    draft = build_gst_return_draft(db, period_start, period_end)
    gst_return.status = "prepared"
    gst_return.prepared_at = _utc_now()
    gst_return.locked_at = None
    gst_return.filed_at = None
    _apply_gst_draft_to_return(db, gst_return, draft)
    db.commit()
    db.refresh(gst_return)
    return gst_return


def _get_gst_return_or_raise(db: Any, return_id: int) -> Any:
    from app.core.models import GstReturn

    gst_return = db.query(GstReturn).filter(GstReturn.id == return_id).first()
    if not gst_return:
        raise LookupError("GST return not found")
    return gst_return


def lock_gst_return(db: Any, return_id: int) -> Any:
    gst_return = _get_gst_return_or_raise(db, return_id)
    if gst_return.status != "prepared":
        raise ValueError("Only prepared GST returns can be locked")

    gst_return.status = "locked"
    gst_return.locked_at = _utc_now()
    db.commit()
    db.refresh(gst_return)
    return gst_return


def file_gst_return(db: Any, return_id: int) -> Any:
    gst_return = _get_gst_return_or_raise(db, return_id)
    if gst_return.status not in {"prepared", "locked"}:
        raise ValueError("Only prepared or locked GST returns can be filed")

    now = _utc_now()
    gst_return.status = "filed"
    if not gst_return.locked_at:
        gst_return.locked_at = now
    gst_return.filed_at = now
    db.commit()
    db.refresh(gst_return)
    return gst_return


def get_gst_return(db: Any, return_id: int) -> Any:
    return _get_gst_return_or_raise(db, return_id)


def delete_gst_return(db: Any, return_id: int) -> dict[str, Any]:
    gst_return = _get_gst_return_or_raise(db, return_id)
    if gst_return.status == "filed":
        raise ValueError("Filed GST returns cannot be deleted")

    deleted = {
        "id": gst_return.id,
        "status": gst_return.status,
        "period_start": gst_return.period_start,
        "period_end": gst_return.period_end,
    }
    db.delete(gst_return)
    db.commit()
    return deleted


def serialize_gst_return(gst_return: Any) -> dict[str, Any]:
    return {
        "id": gst_return.id,
        "period_start": gst_return.period_start,
        "period_end": gst_return.period_end,
        "status": gst_return.status,
        "basis": "payments",
        "total_sales_income": gst_return.total_sales_income,
        "zero_rated_supplies": gst_return.zero_rated_supplies,
        "gst_taxable_income": gst_return.gst_taxable_income,
        "gst_output": gst_return.gst_output,
        "debit_adjustments": gst_return.debit_adjustments,
        "total_purchases_expenses": gst_return.total_purchases_expenses,
        "gst_input": gst_return.gst_input,
        "credit_adjustments": gst_return.credit_adjustments,
        "gst_payable": gst_return.gst_payable,
        "prepared_at": gst_return.prepared_at,
        "locked_at": gst_return.locked_at,
        "filed_at": gst_return.filed_at,
        "rows": [
            {
                "id": row.id,
                "gst_return_id": row.gst_return_id,
                "row_type": row.row_type,
                "source_id": row.source_id,
                "source_date": row.source_date,
                "customer_or_vendor": row.customer_or_vendor,
                "description": row.description,
                "gross_amount": row.gross_amount,
                "net_amount": row.net_amount,
                "gst_amount": row.gst_amount,
                "tax_mode": row.tax_mode,
                "tax_rate": row.tax_rate,
                "metadata_json": row.metadata_json,
            }
            for row in sorted(gst_return.snapshot_rows, key=lambda snapshot_row: snapshot_row.id)
        ],
    }
