from datetime import date, datetime, timedelta
from decimal import Decimal
from typing import Any

from sqlalchemy import and_, or_
from sqlalchemy.orm import Session

from app.core.models import (
    Customer,
    Expense,
    GstReturn,
    Invoice,
    InvoiceStatus,
    Payment,
    Product,
    Reminder,
    Subscription,
)


SUPPORTED_QUERY_TYPES = {
    "revenue",
    "stats",
    "invoice",
    "invoices",
    "customer",
    "customers",
    "product",
    "products",
    "expense",
    "expenses",
    "subscription",
    "subscriptions",
    "reminder",
    "reminders",
    "gst",
    "gst_summary",
}


def execute_ai_query(db: Session, query_data: dict[str, Any]) -> dict[str, Any]:
    """Execute an AI-requested query against real ERP data."""
    query_type = str(query_data.get("query_type") or "").strip().lower()
    if query_type not in SUPPORTED_QUERY_TYPES:
        return {
            "reply": (
                "I cannot run that query yet. I can currently query revenue, invoices, "
                "customers, expenses, subscriptions, reminders, and GST summaries."
            ),
            "action": "error",
        }

    filters = query_data.get("filters") if isinstance(query_data.get("filters"), dict) else {}
    date_range = _resolve_date_range(filters)
    if query_type in {"revenue", "stats"} and not date_range:
        return {
            "reply": 'Please provide a date range or say "this month", "last month", or "recent month."',
            "action": "clarify",
        }

    if query_type in {"revenue", "stats"}:
        return _query_revenue(db, date_range[0], date_range[1])
    if query_type in {"invoice", "invoices"}:
        return _query_invoices(db, filters, date_range)
    if query_type in {"customer", "customers"}:
        return _query_customers(db, filters)
    if query_type in {"product", "products"}:
        return _query_products(db, filters)
    if query_type in {"expense", "expenses"}:
        return _query_expenses(db, filters, date_range)
    if query_type in {"subscription", "subscriptions"}:
        return _query_subscriptions(db, filters, date_range)
    if query_type in {"reminder", "reminders"}:
        return _query_reminders(db, filters)
    if query_type in {"gst", "gst_summary"}:
        return _query_gst_summary(db, filters, date_range)

    return _unsupported_query()


def _unsupported_query() -> dict[str, Any]:
    return {
        "reply": (
            "I cannot run that query yet. I can currently query revenue, invoices, "
            "customers, expenses, subscriptions, reminders, and GST summaries."
        ),
        "action": "error",
    }


def _resolve_date_range(filters: dict[str, Any]) -> tuple[date, date] | None:
    date_from = _parse_date(filters.get("date_from"))
    date_to = _parse_date(filters.get("date_to"))
    if date_from and date_to:
        return date_from, date_to

    range_name = str(filters.get("date_range") or filters.get("period") or "").strip().lower()
    today = date.today()
    if range_name in {"recent_month", "recent month", "last_30_days", "last 30 days"}:
        return today - timedelta(days=30), today
    if range_name in {"this_month", "this month"}:
        return today.replace(day=1), today
    if range_name in {"last_month", "last month"}:
        first_this_month = today.replace(day=1)
        last_previous_month = first_this_month - timedelta(days=1)
        return last_previous_month.replace(day=1), last_previous_month

    return None


def _parse_date(value: Any) -> date | None:
    if isinstance(value, date) and not isinstance(value, datetime):
        return value
    if isinstance(value, datetime):
        return value.date()
    if not value:
        return None
    try:
        return datetime.strptime(str(value), "%Y-%m-%d").date()
    except ValueError:
        return None


