from datetime import datetime
from decimal import Decimal
from typing import Annotated

from fastapi import APIRouter, Depends, HTTPException, Query, status
from pydantic import BaseModel, ConfigDict

from sqlalchemy.orm import Session

from app.core.database import get_db
from app.core.models import Product, ProductType, ProductStatus
from app.core.security import get_current_user
router = APIRouter(prefix="/api/products", tags=["products"])


# ── Pydantic Models ────────────────────────────────────────────────────────────

class ProductBase(BaseModel):
    name: str
    description: str | None = None
    product_type: str = "project"
    unit_price: Decimal = Decimal("0")
    tax_rate: Decimal = Decimal("0.1500")
    unit: str = "unit"
    status: str = "active"


class ProductCreate(ProductBase):
    pass


class ProductUpdate(BaseModel):
    name: str | None = None
    description: str | None = None
    product_type: str | None = None
    unit_price: Decimal | None = None
    tax_rate: Decimal | None = None
    unit: str | None = None
    status: str | None = None


class ProductResponse(ProductBase):
    id: int
    created_at: datetime
    updated_at: datetime | None = None

    model_config = ConfigDict(arbitrary_allowed_types=True, from_attributes=True)


class PaginatedProducts(BaseModel):
    items: list[ProductResponse]
    total: int
    page: int
    page_size: int


# ── Helpers ────────────────────────────────────────────────────────────────────

def _product_to_response(product: Product) -> ProductResponse:
    return ProductResponse(
        id=product.id,
        name=product.name,
        description=product.description,
        product_type=(
            product.product_type.value
            if hasattr(product.product_type, "value")
            else product.product_type
        ),
        unit_price=Decimal(str(product.unit_price)),
        tax_rate=Decimal(str(product.tax_rate)),
        unit=product.unit,
        status=(
            product.status.value if hasattr(product.status, "value") else product.status
        ),
        created_at=product.created_at,
        updated_at=getattr(product, "updated_at", None),
    )


def _parse_product_type(value: str) -> ProductType:
    try:
        return ProductType(value)
    except ValueError:
        allowed = ", ".join(item.value for item in ProductType)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=f"Invalid product_type. Allowed values: {allowed}",
        )


def _parse_product_status(value: str) -> ProductStatus:
    try:
        return ProductStatus(value)
    except ValueError:
        allowed = ", ".join(item.value for item in ProductStatus)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=f"Invalid status. Allowed values: {allowed}",
        )


# ── Endpoints ─────────────────────────────────────────────────────────────────

@router.get("", response_model=PaginatedProducts)
def list_products(
    current_user: Annotated[dict, Depends(get_current_user)],
    page: int = Query(1, ge=1),
    page_size: int = Query(20, ge=1, le=200),
    search: str = Query("", max_length=200),
    product_type: str = Query(""),
    status: str = Query("", alias="status"),
    db: Session = Depends(get_db),
):
    """List all products with pagination and optional search/filters."""
    query = db.query(Product)

    if search:
        term = f"%{search}%"
        query = query.filter(
            (Product.name.ilike(term)) | (Product.description.ilike(term))
        )

    if product_type:
        try:
            pt = ProductType(product_type)
            query = query.filter(Product.product_type == pt)
        except ValueError:
            pass

    if status:
        try:
            ps = ProductStatus(status)
            query = query.filter(Product.status == ps)
        except ValueError:
            pass
    else:
        # Default: exclude archived
        query = query.filter(Product.status != ProductStatus.archived)

    total = query.count()
    items = (
        query.order_by(Product.id.desc())
        .offset((page - 1) * page_size)
        .limit(page_size)
        .all()
    )

    return PaginatedProducts(
        items=[_product_to_response(p) for p in items],
        total=total,
        page=page,
        page_size=page_size,
    )


@router.post("", response_model=ProductResponse, status_code=status.HTTP_201_CREATED)
def create_product(
    data: ProductCreate,
    current_user: Annotated[dict, Depends(get_current_user)],
    db: Session = Depends(get_db),
):
    """Create a new product/service."""
    product = Product(
        name=data.name,
        description=data.description,
        product_type=_parse_product_type(data.product_type),
        unit_price=data.unit_price,
        tax_rate=data.tax_rate,
        unit=data.unit,
        status=_parse_product_status(data.status),
    )
    db.add(product)
    db.commit()
    db.refresh(product)
    return _product_to_response(product)


@router.get("/{product_id}", response_model=ProductResponse)
def get_product(
    product_id: int,
    current_user: Annotated[dict, Depends(get_current_user)],
    db: Session = Depends(get_db),
):
    """Get a single product by ID."""
    product = db.get(Product, product_id)
    if not product:
        raise HTTPException(status_code=404, detail="Product not found")
    return _product_to_response(product)


@router.put("/{product_id}", response_model=ProductResponse)
def update_product(
    product_id: int,
    data: ProductUpdate,
    current_user: Annotated[dict, Depends(get_current_user)],
    db: Session = Depends(get_db),
):
    """Update an existing product."""
    product = db.get(Product, product_id)
    if not product:
        raise HTTPException(status_code=404, detail="Product not found")

    update_data = data.model_dump(exclude_unset=True)

    # Handle enum fields explicitly
    if "product_type" in update_data and update_data["product_type"] is not None:
        product.product_type = _parse_product_type(update_data.pop("product_type"))
    if "status" in update_data and update_data["status"] is not None:
        product.status = _parse_product_status(update_data.pop("status"))

    for field, value in update_data.items():
        setattr(product, field, value)

    db.commit()
    db.refresh(product)
    return _product_to_response(product)


@router.delete("/{product_id}")
def delete_product(
    product_id: int,
    current_user: Annotated[dict, Depends(get_current_user)],
    db: Session = Depends(get_db),
):
    """Soft delete a product (set status to archived)."""
    product = db.get(Product, product_id)
    if not product:
        raise HTTPException(status_code=404, detail="Product not found")

    product.status = ProductStatus.archived
    db.commit()

    return {"message": "Product archived"}
