"""
Servicio de backfill de códigos ATC en background.

Este servicio extrae códigos ATC de la API de CIMA para productos que no los tienen,
procesando en batches pequeños para evitar timeouts y respetar rate limits.

Issue relacionado: #400 - Dashboard de Análisis de Ventas de Prescripción
Prerequisito: Códigos ATC necesarios para análisis terapéutico (treemap, drill-down)
"""

import asyncio
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple
from uuid import UUID

import httpx
import structlog
from sqlalchemy import and_, func, or_
from sqlalchemy.orm import Session

from app.database import SessionLocal
from app.models.product_catalog import ProductCatalog
from app.utils.datetime_utils import utc_now

logger = structlog.get_logger(__name__)


class ATCBackfillService:
    """
    Servicio de backfill de códigos ATC en background.

    Estrategia:
    1. Query productos CIMA sin ATC en batches
    2. Deduplicar por nregistro (múltiples CN pueden compartir mismo medicamento)
    3. Fetch ATCs con concurrencia controlada (rate limiting friendly)
    4. Bulk update en BD

    Performance:
    - ~10-20k medicamentos únicos (nregistro)
    - 5 requests concurrentes = 2 requests/segundo (rate limit friendly)
    - Estimado: 30-40 minutos para completar
    """

    def __init__(self):
        """Inicializa el servicio de backfill ATC"""
        self.api_base_url = "https://cima.aemps.es/cima/rest/medicamento"
        self.timeout = 15.0  # Timeout por request (agresivo)
        self.max_retries = 2  # Solo 2 reintentos por medicamento

        # Métricas de ejecución (para monitoring)
        self.metrics = {
            "total_processed": 0,
            "successful": 0,
            "failed": 0,
            "skipped": 0,
            "start_time": None,
            "end_time": None
        }

    async def backfill_atc_codes(
        self,
        batch_size: int = 100,
        concurrent_requests: int = 5,
        incremental: bool = False,
        max_products: Optional[int] = None
    ) -> Dict[str, Any]:
        """
        Extrae códigos ATC para productos sin ellos, en batches iterativos.

        Args:
            batch_size: Número de productos a procesar por batch
            concurrent_requests: Número de requests HTTP concurrentes (rate limiting)
            incremental: Si True, solo procesa productos agregados recientemente
            max_products: Límite máximo de productos a procesar (para testing)

        Returns:
            Dict con estadísticas de la ejecución
        """
        self.metrics["start_time"] = utc_now()
        logger.info(
            "[ATC Backfill] Iniciando backfill de códigos ATC",
            batch_size=batch_size,
            concurrent_requests=concurrent_requests,
            incremental=incremental,
            max_products=max_products
        )

        db = SessionLocal()
        try:
            total_products_processed = 0
            batch_number = 0

            # Loop sobre múltiples batches hasta procesar todos los productos
            while True:
                batch_number += 1

                # 1. Obtener productos sin ATC
                products_without_atc = self._get_products_without_atc(
                    db,
                    batch_size=batch_size,
                    incremental=incremental
                )

                if not products_without_atc:
                    logger.info("[ATC Backfill] No hay más productos sin ATC para procesar")
                    break

                # Verificar límite max_products
                if max_products and total_products_processed >= max_products:
                    logger.info(
                        f"[ATC Backfill] Límite max_products alcanzado: {max_products}"
                    )
                    break

                logger.info(
                    f"[ATC Backfill] Batch #{batch_number}: Encontrados {len(products_without_atc)} productos sin ATC"
                )

                # 2. Deduplicar por nregistro (múltiples CNs → 1 nregistro)
                nregistro_to_products = self._group_by_nregistro(products_without_atc)
                unique_nregistros = list(nregistro_to_products.keys())

                logger.info(
                    f"[ATC Backfill] Batch #{batch_number}: {len(unique_nregistros)} medicamentos únicos (nregistro) a procesar"
                )

                # 3. Fetch ATCs con concurrencia controlada
                async with httpx.AsyncClient(timeout=self.timeout) as client:
                    # Procesar en batches de concurrent_requests
                    for i in range(0, len(unique_nregistros), concurrent_requests):
                        batch = unique_nregistros[i:i + concurrent_requests]

                        # Crear tasks para el batch
                        tasks = [
                            self._fetch_atc_for_nregistro(client, nregistro)
                            for nregistro in batch
                        ]

                        # Ejecutar batch concurrentemente
                        results = await asyncio.gather(*tasks, return_exceptions=True)

                        # 4. Procesar resultados y actualizar BD
                        self._process_batch_results(
                            db, results, batch, nregistro_to_products
                        )

                        # Commit intermedio cada batch (evita locks largos)
                        db.commit()

                        # Log de progreso dentro del batch
                        processed_in_batch = min(i + concurrent_requests, len(unique_nregistros))
                        logger.info(
                            f"[ATC Backfill] Batch #{batch_number} progreso: {processed_in_batch}/{len(unique_nregistros)} medicamentos",
                            percentage=round(processed_in_batch / len(unique_nregistros) * 100, 1)
                        )

                # Actualizar contador global
                total_products_processed += len(products_without_atc)

                logger.info(
                    f"[ATC Backfill] Batch #{batch_number} completado. Total procesado: {total_products_processed} productos"
                )

            self.metrics["end_time"] = utc_now()
            result = self._build_result(db)

            logger.info(
                "[ATC Backfill] Completado exitosamente",
                **result
            )

            return result

        except Exception as e:
            logger.error(
                "[ATC Backfill] Error durante backfill",
                error=str(e),
                error_type=type(e).__name__
            )
            db.rollback()
            raise
        finally:
            db.close()

    def _get_products_without_atc(
        self,
        db: Session,
        batch_size: int = 100,
        incremental: bool = False
    ) -> List[ProductCatalog]:
        """
        Query productos CIMA sin código ATC.

        Args:
            db: Sesión de BD
            batch_size: Límite de productos a retornar por batch
            incremental: Si True, solo productos recientes (última semana)

        Returns:
            Lista de productos sin ATC
        """
        query = db.query(ProductCatalog).filter(
            and_(
                ProductCatalog.cima_nombre_comercial.isnot(None),  # Es producto CIMA
                or_(
                    ProductCatalog.cima_atc_code.is_(None),  # Sin ATC principal
                    ProductCatalog.cima_atc_code == ""  # O ATC vacío
                )
            )
        )

        # Modo incremental: solo productos agregados recientemente
        if incremental:
            from datetime import timedelta
            cutoff_date = utc_now() - timedelta(days=7)
            query = query.filter(ProductCatalog.created_at >= cutoff_date)

        # Ordenar por created_at DESC (más recientes primero)
        query = query.order_by(ProductCatalog.created_at.desc())

        # Aplicar límite de batch
        query = query.limit(batch_size)

        return query.all()

    def _group_by_nregistro(
        self, products: List[ProductCatalog]
    ) -> Dict[str, List[ProductCatalog]]:
        """
        Agrupa productos por nregistro (número de registro del medicamento).

        Múltiples códigos nacionales (CN) pueden corresponder al mismo medicamento
        (diferentes presentaciones). Necesitamos deduplicar para evitar llamadas
        redundantes a la API de CIMA.

        Args:
            products: Lista de productos a agrupar

        Returns:
            Dict {nregistro: [productos]} donde cada nregistro mapea a sus CNs
        """
        nregistro_map = {}

        for product in products:
            # Usar cima_nregistro que se pobla durante la sincronización CIMA
            # Este es el número de registro correcto del medicamento
            nregistro = product.cima_nregistro

            if not nregistro:
                # Si no tiene nregistro, skip (requiere re-sync CIMA primero)
                self.metrics["skipped"] += 1
                continue

            if nregistro not in nregistro_map:
                nregistro_map[nregistro] = []

            nregistro_map[nregistro].append(product)

        return nregistro_map

    async def _fetch_atc_for_nregistro(
        self, client: httpx.AsyncClient, nregistro: str
    ) -> Optional[Dict[str, Any]]:
        """
        Fetch códigos ATC para un medicamento específico (por nregistro).

        Args:
            client: Cliente HTTP asíncrono
            nregistro: Número de registro del medicamento

        Returns:
            Dict con códigos ATC o None si falla
        """
        for attempt in range(self.max_retries + 1):
            try:
                response = await client.get(
                    self.api_base_url,
                    params={"nregistro": nregistro}
                )

                if response.status_code == 200:
                    data = response.json()
                    atcs = data.get("atcs", [])

                    if atcs:
                        # Extraer y limpiar códigos ATC
                        clean_atcs = []
                        atc_principal = None

                        for atc in atcs:
                            if isinstance(atc, dict) and atc.get("codigo"):
                                clean_atc = {
                                    "codigo": atc.get("codigo", ""),
                                    "nombre": atc.get("nombre", ""),
                                    "nivel": atc.get("nivel", 0)
                                }
                                clean_atcs.append(clean_atc)

                                # ATC principal: nivel 5 (más específico)
                                if atc.get("nivel") == 5 and not atc_principal:
                                    atc_principal = clean_atc["codigo"]

                        # Si no hay nivel 5, usar el último (más específico disponible)
                        if not atc_principal and clean_atcs:
                            atc_principal = clean_atcs[-1]["codigo"]

                        return {
                            "nregistro": nregistro,
                            "atc_codes": clean_atcs,
                            "atc_principal": atc_principal,
                            "success": True
                        }
                    else:
                        # Medicamento sin ATCs en CIMA
                        logger.debug(
                            f"[ATC Backfill] Medicamento {nregistro} sin ATCs en CIMA"
                        )
                        return {
                            "nregistro": nregistro,
                            "atc_codes": [],
                            "atc_principal": None,
                            "success": True  # No es error, simplemente no tiene ATC
                        }

                elif response.status_code == 404:
                    # Medicamento no encontrado en CIMA
                    logger.warning(
                        f"[ATC Backfill] Medicamento {nregistro} no encontrado (404)"
                    )
                    return None

                elif response.status_code == 429:
                    # Rate limit excedido - esperar y reintentar
                    retry_after = int(response.headers.get("Retry-After", 30))
                    logger.warning(
                        f"[ATC Backfill] Rate limit excedido, esperando {retry_after}s"
                    )
                    await asyncio.sleep(retry_after)
                    continue

                else:
                    logger.warning(
                        f"[ATC Backfill] Error HTTP {response.status_code} para nregistro {nregistro}"
                    )

            except httpx.TimeoutException:
                logger.warning(
                    f"[ATC Backfill] Timeout para nregistro {nregistro} (intento {attempt + 1}/{self.max_retries + 1})"
                )
                if attempt < self.max_retries:
                    await asyncio.sleep(2 ** attempt)  # Exponential backoff
                    continue

            except Exception as e:
                logger.error(
                    f"[ATC Backfill] Error inesperado para nregistro {nregistro}",
                    error=str(e),
                    error_type=type(e).__name__
                )

        # Falló después de reintentos
        return None

    def _process_batch_results(
        self,
        db: Session,
        results: List[Optional[Dict[str, Any]]],
        nregistros: List[str],
        nregistro_to_products: Dict[str, List[ProductCatalog]]
    ):
        """
        Procesa resultados de un batch y actualiza BD.

        Args:
            db: Sesión de BD
            results: Lista de resultados de fetch ATCs
            nregistros: Lista de nregistros del batch
            nregistro_to_products: Mapeo de nregistro a productos
        """
        for nregistro, result in zip(nregistros, results):
            products = nregistro_to_products.get(nregistro, [])

            if isinstance(result, Exception):
                # Error durante fetch
                logger.error(
                    f"[ATC Backfill] Error procesando nregistro {nregistro}",
                    error=str(result)
                )
                self.metrics["failed"] += len(products)
                continue

            if result is None:
                # Fetch falló después de reintentos
                self.metrics["failed"] += len(products)
                continue

            if not result.get("success"):
                # Resultado inválido
                self.metrics["failed"] += len(products)
                continue

            # Actualizar todos los productos con el mismo nregistro
            atc_codes = result.get("atc_codes", [])
            atc_principal = result.get("atc_principal")

            for product in products:
                try:
                    product.cima_atc_codes = atc_codes  # JSONB
                    product.cima_atc_code = atc_principal  # String
                    product.updated_at = utc_now()

                    self.metrics["successful"] += 1
                    self.metrics["total_processed"] += 1

                except Exception as e:
                    logger.error(
                        f"[ATC Backfill] Error actualizando producto {product.id}",
                        error=str(e)
                    )
                    self.metrics["failed"] += 1

    def _build_result(self, db: Session) -> Dict[str, Any]:
        """
        Construye resultado final con estadísticas.

        Args:
            db: Sesión de BD

        Returns:
            Dict con métricas de ejecución y cobertura ATC
        """
        # Calcular cobertura ATC actual
        total_cima = db.query(ProductCatalog).filter(
            ProductCatalog.cima_nombre_comercial.isnot(None)
        ).count()

        with_atc = db.query(ProductCatalog).filter(
            ProductCatalog.cima_atc_code.isnot(None),
            ProductCatalog.cima_atc_code != ""
        ).count()

        coverage = (with_atc / total_cima * 100) if total_cima > 0 else 0

        # Calcular duración
        duration_seconds = None
        if self.metrics["start_time"] and self.metrics["end_time"]:
            delta = self.metrics["end_time"] - self.metrics["start_time"]
            duration_seconds = delta.total_seconds()

        return {
            "status": "completed",
            "processed": self.metrics["total_processed"],
            "successful": self.metrics["successful"],
            "failed": self.metrics["failed"],
            "skipped": self.metrics["skipped"],
            "duration_seconds": duration_seconds,
            "coverage": {
                "total_cima_products": total_cima,
                "products_with_atc": with_atc,
                "coverage_percentage": round(coverage, 2),
                "target_percentage": 60.0,
                "target_reached": coverage >= 60.0
            }
        }


# Singleton instance
atc_backfill_service = ATCBackfillService()