def _query_revenue(db: Session, date_from: date, date_to: date) -> dict[str, Any]:
    payments = (
        db.query(Payment)
        .join(Invoice, Payment.invoice_id == Invoice.id)
        .filter(
            or_(
                and_(Payment.received_date >= date_from, Payment.received_date <= date_to),
                and_(
                    Payment.received_date.is_(None),
                    Payment.paid_at >= datetime.combine(date_from, datetime.min.time()),
                    Payment.paid_at < datetime.combine(date_to + timedelta(days=1), datetime.min.time()),
                ),
            )
        )
        .order_by(Payment.received_date.asc(), Payment.paid_at.asc(), Payment.id.asc())
        .all()
    )

    paid_invoices = (
        db.query(Invoice)
        .filter(
            Invoice.status == InvoiceStatus.paid,
            Invoice.invoice_date >= date_from,
            Invoice.invoice_date <= date_to,
        )
        .order_by(Invoice.invoice_date.asc(), Invoice.invoice_number.asc())
        .all()
    )

    cash_rows = [_payment_row(db, payment) for payment in payments]
    invoice_rows = [_invoice_row(db, invoice) for invoice in paid_invoices]
    cash_total = sum((Decimal(str(row["amount"])) for row in cash_rows), Decimal("0"))
    paid_invoice_total = sum((Decimal(str(row["total_amount"])) for row in invoice_rows), Decimal("0"))

    result = {
        "period": {
            "date_from": date_from.isoformat(),
            "date_to": date_to.isoformat(),
        },
        "cash_received": {
            "total": float(cash_total),
            "currency": _currency_for_rows(cash_rows) or "NZD",
            "payment_count": len(cash_rows),
            "rows": cash_rows,
        },
        "paid_invoices": {
            "total": float(paid_invoice_total),
            "currency": _currency_for_rows(invoice_rows) or "NZD",
            "invoice_count": len(invoice_rows),
            "rows": invoice_rows,
        },
    }

    return {
        "reply": _format_revenue_reply(result),
        "action": "query_result",
        "source": "database",
        "tool_name": "query_revenue",
        "result": result,
    }


def _query_invoices(db: Session, filters: dict[str, Any], date_range: tuple[date, date] | None) -> dict[str, Any]:
    query = db.query(Invoice)
    status = str(filters.get("status") or "").strip()
    invoice_number = str(filters.get("invoice_number") or "").strip()
    customer_name = str(filters.get("customer_name") or filters.get("customer") or "").strip()
    if status:
        query = query.filter(Invoice.status == status)
    if invoice_number:
        query = query.filter(Invoice.invoice_number.ilike(f"%{invoice_number}%"))
    if customer_name:
        query = query.join(Customer, Invoice.customer_id == Customer.id).filter(Customer.name.ilike(f"%{customer_name}%"))
    if date_range:
        query = query.filter(Invoice.invoice_date >= date_range[0], Invoice.invoice_date <= date_range[1])
    rows = [_invoice_row(db, invoice) for invoice in query.order_by(Invoice.invoice_date.desc(), Invoice.id.desc()).limit(20).all()]
    return _simple_query_result("query_invoices", "Invoice results", rows)


def _query_customers(db: Session, filters: dict[str, Any]) -> dict[str, Any]:
    query = db.query(Customer)
    keyword = str(filters.get("keyword") or filters.get("name") or filters.get("customer") or "").strip()
    status = str(filters.get("status") or "").strip()
    if keyword:
        query = query.filter(
            or_(
                Customer.name.ilike(f"%{keyword}%"),
                Customer.company_name.ilike(f"%{keyword}%"),
                Customer.email.ilike(f"%{keyword}%"),
            )
        )
    if status:
        query = query.filter(Customer.status == status)
    rows = [_customer_row(customer) for customer in query.order_by(Customer.name.asc()).limit(20).all()]
    return _simple_query_result("query_customers", "Customer results", rows)


def _query_products(db: Session, filters: dict[str, Any]) -> dict[str, Any]:
    query = db.query(Product)
    keyword = str(filters.get("keyword") or filters.get("name") or filters.get("product") or "").strip()
    status = str(filters.get("status") or "").strip()
    product_type = str(filters.get("product_type") or filters.get("type") or "").strip()
    if keyword:
        query = query.filter(Product.name.ilike(f"%{keyword}%"))
    if status:
        query = query.filter(Product.status == status)
    if product_type:
        query = query.filter(Product.product_type == product_type)
    rows = [_product_row(product) for product in query.order_by(Product.name.asc()).limit(20).all()]
    return _simple_query_result("query_products", "Product results", rows)


def _query_expenses(db: Session, filters: dict[str, Any], date_range: tuple[date, date] | None) -> dict[str, Any]:
    query = db.query(Expense)
    status = str(filters.get("status") or "").strip()
    vendor = str(filters.get("vendor") or filters.get("vendor_name") or "").strip()
    category = str(filters.get("category") or "").strip()
    if status:
        query = query.filter(Expense.status == status)
    if vendor:
        query = query.filter(Expense.vendor_name.ilike(f"%{vendor}%"))
    if category:
        query = query.filter(Expense.category == category)
    if date_range:
        query = query.filter(Expense.expense_date >= date_range[0], Expense.expense_date <= date_range[1])
    rows = [_expense_row(expense) for expense in query.order_by(Expense.expense_date.desc(), Expense.id.desc()).limit(20).all()]
    return _simple_query_result("query_expenses", "Expense results", rows)


