import io
import json
import re
from datetime import date
from decimal import Decimal
from mimetypes import guess_type
from pathlib import Path
from uuid import uuid4

from PIL import Image, UnidentifiedImageError


UPLOAD_ROOT = Path(__file__).resolve().parents[2] / "storage" / "receipts"
UPLOAD_ROOT.mkdir(parents=True, exist_ok=True)

_ocr_engine = None


def _import_pymupdf():
    try:
        import fitz
    except ImportError as exc:  # pragma: no cover - environment-specific
        raise RuntimeError("PyMuPDF is required for PDF receipt uploads.") from exc
    return fitz


def _get_ocr_engine():
    global _ocr_engine
    if _ocr_engine is None:
        try:
            from rapidocr_onnxruntime import RapidOCR
        except ImportError as exc:  # pragma: no cover - environment-specific
            raise RuntimeError("rapidocr-onnxruntime is required for receipt OCR.") from exc
        _ocr_engine = RapidOCR()
    return _ocr_engine


def _sanitize_name_component(value: str | None, fallback: str) -> str:
    normalized = (value or "").strip().lower()
    if not normalized:
        return fallback
    normalized = re.sub(r"[^\w\-]+", "-", normalized)
    normalized = re.sub(r"-{2,}", "-", normalized).strip("-")
    return normalized or fallback


def _ensure_rgb(image: Image.Image) -> Image.Image:
    if image.mode != "RGB":
        return image.convert("RGB")
    return image


def _save_page_webp(image: Image.Image, path: Path) -> dict:
    rgb_image = _ensure_rgb(image)
    rgb_image.save(path, format="WEBP", quality=82, method=6)
    return {
        "path": str(path),
        "width": rgb_image.width,
        "height": rgb_image.height,
        "size": path.stat().st_size,
    }


def _render_image_pages(file_bytes: bytes) -> list[Image.Image]:
    try:
        image = Image.open(io.BytesIO(file_bytes))
    except UnidentifiedImageError as exc:
        raise RuntimeError("Unsupported image format. Please upload JPG, PNG, WEBP, or PDF.") from exc
    return [_ensure_rgb(image)]


def _render_pdf_pages(file_bytes: bytes) -> list[Image.Image]:
    fitz = _import_pymupdf()
    document = fitz.open(stream=file_bytes, filetype="pdf")
    pages = []
    try:
        for page in document:
            pix = page.get_pixmap(matrix=fitz.Matrix(2, 2), alpha=False)
            image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
            pages.append(image)
    finally:
        document.close()
    return pages


def _write_original_file(directory: Path, upload_filename: str, file_bytes: bytes) -> Path:
    suffix = Path(upload_filename or "receipt.bin").suffix or ".bin"
    original_path = directory / f"original{suffix.lower()}"
    original_path.write_bytes(file_bytes)
    return original_path


def normalize_receipt_upload(upload_filename: str, content_type: str, file_bytes: bytes) -> dict:
    receipt_stem = uuid4().hex
    receipt_dir = UPLOAD_ROOT / receipt_stem
    original_dir = receipt_dir / "original"
    pages_dir = receipt_dir / "pages"
    original_dir.mkdir(parents=True, exist_ok=True)
    pages_dir.mkdir(parents=True, exist_ok=True)

    original_path = _write_original_file(original_dir, upload_filename, file_bytes)
    is_pdf = (content_type or "").lower() == "application/pdf" or upload_filename.lower().endswith(".pdf")
    page_images = _render_pdf_pages(file_bytes) if is_pdf else _render_image_pages(file_bytes)
    if not page_images:
        raise RuntimeError("No renderable pages were found in the uploaded receipt.")

    pages = []
    for index, image in enumerate(page_images, start=1):
        page_path = pages_dir / f"upload-{receipt_stem}-page{index:02d}.webp"
        page_meta = _save_page_webp(image, page_path)
        page_meta["page_number"] = index
        pages.append(page_meta)

    preview_meta = pages[0]
    return {
        "receipt_key": receipt_stem,
        "original_filename": upload_filename,
        "mime_type": content_type or guess_type(upload_filename)[0] or "application/octet-stream",
        "storage_path_original": str(original_path),
        "storage_path_preview": preview_meta["path"],
        "file_size_original": original_path.stat().st_size,
        "file_size_preview": preview_meta["size"],
        "image_width": preview_meta["width"],
        "image_height": preview_meta["height"],
        "pages": pages,
    }


def run_ocr_on_pages(page_paths: list[str]) -> dict:
    engine = _get_ocr_engine()
    pages = []

    for index, path in enumerate(page_paths, start=1):
        result, _ = engine(path)
        lines = []
        if result:
            for entry in result:
                text = (entry[1] or "").strip()
                if text:
                    lines.append(text)
        pages.append(
            {
                "page_number": index,
                "path": path,
                "text": "\n".join(lines).strip(),
            }
        )

    return {
        "pages": pages,
        "combined_text": "\n\n".join(page["text"] for page in pages if page["text"]).strip(),
    }


