import json
from decimal import Decimal
from datetime import datetime, date
from sqlalchemy.orm import Session
from sqlalchemy import func

from app.core.models import (
    Invoice, InvoiceItem, InvoiceEvent, InvoiceStatus, Payment
)


def generate_invoice_number(db: Session) -> str:
    """Generate next invoice number: INV-YYYYMM-NNNN
    Uses database transaction with row-level locking to avoid collisions.
    """
    now = datetime.now()
    prefix = f"INV-{now.strftime('%Y%m')}-"

    last = db.query(func.max(Invoice.invoice_number)).filter(
        Invoice.invoice_number.like(prefix + "%")
    ).scalar()

    if last:
        seq = int(last[-4:]) + 1
    else:
        seq = 1

    return f"{prefix}{seq:04d}"


def resolve_tax_rate(tax_mode: str | None, custom_tax_rate=None) -> Decimal:
    mode = tax_mode or "gst_15"
    if mode == "gst_15":
        return Decimal("0.15")
    if mode in ("zero_rated", "no_gst"):
        return Decimal("0")
    if mode == "custom_rate":
        rate = Decimal(str(custom_tax_rate or 0))
        if rate < 0:
            return Decimal("0")
        if rate > 1:
            return rate / Decimal("100")
        return rate
    return Decimal("0.15")


def calculate_item(row: dict) -> dict:
    """Calculate subtotal, tax_amount, total for an invoice item row.
    row needs: quantity, unit_price, tax_rate
    """
    qty = Decimal(str(row.get('quantity', 1)))
    price = Decimal(str(row.get('unit_price', 0)))
    if "tax_mode" in row or "custom_tax_rate" in row:
        rate = resolve_tax_rate(row.get("tax_mode"), row.get("custom_tax_rate"))
    else:
        rate = Decimal(str(row.get('tax_rate', 0.15)))

    subtotal = qty * price
    tax_amount = subtotal * rate
    total = subtotal + tax_amount

    return {
        'subtotal': subtotal,
        'tax_amount': tax_amount,
        'total': total
    }


def recalculate_invoice(db: Session, invoice: Invoice):
    """Recalculate invoice totals from its items."""
    items = db.query(InvoiceItem).filter(InvoiceItem.invoice_id == invoice.id).all()
    subtotal = sum((i.subtotal for i in items), Decimal("0"))
    total_tax = sum((i.tax_amount for i in items), Decimal("0"))
    total_amount = sum((i.total for i in items), Decimal("0"))
    invoice.subtotal = subtotal
    invoice.total_tax = total_tax
    invoice.total_amount = total_amount
    db.commit()


def issue_invoice(db: Session, invoice_id: int, actor_name: str = "system"):
    """Lock invoice (mark as issued), write invoice event."""
    invoice = db.query(Invoice).get(invoice_id)
    if not invoice:
        raise ValueError("Invoice not found")
    if invoice.status != InvoiceStatus.draft:
        raise ValueError(f"Cannot issue invoice in status: {invoice.status}")

    invoice.status = InvoiceStatus.issued
    invoice.issued_at = datetime.now()
    db.commit()
    write_invoice_event(db, invoice_id, "issued", "user", actor_name, "Invoice issued")
    return invoice


def void_invoice(db: Session, invoice_id: int, actor_name: str = "system"):
    """Void an invoice (cannot be undone)."""
    invoice = db.query(Invoice).get(invoice_id)
    if not invoice:
        raise ValueError("Invoice not found")
    if invoice.status == InvoiceStatus.void:
        raise ValueError("Invoice already voided")

    invoice.status = InvoiceStatus.void
    invoice.voided_at = datetime.now()
    db.commit()
    write_invoice_event(db, invoice_id, "voided", "user", actor_name, "Invoice voided")
    return invoice


def record_payment(
    db: Session,
    invoice_id: int,
    amount: Decimal,
    method: str = "bank_transfer",
    reference: str = "",
    notes: str = "",
    received_date: date | None = None,
):
    """Record a payment toward an invoice. Updates status to partially_paid or paid."""
    invoice = db.query(Invoice).get(invoice_id)
    if not invoice:
        raise ValueError("Invoice not found")

    payment = Payment(
        invoice_id=invoice_id,
        amount=amount,
        method=method,
        reference=reference,
        notes=notes,
        received_date=received_date or date.today(),
    )
    db.add(payment)

    # Calculate total paid
    total_paid = sum((p.amount for p in invoice.payments), Decimal("0")) + amount
    previous_total_paid = total_paid - amount
    remaining_after_payment = invoice.total_amount - total_paid
    if total_paid >= invoice.total_amount:
        invoice.status = InvoiceStatus.paid
        invoice.paid_at = datetime.now()
        event_type = "payment_completed"
        event_message = f"Payment recorded: {amount} {invoice.currency}. Invoice fully paid."
    else:
        invoice.status = InvoiceStatus.partially_paid
        event_type = "payment_recorded"
        event_message = (
            f"Payment recorded: {amount} {invoice.currency}. "
            f"Paid so far: {total_paid} {invoice.currency}. Remaining: {remaining_after_payment} {invoice.currency}."
        )
    db.commit()

    write_invoice_event(
        db,
        invoice_id,
        event_type,
        "user",
        "system",
        event_message,
        metadata={
            "amount": str(amount),
            "currency": invoice.currency,
            "method": method,
            "reference": reference,
            "notes": notes,
            "previous_total_paid": str(previous_total_paid),
            "total_paid": str(total_paid),
            "remaining_after_payment": str(max(remaining_after_payment, Decimal("0"))),
        },
    )
    return payment


def write_invoice_event(db: Session, invoice_id: int, event_type: str, actor_type: str = "user", actor_name: str = "", message: str = "", metadata: dict = None):
    """Write an audit log entry for an invoice event."""
    event = InvoiceEvent(
        invoice_id=invoice_id,
        event_type=event_type,
        actor_type=actor_type,
        actor_name=actor_name,
        message=message,
        metadata_json=json.dumps(metadata) if metadata else None,
    )
    db.add(event)
    db.commit()
    return event