def _query_subscriptions(db: Session, filters: dict[str, Any], date_range: tuple[date, date] | None) -> dict[str, Any]:
    query = db.query(Subscription)
    status = str(filters.get("status") or "").strip()
    customer_name = str(filters.get("customer") or filters.get("customer_name") or "").strip()
    if status:
        query = query.filter(Subscription.status == status)
    if customer_name:
        query = query.join(Customer, Subscription.customer_id == Customer.id).filter(Customer.name.ilike(f"%{customer_name}%"))
    if date_range:
        query = query.filter(or_(Subscription.end_date.is_(None), and_(Subscription.end_date >= date_range[0], Subscription.end_date <= date_range[1])))
    rows = [_subscription_row(db, sub) for sub in query.order_by(Subscription.end_date.asc(), Subscription.id.desc()).limit(20).all()]
    return _simple_query_result("query_subscriptions", "Subscription results", rows)


def _query_reminders(db: Session, filters: dict[str, Any]) -> dict[str, Any]:
    query = db.query(Reminder)
    status = str(filters.get("status") or "").strip()
    reminder_type = str(filters.get("reminder_type") or filters.get("type") or "").strip()
    if status:
        query = query.filter(Reminder.status == status)
    if reminder_type:
        query = query.filter(Reminder.reminder_type == reminder_type)
    rows = [_reminder_row(reminder) for reminder in query.order_by(Reminder.id.desc()).limit(20).all()]
    return _simple_query_result("query_reminders", "Reminder results", rows)


def _query_gst_summary(db: Session, filters: dict[str, Any], date_range: tuple[date, date] | None) -> dict[str, Any]:
    query = db.query(GstReturn)
    status = str(filters.get("status") or "").strip()
    if status:
        query = query.filter(GstReturn.status == status)
    if date_range:
        query = query.filter(GstReturn.period_start <= date_range[1], GstReturn.period_end >= date_range[0])
    rows = [_gst_return_row(gst_return) for gst_return in query.order_by(GstReturn.period_start.desc()).limit(20).all()]
    return _simple_query_result("query_gst_summary", "GST summary results", rows)


def _payment_row(db: Session, payment: Payment) -> dict[str, Any]:
    invoice = db.get(Invoice, payment.invoice_id)
    customer = db.get(Customer, invoice.customer_id) if invoice else None
    payment_date = payment.received_date or (payment.paid_at.date() if payment.paid_at else None)
    return {
        "invoice_number": invoice.invoice_number if invoice else "",
        "customer": customer.name if customer else "",
        "received_date": payment_date.isoformat() if payment_date else "",
        "amount": float(payment.amount or 0),
        "currency": payment.currency or (invoice.currency if invoice else "NZD"),
        "method": payment.method or "",
        "reference": payment.reference or "",
    }


def _invoice_row(db: Session, invoice: Invoice) -> dict[str, Any]:
    customer = db.get(Customer, invoice.customer_id)
    return {
        "invoice_number": invoice.invoice_number,
        "customer": customer.name if customer else "",
        "invoice_date": invoice.invoice_date.isoformat() if invoice.invoice_date else "",
        "total_amount": float(invoice.total_amount or 0),
        "currency": invoice.currency or "NZD",
        "status": invoice.status.value if hasattr(invoice.status, "value") else str(invoice.status),
    }


def _customer_row(customer: Customer) -> dict[str, Any]:
    return {
        "id": customer.id,
        "name": customer.name,
        "company_name": customer.company_name or "",
        "email": customer.email or "",
        "customer_type": customer.customer_type.value if hasattr(customer.customer_type, "value") else str(customer.customer_type),
        "status": customer.status.value if hasattr(customer.status, "value") else str(customer.status),
    }


def _expense_row(expense: Expense) -> dict[str, Any]:
    return {
        "id": expense.id,
        "vendor_name": expense.vendor_name,
        "expense_date": expense.expense_date.isoformat() if expense.expense_date else "",
        "category": expense.category,
        "amount_gross": float(expense.amount_gross or 0),
        "gst_amount": float(expense.gst_amount or 0),
        "currency": expense.currency or "NZD",
        "status": expense.status or "",
        "source": expense.source or "",
    }


def _product_row(product: Product) -> dict[str, Any]:
    return {
        "id": product.id,
        "name": product.name,
        "product_type": product.product_type.value if hasattr(product.product_type, "value") else str(product.product_type),
        "unit_price": float(product.unit_price or 0),
        "tax_rate": float(product.tax_rate or 0),
        "unit": product.unit or "",
        "status": product.status.value if hasattr(product.status, "value") else str(product.status),
    }