def rename_archive_pages(
    page_paths: list[str],
    supplier: str | None,
    invoice_number: str | None,
    invoice_date: str | None,
) -> list[str]:
    supplier_slug = _sanitize_name_component(supplier, "unknown-supplier")
    invoice_slug = _sanitize_name_component(invoice_number, "no-invoice")
    date_slug = _sanitize_name_component(invoice_date, "unknown-date")

    renamed_paths = []
    for index, path_str in enumerate(page_paths, start=1):
        path = Path(path_str)
        target = path.with_name(f"{supplier_slug}-{invoice_slug}-{date_slug}-page{index:02d}.webp")
        if target != path:
            path.replace(target)
        renamed_paths.append(str(target))
    return renamed_paths


def snapshot_json(data: dict) -> str:
    return json.dumps(data, ensure_ascii=False)


def _apply_ai_extraction_to_expense(expense, extraction: dict) -> None:
    vendor_name = extraction.get("vendor_name") or "Pending review"
    category = extraction.get("category_suggestion") or "uncategorized"
    invoice_date = extraction.get("invoice_date") or date.today().isoformat()
    try:
        expense.expense_date = date.fromisoformat(invoice_date)
    except ValueError:
        expense.expense_date = date.today()

    expense.vendor_name = vendor_name
    expense.category = category
    expense.description = extraction.get("notes") or f"Invoice {extraction.get('invoice_number') or ''}".strip()
    expense.amount_gross = Decimal(str(extraction.get("amount_gross") or 0))
    expense.gst_amount = Decimal(str(extraction.get("gst_amount") or 0))
    amount_net = extraction.get("amount_net")
    if amount_net in (None, ""):
        amount_net = expense.amount_gross - expense.gst_amount
    expense.amount_net = Decimal(str(amount_net or 0))
    expense.currency = extraction.get("currency") or "NZD"
    expense.notes = extraction.get("notes") or None
    expense.tax_mode = "gst_15"
    expense.custom_tax_rate = None
    expense.gst_amount_overridden = False
    expense.gst_amount_override = None


async def process_receipt_upload(
    db,
    filename: str,
    content_type: str,
    file_bytes: bytes,
    upload_user: str | None = None,
    source_channel: str = "web",
    extract_func=None,
) -> dict:
    from app.core.models import Expense, ReceiptAsset
    from app.services.ai_service import extract_receipt_data

    normalized = normalize_receipt_upload(
        filename or "receipt.bin",
        content_type or "application/octet-stream",
        file_bytes,
    )

    receipt = ReceiptAsset(
        original_filename=normalized["original_filename"],
        storage_path_original=normalized["storage_path_original"],
        storage_path_preview=normalized["storage_path_preview"],
        mime_type=normalized["mime_type"],
        file_size_original=normalized["file_size_original"],
        file_size_preview=normalized["file_size_preview"],
        image_width=normalized["image_width"],
        image_height=normalized["image_height"],
        upload_user=upload_user,
        processing_status="normalizing",
    )
    db.add(receipt)
    db.flush()

    expense = Expense(
        receipt_asset_id=receipt.id,
        expense_date=date.today(),
        vendor_name="Pending review",
        category="uncategorized",
        source="ai_receipt",
        status="draft",
    )
    db.add(expense)
    db.flush()

    page_paths = [page["path"] for page in normalized["pages"]]
    ocr_result = None
    extractor = extract_func or extract_receipt_data
    try:
        ocr_result = run_ocr_on_pages(page_paths)
        receipt.processing_status = "ocr_complete"
        receipt.ocr_snapshot_json = snapshot_json(ocr_result)

        extraction = await extractor(
            ocr_result.get("combined_text", ""),
            normalized["original_filename"],
            len(normalized["pages"]),
        )
        renamed_paths = rename_archive_pages(
            page_paths,
            extraction.get("vendor_name"),
            extraction.get("invoice_number"),
            extraction.get("invoice_date"),
        )
        for page, renamed_path in zip(normalized["pages"], renamed_paths, strict=False):
            page["path"] = renamed_path

        receipt.storage_path_preview = renamed_paths[0]
        receipt.file_size_preview = Path(renamed_paths[0]).stat().st_size
        receipt.processing_status = "ai_complete"
        receipt.ai_extraction_snapshot_json = snapshot_json(extraction)
        receipt.ocr_snapshot_json = snapshot_json(
            {
                **ocr_result,
                "pages": [
                    {
                        **page,
                        "path": renamed_paths[index],
                    }
                    for index, page in enumerate(ocr_result.get("pages", []))
                ],
            }
        )
        _apply_ai_extraction_to_expense(expense, extraction)
        if extraction.get("warnings"):
            warning_text = "; ".join(str(item) for item in extraction["warnings"])
            expense.notes = f"{expense.notes or ''}\nAI warnings: {warning_text}".strip()
    except Exception as exc:  # noqa: BLE001
        receipt.processing_status = "ai_failed" if ocr_result else "ocr_failed"
        warnings = {"error": str(exc), "source_channel": source_channel}
        if ocr_result:
            receipt.ocr_snapshot_json = snapshot_json(ocr_result)
        receipt.ai_extraction_snapshot_json = json.dumps(warnings, ensure_ascii=False)
        expense.notes = f"Receipt processing warning: {exc}"

    db.commit()
    db.refresh(receipt)
    db.refresh(expense)
    return {
        "receipt_id": receipt.id,
        "expense_id": expense.id,
        "processing_status": receipt.processing_status,
        "storage_root": str(UPLOAD_ROOT),
    }
