# ============================================================
# Security module — JWT / password hashing
# ============================================================
import os
from datetime import datetime, timedelta, timezone
from typing import Any

import bcrypt
from jose import JWTError, jwt


# -------------------- Password hashing --------------------
def hash_password(password: str) -> str:
    """Hash a password using bcrypt."""
    return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")


def verify_password(plain_password: str, hashed_password: str) -> bool:
    """Verify a password against its hash."""
    return bcrypt.checkpw(plain_password.encode("utf-8"), hashed_password.encode("utf-8"))


# -------------------- JWT helpers --------------------
JWT_SECRET = os.getenv("JWT_SECRET", "erp-super-secret-key-change-in-production")
JWT_ALGORITHM = os.getenv("JWT_ALGORITHM", "HS256")
JWT_EXPIRE_MINUTES = int(os.getenv("JWT_EXPIRE_MINUTES", "1440"))  # 24 hours default

ALGORITHM = JWT_ALGORITHM


def create_jwt_token(payload: dict[str, Any]) -> str:
    """
    Create a JWT token.

    Args:
        payload: Must contain at least {"sub": user_id, "username": ..., "role": ...}
    """
    to_encode = payload.copy()
    expire = datetime.now(timezone.utc) + timedelta(minutes=JWT_EXPIRE_MINUTES)
    to_encode.update({
        "exp": expire,
        "iat": datetime.now(timezone.utc),
    })
    return jwt.encode(to_encode, JWT_SECRET, algorithm=ALGORITHM)


def decode_jwt_token(token: str) -> dict[str, Any]:
    """
    Decode and validate a JWT token.
    Raises HTTPException on failure.
    """
    from fastapi import HTTPException
    credentials_exception = HTTPException(
        status_code=401,
        detail="Invalid or expired token",
        headers={"WWW-Authenticate": "Bearer"},
    )
    try:
        payload = jwt.decode(token, JWT_SECRET, algorithms=[ALGORITHM])
        return payload
    except JWTError:
        raise credentials_exception