def _subscription_row(db: Session, sub: Subscription) -> dict[str, Any]:
    customer = db.get(Customer, sub.customer_id)
    product = db.get(Product, sub.product_id)
    return {
        "id": sub.id,
        "customer": customer.name if customer else "",
        "product": product.name if product else "",
        "start_date": sub.start_date.isoformat() if sub.start_date else "",
        "end_date": sub.end_date.isoformat() if sub.end_date else "",
        "status": sub.status.value if hasattr(sub.status, "value") else str(sub.status),
        "billing_cycle": sub.billing_cycle.value if hasattr(sub.billing_cycle, "value") else str(sub.billing_cycle),
        "auto_renew": bool(sub.auto_renew),
        "next_invoice_date": sub.next_invoice_date.isoformat() if sub.next_invoice_date else "",
    }


def _reminder_row(reminder: Reminder) -> dict[str, Any]:
    return {
        "id": reminder.id,
        "reminder_type": reminder.reminder_type.value if hasattr(reminder.reminder_type, "value") else str(reminder.reminder_type),
        "trigger_days": reminder.trigger_days,
        "send_email": bool(reminder.send_email),
        "send_telegram": bool(reminder.send_telegram),
        "status": reminder.status.value if hasattr(reminder.status, "value") else str(reminder.status),
    }


def _gst_return_row(gst_return: GstReturn) -> dict[str, Any]:
    return {
        "id": gst_return.id,
        "period_start": gst_return.period_start.isoformat() if gst_return.period_start else "",
        "period_end": gst_return.period_end.isoformat() if gst_return.period_end else "",
        "status": gst_return.status or "",
        "gst_output": float(gst_return.gst_output or 0),
        "gst_input": float(gst_return.gst_input or 0),
        "gst_payable": float(gst_return.gst_payable or 0),
        "total_sales_income": float(gst_return.total_sales_income or 0),
        "total_purchases_expenses": float(gst_return.total_purchases_expenses or 0),
    }


def _currency_for_rows(rows: list[dict[str, Any]]) -> str | None:
    currencies = {str(row.get("currency") or "").strip() for row in rows if row.get("currency")}
    return currencies.pop() if len(currencies) == 1 else None


def _format_money(currency: str | None, amount: float) -> str:
    return f"{currency or 'NZD'} {amount:,.2f}"


def _simple_query_result(tool_name: str, title: str, rows: list[dict[str, Any]]) -> dict[str, Any]:
    result = {"count": len(rows), "rows": rows}
    lines = [
        "The following figures come from database queries, not AI-generated data.",
        "",
        f"{title}: {len(rows)}",
    ]
    for row in rows:
        lines.append("- " + " | ".join(f"{key}: {value}" for key, value in row.items() if value not in ("", None)))
    if not rows:
        lines.append("- No matching records.")
    return {
        "reply": "\n".join(lines),
        "action": "query_result",
        "source": "database",
        "tool_name": tool_name,
        "result": result,
    }


def _format_revenue_reply(result: dict[str, Any]) -> str:
    period = result["period"]
    cash = result["cash_received"]
    invoices = result["paid_invoices"]
    lines = [
        "The following figures come from database queries, not AI-generated data.",
        "",
        f"Period: {period['date_from']} to {period['date_to']}",
        "",
        "Cash received:",
        f"- Total: {_format_money(cash['currency'], cash['total'])}",
        f"- Payment count: {cash['payment_count']}",
        "",
        "Paid invoice totals:",
        f"- Total: {_format_money(invoices['currency'], invoices['total'])}",
        f"- Invoice count: {invoices['invoice_count']}",
        "",
        "Why they may differ: cash received uses payment received dates. Paid invoice totals use invoice dates for invoices currently marked paid.",
        "",
        "Cash received details:",
    ]

    if cash["rows"]:
        for row in cash["rows"]:
            lines.append(
                f"- {row['invoice_number']} | {row['customer']} | {row['received_date']} | "
                f"{_format_money(row['currency'], row['amount'])} | {row['method']} | {row['reference'] or '-'}"
            )
    else:
        lines.append("- No matching payments.")

    lines.extend(["", "Paid invoice details:"])
    if invoices["rows"]:
        for row in invoices["rows"]:
            lines.append(
                f"- {row['invoice_number']} | {row['customer']} | {row['invoice_date']} | "
                f"{_format_money(row['currency'], row['total_amount'])} | {row['status']}"
            )
    else:
        lines.append("- No matching paid invoices.")

    return "\n".join(lines)
