from PIL import Image
import numpy as np
import io


def load_equirectangular_image(path: str) -> Image.Image:
    """Load an equirectangular image from a file path."""
    return Image.open(path).convert("RGB")


def validate_equirectangular(image: Image.Image) -> bool:
    """
    Validate that the image has the correct aspect ratio for equirectangular (2:1 +/- 5%).
    """
    width, height = image.size
    if height == 0:
        return False
    actual_ratio = width / height
    target_ratio = 2.0
    tolerance = 0.05 * target_ratio  # 5% of 2.0 = 0.1
    return abs(actual_ratio - target_ratio) <= tolerance


def generate_thumbnail(image: Image.Image, size: tuple = (800, 400)) -> Image.Image:
    """Generate a thumbnail of the given size from an equirectangular image."""
    return image.resize(size, Image.LANCZOS)


def extract_color_histogram(image: Image.Image) -> np.ndarray:
    """
    Extract a simple 3D color histogram (RGB) for similarity comparison.

    Returns a flattened 1D numpy array.
    """
    img_array = np.array(image)
    # Compute histogram for each channel (R, G, B) with 8 bins per channel
    histogram = np.zeros((8, 8, 8), dtype=np.float32)
    # Downsample for performance
    h, w = img_array.shape[:2]
    step = max(1, min(h, w) // 256)
    sampled = img_array[::step, ::step]

    r_bins = (sampled[:, :, 0] * 8 / 256).astype(np.int32)
    g_bins = (sampled[:, :, 1] * 8 / 256).astype(np.int32)
    b_bins = (sampled[:, :, 2] * 8 / 256).astype(np.int32)

    # Clip to valid bin range
    r_bins = np.clip(r_bins, 0, 7)
    g_bins = np.clip(g_bins, 0, 7)
    b_bins = np.clip(b_bins, 0, 7)

    np.add.at(histogram, (r_bins, g_bins, b_bins), 1)

    # Normalize
    total = histogram.sum()
    if total > 0:
        histogram = histogram / total

    return histogram.flatten()
