import numpy as np
from celery_app import app
from config import settings
from utils.image_utils import extract_color_histogram, load_equirectangular_image
from sklearn.cluster import AgglomerativeClustering
from sklearn.preprocessing import StandardScaler
from sqlalchemy import create_engine, text
from sqlalchemy.orm import Session
import os


@app.task(bind=True)
def group_rooms(self, house_id: str):
    """Group panoramas into rooms using color similarity + capture order."""
    engine = create_engine(settings.database_url)

    try:
        with Session(engine) as session:
            # Load all panoramas for house, ordered by capture_order
            result = session.execute(
                text(
                    "SELECT id, file_path, capture_order FROM panoramas "
                    "WHERE house_id = :house_id ORDER BY capture_order ASC"
                ),
                {"house_id": house_id},
            )
            panoramas = result.fetchall()

            if len(panoramas) == 0:
                return {
                    "success": False,
                    "groups": [],
                    "error": "No panoramas found for this house",
                }

            if len(panoramas) == 1:
                # Single pano, single room
                return {
                    "success": True,
                    "groups": [
                        {
                            "suggested_name": "Room 1",
                            "pano_ids": [panoramas[0][0]],
                            "confidence": 1.0,
                        }
                    ],
                }

            pano_ids = [p[0] for p in panoramas]
            file_paths = [p[1] for p in panoramas]
            capture_orders = [p[2] for p in panoramas]
            total_count = len(panoramas)

            # Compute color histograms
            histograms = []
            for file_path in file_paths:
                full_path = os.path.join(settings.upload_dir, file_path)
                if os.path.exists(full_path):
                    image = load_equirectangular_image(full_path)
                    hist = extract_color_histogram(image)
                    histograms.append(hist)
                else:
                    histograms.append(np.zeros(512, dtype=np.float32))

            histograms = np.array(histograms)

            # Compute color similarity distance
            n = len(histograms)
            dist_matrix = np.zeros((n, n))
            for i in range(n):
                for j in range(i + 1, n):
                    # Correlation distance
                    h1 = histograms[i].astype(np.float32)
                    h2 = histograms[j].astype(np.float32)
                    mean1 = np.mean(h1)
                    mean2 = np.mean(h2)
                    numerator = np.sum((h1 - mean1) * (h2 - mean2))
                    denom1 = np.sqrt(np.sum((h1 - mean1) ** 2))
                    denom2 = np.sqrt(np.sum((h2 - mean2) ** 2))
                    denominator = denom1 * denom2
                    if denominator == 0:
                        color_sim = 0.0
                    else:
                        correlation = numerator / denominator
                        color_sim = max(0.0, (correlation + 1.0) / 2.0)
                    # Distance = 0.7 * (1 - color_sim) + 0.3 * normalized capture_order diff
                    order_diff = abs(capture_orders[i] - capture_orders[j]) / total_count
                    distance = 0.7 * (1.0 - color_sim) + 0.3 * order_diff
                    dist_matrix[i][j] = distance
                    dist_matrix[j][i] = distance

            # Agglomerative clustering
            clustering = AgglomerativeClustering(
                n_clusters=None,
                distance_threshold=0.5,
                metric="precomputed",
                linkage="average",
            )
            labels = clustering.fit_predict(dist_matrix)

            # Build groups from labels
            unique_labels = sorted(set(labels))
            groups = []
            room_names = [
                "Living Room", "Kitchen", "Bedroom 1", "Bedroom 2",
                "Bedroom 3", "Bathroom", "Hallway", "Balcony",
                "Dining Room", "Study", "Stairs", "Entry",
            ]

            for label in unique_labels:
                indices = np.where(labels == label)[0]
                group_pano_ids = [pano_ids[i] for i in indices]
                confidence = min(1.0, 0.5 + 0.1 * len(indices))

                # Suggest room name based on position
                avg_order = np.mean([capture_orders[i] for i in indices])
                if avg_order <= 1.0:
                    suggested_name = "Entry"
                elif len(indices) >= 3:
                    suggested_name = "Living Room"
                elif len(indices) >= 2:
                    suggested_name = "Hallway"
                else:
                    name_idx = min(label, len(room_names) - 1)
                    suggested_name = room_names[name_idx]

                groups.append({
                    "suggested_name": suggested_name,
                    "pano_ids": group_pano_ids,
                    "confidence": float(confidence),
                })

            # Save AI suggestions to database
            # Clear old suggestions for this house
            session.execute(
                text(
                    "DELETE FROM ai_suggestions WHERE house_id = :house_id "
                    "AND suggestion_type = 'room_group'"
                ),
                {"house_id": house_id},
            )

            for group in groups:
                session.execute(
                    text(
                        "INSERT INTO ai_suggestions (house_id, suggestion_type, "
                        "data, confidence, status) "
                        "VALUES (:house_id, 'room_group', :data, :confidence, 'pending')"
                    ),
                    {
                        "house_id": house_id,
                        "data": str(group),
                        "confidence": group["confidence"],
                    },
                )

            session.commit()

    except Exception as e:
        return {"success": False, "groups": [], "error": str(e)}

    return {
        "success": True,
        "groups": groups,
    }
