﻿"""
Servicio de cache inteligente para enriquecimiento de productos.
Reduce consultas a BD y mejora tiempos de respuesta.
"""

import hashlib
import json
import logging
import os
from typing import Any, Dict, List, Optional

from sqlalchemy.orm import Session

from app.utils.datetime_utils import utc_now

logger = logging.getLogger(__name__)

# Redis es opcional - no disponible en modo local (Windows desktop)
try:
    import redis.asyncio as aioredis
    from redis.exceptions import RedisError
    REDIS_AVAILABLE = True
except ImportError:
    aioredis = None  # type: ignore
    RedisError = Exception  # type: ignore
    REDIS_AVAILABLE = False
    logger.info("[CACHE] Redis not available - cache disabled (local mode)")


class EnrichmentCache:
    """
    Cache inteligente para enriquecimiento de productos usando Redis.

    Características:
    - TTL configurable (24 horas por defecto)
    - Invalidación selectiva por código nacional
    - Métricas de hit/miss
    - Fallback a BD si Redis no está disponible
    """

    def __init__(self, redis_url: str = "redis://redis:6379"):
        self.redis_url = redis_url
        self.redis = None  # aioredis.Redis or None
        self.ttl = 86400  # 24 horas por defecto
        # Desactivar cache si Redis no está disponible (modo local Windows)
        self.enabled = REDIS_AVAILABLE and not os.getenv("KAIFARMA_LOCAL", "").lower() == "true"

        # Métricas
        self.stats = {"hits": 0, "misses": 0, "errors": 0, "invalidations": 0}

    async def connect(self):
        """Conecta con Redis de forma asíncrona."""
        if not REDIS_AVAILABLE or aioredis is None:
            logger.info("[CACHE] Redis module not available - skipping connection")
            self.enabled = False
            return

        try:
            self.redis = await aioredis.from_url(
                self.redis_url,
                decode_responses=True,
                max_connections=10,
                socket_connect_timeout=5,
            )
            await self.redis.ping()
            self.enabled = True
            logger.info("[CACHE] Conectado a Redis para cache de enriquecimiento")
        except Exception as e:
            logger.warning(f"[CACHE] Redis no disponible, cache deshabilitado: {e}")
            self.enabled = False

    async def disconnect(self):
        """Desconecta de Redis."""
        if self.redis:
            await self.redis.close()

    def _get_cache_key(self, product_code: str) -> str:
        """Genera la clave de cache para un producto."""
        return f"enrichment:product:{product_code}"

    def _get_batch_cache_key(self, product_codes: List[str]) -> str:
        """Genera clave de cache para un batch de productos."""
        codes_hash = hashlib.md5(",".join(sorted(product_codes)).encode(), usedforsecurity=False).hexdigest()
        return f"enrichment:batch:{codes_hash}"

    async def get_enriched_product(self, product_code: str) -> Optional[Dict[str, Any]]:
        """
        Obtiene datos enriquecidos de un producto desde cache.

        Args:
            product_code: Código nacional del producto

        Returns:
            Datos enriquecidos o None si no está en cache
        """
        if not self.enabled or not self.redis:
            return None

        try:
            key = self._get_cache_key(product_code)
            cached_data = await self.redis.get(key)

            if cached_data:
                self.stats["hits"] += 1
                logger.debug(f"[CACHE] Hit para producto {product_code}")
                return json.loads(cached_data)
            else:
                self.stats["misses"] += 1
                logger.debug(f"[CACHE] Miss para producto {product_code}")
                return None

        except Exception as e:
            self.stats["errors"] += 1
            logger.error(f"[CACHE] Error obteniendo producto {product_code}: {e}")
            return None

    async def set_enriched_product(
        self,
        product_code: str,
        enriched_data: Dict[str, Any],
        ttl: Optional[int] = None,
    ):
        """
        Guarda datos enriquecidos en cache.

        Args:
            product_code: Código nacional del producto
            enriched_data: Datos enriquecidos para cachear
            ttl: Tiempo de vida en segundos (opcional)
        """
        if not self.enabled or not self.redis:
            return

        try:
            key = self._get_cache_key(product_code)
            ttl = ttl or self.ttl

            # Agregar timestamp de cache
            enriched_data["cached_at"] = utc_now().isoformat()

            await self.redis.setex(key, ttl, json.dumps(enriched_data, default=str))

            logger.debug(f"[CACHE] Producto {product_code} cacheado por {ttl}s")

        except Exception as e:
            self.stats["errors"] += 1
            logger.error(f"[CACHE] Error cacheando producto {product_code}: {e}")

    async def get_enriched_batch(self, product_codes: List[str]) -> Dict[str, Dict[str, Any]]:
        """
        Obtiene múltiples productos enriquecidos del cache.

        Args:
            product_codes: Lista de códigos nacionales

        Returns:
            Dict con códigos como keys y datos enriquecidos como values
        """
        if not self.enabled or not self.redis:
            return {}

        try:
            # Crear pipeline para obtener múltiples valores
            pipe = self.redis.pipeline()
            keys = [self._get_cache_key(code) for code in product_codes]

            for key in keys:
                pipe.get(key)

            results = await pipe.execute()

            # Mapear resultados
            cached_products = {}
            for code, result in zip(product_codes, results):
                if result:
                    self.stats["hits"] += 1
                    cached_products[code] = json.loads(result)
                else:
                    self.stats["misses"] += 1

            logger.info(f"[CACHE] Batch: {len(cached_products)}/{len(product_codes)} hits")
            return cached_products

        except Exception as e:
            self.stats["errors"] += 1
            logger.error(f"[CACHE] Error en batch get: {e}")
            return {}

    async def set_enriched_batch(self, products: Dict[str, Dict[str, Any]], ttl: Optional[int] = None):
        """
        Guarda múltiples productos enriquecidos en cache.

        Args:
            products: Dict con código como key y datos como value
            ttl: Tiempo de vida en segundos
        """
        if not self.enabled or not self.redis:
            return

        try:
            pipe = self.redis.pipeline()
            ttl = ttl or self.ttl
            timestamp = utc_now().isoformat()

            for code, data in products.items():
                key = self._get_cache_key(code)
                data["cached_at"] = timestamp
                pipe.setex(key, ttl, json.dumps(data, default=str))

            await pipe.execute()
            logger.info(f"[CACHE] Batch: {len(products)} productos cacheados")

        except Exception as e:
            self.stats["errors"] += 1
            logger.error(f"[CACHE] Error en batch set: {e}")

    async def invalidate_product(self, product_code: str):
        """
        Invalida el cache de un producto específico.

        Args:
            product_code: Código nacional del producto
        """
        if not self.enabled or not self.redis:
            return

        try:
            key = self._get_cache_key(product_code)
            deleted = await self.redis.delete(key)

            if deleted:
                self.stats["invalidations"] += 1
                logger.info(f"[CACHE] Invalidado producto {product_code}")

        except Exception as e:
            logger.error(f"[CACHE] Error invalidando {product_code}: {e}")

    async def get(self, key: str) -> Optional[Dict[str, Any]]:
        """
        Get generic cached value by key.

        Args:
            key: Cache key

        Returns:
            Cached data or None if not found/expired
        """
        if not self.enabled or not self.redis:
            return None

        try:
            cached_data = await self.redis.get(key)

            if cached_data:
                self.stats["hits"] += 1
                logger.debug(f"[CACHE] Hit for key {key}")
                return json.loads(cached_data)
            else:
                self.stats["misses"] += 1
                logger.debug(f"[CACHE] Miss for key {key}")
                return None

        except Exception as e:
            self.stats["errors"] += 1
            logger.error(f"[CACHE] Error getting key {key}: {e}")
            return None

    async def set(self, key: str, value: Dict[str, Any], ttl: Optional[int] = None):
        """
        Set generic cached value with TTL.

        Args:
            key: Cache key
            value: Data to cache
            ttl: Time to live in seconds (optional, defaults to self.ttl)
        """
        if not self.enabled or not self.redis:
            return

        try:
            ttl = ttl or self.ttl

            # Add timestamp for debugging
            value_with_meta = value.copy() if isinstance(value, dict) else value
            if isinstance(value_with_meta, dict):
                value_with_meta["cached_at"] = utc_now().isoformat()

            await self.redis.setex(key, ttl, json.dumps(value_with_meta, default=str))
            logger.debug(f"[CACHE] Key {key} cached for {ttl}s")

        except Exception as e:
            self.stats["errors"] += 1
            logger.error(f"[CACHE] Error setting key {key}: {e}")

    async def delete(self, key: str):
        """
        Delete specific cached key.

        Args:
            key: Cache key to delete
        """
        if not self.enabled or not self.redis:
            return

        try:
            deleted = await self.redis.delete(key)

            if deleted:
                self.stats["invalidations"] += 1
                logger.info(f"[CACHE] Key {key} invalidated")

        except Exception as e:
            logger.error(f"[CACHE] Error deleting key {key}: {e}")

    async def invalidate_all(self):
        """
        Invalida todo el cache de enriquecimiento.
        Útil después de actualizar el catálogo.
        """
        if not self.enabled or not self.redis:
            return

        try:
            # Buscar todas las claves de enriquecimiento
            pattern = "enrichment:*"
            cursor = 0
            deleted_count = 0

            while True:
                cursor, keys = await self.redis.scan(cursor, match=pattern, count=100)

                if keys:
                    deleted_count += await self.redis.delete(*keys)

                if cursor == 0:
                    break

            self.stats["invalidations"] += deleted_count
            logger.info(f"[CACHE] Invalidadas {deleted_count} entradas")

        except Exception as e:
            logger.error(f"[CACHE] Error invalidando todo: {e}")

    async def warmup_cache(self, db: Session, limit: int = 1000):
        """
        Pre-calienta el cache con los productos más usados.

        Args:
            db: Sesión de base de datos
            limit: Número máximo de productos a cachear
        """
        if not self.enabled or not self.redis:
            return

        try:
            from sqlalchemy import func

            from app.models.product_catalog import ProductCatalog
            from app.models.sales_data import SalesData

            # Obtener los productos más vendidos
            top_products = (
                db.query(SalesData.codigo_nacional, func.count(SalesData.id).label("count"))
                .filter(SalesData.codigo_nacional.isnot(None))
                .group_by(SalesData.codigo_nacional)
                .order_by(func.count(SalesData.id).desc())
                .limit(limit)
                .all()
            )

            logger.info(f"[CACHE] Calentando cache con {len(top_products)} productos")

            # Cachear información del catálogo para estos productos
            for product_code, _ in top_products:
                catalog_data = db.query(ProductCatalog).filter(ProductCatalog.national_code == product_code).first()

                if catalog_data:
                    await self.set_enriched_product(product_code, catalog_data.to_dict())

            logger.info("[CACHE] Cache precalentado exitosamente")

        except Exception as e:
            logger.error(f"[CACHE] Error en warmup: {e}")

    def get_stats(self) -> Dict[str, Any]:
        """
        Obtiene estadísticas del cache.

        Returns:
            Métricas de uso del cache
        """
        total = self.stats["hits"] + self.stats["misses"]
        hit_rate = (self.stats["hits"] / total * 100) if total > 0 else 0

        return {
            "enabled": self.enabled,
            "hits": self.stats["hits"],
            "misses": self.stats["misses"],
            "errors": self.stats["errors"],
            "invalidations": self.stats["invalidations"],
            "hit_rate": round(hit_rate, 2),
            "total_requests": total,
        }

    async def check_health(self) -> bool:
        """
        Verifica si el cache está funcionando correctamente.

        Returns:
            True si Redis está disponible y funcionando
        """
        if not self.enabled or not self.redis:
            return False

        try:
            await self.redis.ping()
            return True
        except (RedisError, AttributeError, ConnectionError, TimeoutError) as e:
            logger.debug(f"Redis ping failed: {e}")
            return False


# Instancia global del cache
enrichment_cache = EnrichmentCache()
