"""
MySQL connection using pymysql with thread-local persistent connections.

Each worker thread keeps ONE open connection for its entire lifetime,
so we pay the TCP handshake cost once per thread instead of once per query.
"""
import threading
import pymysql
from pymysql.cursors import DictCursor
from config import DB_CONFIG

_local = threading.local()


def get_connection():
    """
    Return the current thread's persistent MySQL connection.
    Creates it on first call; reconnects automatically if the connection dropped.
    """
    conn = getattr(_local, "conn", None)
    if conn is None:
        conn = _new_conn()
        _local.conn = conn
    else:
        try:
            conn.ping(reconnect=True)
        except Exception:
            conn = _new_conn()
            _local.conn = conn
    return conn


def _new_conn():
    return pymysql.connect(
        host=DB_CONFIG["host"],
        port=DB_CONFIG["port"],
        user=DB_CONFIG["user"],
        password=DB_CONFIG["password"],
        database=DB_CONFIG["database"],
        charset=DB_CONFIG["charset"],
        cursorclass=DictCursor,
        autocommit=True,
        connect_timeout=10,
    )


def close_thread_connection():
    """Explicitly close this thread's connection (call at thread exit if needed)."""
    conn = getattr(_local, "conn", None)
    if conn is not None:
        try:
            conn.close()
        except Exception:
            pass
        _local.conn = None


# ── helpers ────────────────────────────────────────────────────────────────────

def execute_one(sql, args=None):
    """Execute INSERT/UPDATE/DELETE, return affected rows."""
    conn = get_connection()
    with conn.cursor() as cur:
        return cur.execute(sql, args)


def query_one(sql, args=None):
    """Return the first matching row as a dict, or None."""
    conn = get_connection()
    with conn.cursor() as cur:
        cur.execute(sql, args)
        return cur.fetchone()


def query_all(sql, args=None):
    """Return all matching rows as a list of dicts."""
    conn = get_connection()
    with conn.cursor() as cur:
        cur.execute(sql, args)
        return cur.fetchall()


def insert_and_get_id(sql, args=None):
    """INSERT a single row and return its auto-increment id."""
    conn = get_connection()
    with conn.cursor() as cur:
        cur.execute(sql, args)
        return cur.lastrowid


def insert_many(sql, rows):
    """
    Bulk INSERT using executemany.
    rows: list of tuples, each tuple supplies one row's VALUES placeholders.
    Returns total rows inserted.
    """
    if not rows:
        return 0
    conn = get_connection()
    with conn.cursor() as cur:
        return cur.executemany(sql, rows)


def update_many(sql, rows):
    """
    Bulk UPDATE using executemany.
    rows: list of tuples (SET values … WHERE args).
    Returns total rows affected.
    """
    if not rows:
        return 0
    conn = get_connection()
    with conn.cursor() as cur:
        return cur.executemany(sql, rows)
