import cv2
import numpy as np
from PIL import Image
from typing import List, Tuple
from utils.image_utils import extract_color_histogram


def _orb_similarity(img1: np.ndarray, img2: np.ndarray) -> float:
    """
    Compute similarity score based on ORB feature matching.
    Returns a float between 0 and 1.
    """
    # Convert to grayscale for ORB
    if len(img1.shape) == 3:
        gray1 = cv2.cvtColor(img1, cv2.COLOR_RGB2GRAY)
        gray2 = cv2.cvtColor(img2, cv2.COLOR_RGB2GRAY)
    else:
        gray1 = img1
        gray2 = img2

    # Initialize ORB detector
    orb = cv2.ORB_create(nfeatures=1000)

    # Detect keypoints and compute descriptors
    kp1, des1 = orb.detectAndCompute(gray1, None)
    kp2, des2 = orb.detectAndCompute(gray2, None)

    if des1 is None or des2 is None or len(kp1) < 5 or len(kp2) < 5:
        return 0.0

    # Match descriptors using brute force matcher
    bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=False)
    matches = bf.knnMatch(des1, des2, k=2)

    # Apply Lowe's ratio test
    good_matches = []
    for match_pair in matches:
        if len(match_pair) == 2:
            m, n = match_pair
            if m.distance < 0.75 * n.distance:
                good_matches.append(m)

    if len(good_matches) == 0:
        return 0.0

    # Score based on ratio of good matches to keypoints
    score = len(good_matches) / max(len(kp1), len(kp2))
    return min(score, 1.0)


def _color_histogram_similarity(img1: np.ndarray, img2: np.ndarray) -> float:
    """
    Compute similarity based on color histogram correlation.
    Returns a float between 0 and 1.
    """
    # Convert to PIL for histogram extraction
    pil1 = Image.fromarray(img1)
    pil2 = Image.fromarray(img2)

    hist1 = extract_color_histogram(pil1)
    hist2 = extract_color_histogram(pil2)

    # Compute correlation
    hist1_f = hist1.astype(np.float32)
    hist2_f = hist2.astype(np.float32)

    # Manual correlation to avoid OpenCV dependency quirks
    mean1 = np.mean(hist1_f)
    mean2 = np.mean(hist2_f)

    numerator = np.sum((hist1_f - mean1) * (hist2_f - mean2))
    denom1 = np.sqrt(np.sum((hist1_f - mean1) ** 2))
    denom2 = np.sqrt(np.sum((hist2_f - mean2) ** 2))

    denominator = denom1 * denom2
    if denominator == 0:
        return 0.0

    correlation = numerator / denominator
    # Normalize from [-1, 1] to [0, 1]
    return max(0.0, (correlation + 1.0) / 2.0)


def compute_similarity(img1_path: str, img2_path: str) -> float:
    """
    Compute overall similarity between two images.
    Combines color histogram correlation (50%) and ORB feature matching (50%).
    Returns a float between 0 and 1.
    """
    # Load images
    pil1 = Image.open(img1_path).convert("RGB")
    pil2 = Image.open(img2_path).convert("RGB")

    img1 = np.array(pil1)
    img2 = np.array(pil2)

    # Compute individual scores
    color_score = _color_histogram_similarity(img1, img2)
    orb_score = _orb_similarity(img1, img2)

    # Combined score: 50% color, 50% ORB
    combined = 0.5 * color_score + 0.5 * orb_score
    return combined


def find_best_matches(
    target_path: str, candidate_paths: List[str], top_k: int = 5
) -> List[Tuple[str, float]]:
    """
    Find the top-k best matching images for a target image from a list of candidates.
    Returns list of (path, score) tuples sorted by score descending.
    """
    scores = []
    for candidate_path in candidate_paths:
        if candidate_path == target_path:
            continue
        score = compute_similarity(target_path, candidate_path)
        scores.append((candidate_path, score))

    # Sort by score descending
    scores.sort(key=lambda x: x[1], reverse=True)
    return scores[:top_k]


def batch_similarity_matrix(image_paths: List[str]) -> np.ndarray:
    """
    Compute a similarity matrix for a list of images.
    Returns a 2D numpy array where [i][j] = similarity between image i and j.
    The diagonal is always 1.0.
    """
    n = len(image_paths)
    matrix = np.eye(n, dtype=np.float32)

    # Pre-load all images as numpy arrays
    images = []
    for path in image_paths:
        pil_img = Image.open(path).convert("RGB")
        images.append(np.array(pil_img))

    # Pre-compute color histograms for all images
    color_hists = []
    for img in images:
        pil_img = Image.fromarray(img)
        color_hists.append(extract_color_histogram(pil_img))

    # Pre-compute ORB features for all images
    orb = cv2.ORB_create(nfeatures=1000)
    orb_features = []
    for img in images:
        if len(img.shape) == 3:
            gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        else:
            gray = img
        kp, des = orb.detectAndCompute(gray, None)
        orb_features.append((kp, des))

    bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=False)

    for i in range(n):
        for j in range(i + 1, n):
            # Color histogram similarity
            hist1 = color_hists[i].astype(np.float32)
            hist2 = color_hists[j].astype(np.float32)
            mean1 = np.mean(hist1)
            mean2 = np.mean(hist2)
            numerator = np.sum((hist1 - mean1) * (hist2 - mean2))
            denom1 = np.sqrt(np.sum((hist1 - mean1) ** 2))
            denom2 = np.sqrt(np.sum((hist2 - mean2) ** 2))
            denominator = denom1 * denom2
            if denominator == 0:
                color_score = 0.0
            else:
                correlation = numerator / denominator
                color_score = max(0.0, (correlation + 1.0) / 2.0)

            # ORB similarity
            kp1, des1 = orb_features[i]
            kp2, des2 = orb_features[j]

            if des1 is None or des2 is None or len(kp1) < 5 or len(kp2) < 5:
                orb_score = 0.0
            else:
                matches = bf.knnMatch(des1, des2, k=2)
                good_matches = []
                for match_pair in matches:
                    if len(match_pair) == 2:
                        m, n = match_pair
                        if m.distance < 0.75 * n.distance:
                            good_matches.append(m)
                if len(good_matches) == 0:
                    orb_score = 0.0
                else:
                    orb_score = min(len(good_matches) / max(len(kp1), len(kp2)), 1.0)

            # Combined score
            score = 0.5 * color_score + 0.5 * orb_score
            matrix[i][j] = score
            matrix[j][i] = score

    return matrix
