# backend/app/services/llm_enrichment_service.py
"""
LLM Enrichment Service - Issue #456

Servicio para enriquecer productos de venta libre usando Groq (Llama 3.1 70B).

Caracteristicas:
- Groq API: ~800 tokens/s, gratis en beta
- Few-Shot prompting para alta precision
- Tenacity para retry automatico
- Fallback a Ollama local si Groq no disponible

Uso:
    service = LLMEnrichmentService()
    result = service.enrich_product("ISDIN FOTOPROTECTOR FUSION WATER SPF50")
    print(result.necesidad)  # "proteccion_solar"

Requisitos:
    - GROQ_API_KEY en variables de entorno
    - Obtener key gratis en: https://console.groq.com/
"""

import json
import logging
import os
import re
import time
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Dict, List, Optional

from pydantic import BaseModel, Field, ValidationError
from tenacity import (
    retry,
    retry_if_exception_type,
    stop_after_attempt,
    wait_exponential,
)

from ..schemas.llm_enrichment import (
    BatchEnrichmentProgress,
    BatchEnrichmentResult,
    CheckpointData,
    LLMEnrichmentResult,
    NecesidadCategory,
    ProductEnrichmentInput,
    ProductEnrichmentSchema,
)

logger = logging.getLogger(__name__)


# === PYDANTIC MODEL PARA LLM OUTPUT ===

class ProductEnrichment(BaseModel):
    """Modelo Pydantic para validar salida del LLM."""
    necesidad: str = Field(..., description="Categoria principal del producto")
    subcategoria: Optional[str] = Field(None, description="Subcategoria opcional")
    indicaciones: Optional[str] = Field(None, description="Para que sirve el producto")
    composicion_principal: Optional[str] = Field(None, description="Ingredientes activos principales")
    modo_empleo: Optional[str] = Field(None, description="Como usar el producto")
    confianza: float = Field(0.85, ge=0.0, le=1.0, description="Confianza en la clasificacion")


# === FEW-SHOT EXAMPLES ===
# Ejemplos diversos para que el modelo aprenda diferentes categorías

FEW_SHOT_EXAMPLES = [
    # Solar
    {
        "input": "ISDIN FOTOPROTECTOR FUSION WATER SPF50+ 50ML",
        "output": {
            "necesidad": "proteccion_solar",
            "subcategoria": "facial",
            "indicaciones": "Proteccion solar muy alta para rostro",
            "composicion_principal": "Filtros UVA/UVB",
            "modo_empleo": "Aplicar antes de exposicion solar, reaplicar cada 2h",
            "confianza": 0.95
        }
    },
    # Dental
    {
        "input": "LACER COLUTORIO FLUOR 500ML",
        "output": {
            "necesidad": "higiene_dental",
            "subcategoria": "colutorio",
            "indicaciones": "Enjuague bucal con fluor para prevencion de caries",
            "composicion_principal": "Fluoruro sodico",
            "modo_empleo": "Enjuagar 1 minuto despues del cepillado",
            "confianza": 0.95
        }
    },
    # Gripe
    {
        "input": "FRENADOL COMPLEX GRANULADO 10 SOBRES",
        "output": {
            "necesidad": "gripe_resfriado",
            "subcategoria": "antigripal",
            "indicaciones": "Alivio de sintomas de gripe y resfriado",
            "composicion_principal": "Paracetamol, Dextrometorfano, Clorfenamina",
            "modo_empleo": "1 sobre cada 6-8 horas disuelto en agua",
            "confianza": 0.95
        }
    },
    # Facial antiedad
    {
        "input": "VICHY LIFTACTIV SUPREME CREMA ANTIARRUGAS 50ML",
        "output": {
            "necesidad": "arrugas_antiedad",
            "subcategoria": "crema_facial",
            "indicaciones": "Tratamiento antiarrugas y firmeza",
            "composicion_principal": "Acido hialuronico, Rhamnosa",
            "modo_empleo": "Aplicar manana y noche sobre rostro limpio",
            "confianza": 0.95
        }
    },
    # Dolor
    {
        "input": "NUROFEN 400MG 12 COMPRIMIDOS",
        "output": {
            "necesidad": "dolor",
            "subcategoria": "analgesico",
            "indicaciones": "Alivio del dolor leve a moderado",
            "composicion_principal": "Ibuprofeno 400mg",
            "modo_empleo": "1 comprimido cada 6-8 horas con comida",
            "confianza": 0.95
        }
    },
]


class LLMEnrichmentService:
    """
    Servicio de enriquecimiento de productos usando Groq.

    Usa Llama 3.1 70B via Groq API para maxima precision.
    Fallback a Ollama local si Groq no esta disponible.
    """

    # Modelos disponibles en Groq (actualizados Dic 2024)
    GROQ_MODELS = [
        "llama-3.3-70b-versatile",   # Mejor calidad (nuevo)
        "llama-3.1-8b-instant",      # Rápido
        "mixtral-8x7b-32768",        # Alternativa
    ]

    OLLAMA_MODEL = "llama3.2:3b"  # Fallback local

    def __init__(
        self,
        groq_api_key: Optional[str] = None,
        use_groq: bool = True,
    ):
        """
        Inicializa el servicio.

        Args:
            groq_api_key: API key de Groq (o GROQ_API_KEY env var)
            use_groq: Si usar Groq (True) o Ollama local (False)
        """
        self.groq_api_key = groq_api_key or os.getenv("GROQ_API_KEY")
        self.use_groq = use_groq and bool(self.groq_api_key)

        # Lazy initialization
        self._groq_client = None
        self._ollama_client = None
        self._model = None

        # Build system prompt with valid categories
        self._system_prompt = self._build_system_prompt()
        self._few_shot_messages = self._build_few_shot_messages()

        if self.use_groq:
            logger.info("LLMEnrichmentService initialized with Groq API")
        else:
            logger.warning("GROQ_API_KEY not set, falling back to Ollama local")

    @property
    def groq_client(self):
        """Lazy initialization del cliente Groq."""
        if self._groq_client is None and self.use_groq:
            try:
                from groq import Groq
                self._groq_client = Groq(api_key=self.groq_api_key)
                self._model = self.GROQ_MODELS[0]
                logger.info(f"Groq client initialized: model={self._model}")
            except ImportError:
                logger.error("groq package not installed. Run: pip install groq")
                self.use_groq = False
        return self._groq_client

    @property
    def ollama_client(self):
        """Lazy initialization del cliente Ollama (fallback)."""
        if self._ollama_client is None:
            try:
                from ollama import Client
                self._ollama_client = Client(host="http://localhost:11434")
                self._model = self.OLLAMA_MODEL
                logger.info(f"Ollama client initialized: model={self._model}")
            except ImportError:
                logger.error("ollama package not installed")
        return self._ollama_client

    @property
    def model(self) -> str:
        """Retorna el modelo activo."""
        if self._model is None:
            if self.use_groq:
                _ = self.groq_client  # Initialize
            else:
                _ = self.ollama_client  # Initialize
        return self._model or "unknown"

    def _build_system_prompt(self) -> str:
        """Construye system prompt con categorías válidas."""
        # Top 50 categorías más usadas
        top_categories = [
            "dolor", "dolor_muscular", "proteccion_solar", "hidratacion_corporal",
            "hidratacion_facial", "arrugas_antiedad", "contorno_ojos", "limpieza_facial",
            "acne", "manchas", "caida_cabello", "caspa", "higiene_dental",
            "blanqueamiento_dental", "sensibilidad_dental", "encias", "gripe_resfriado",
            "mucosidad_respiratoria", "congestion_nasal", "dolor_garganta", "vitaminas_general",
            "probioticos", "flora_intestinal", "estrenimiento", "diarrea",
            "acidez_reflujo", "gases_digestion", "sueno_insomnio", "estres_ansiedad",
            "energia_cansancio", "memoria_concentracion", "defensas", "colesterol",
            "circulacion", "articulaciones", "heridas_apositos", "quemaduras",
            "picaduras", "hemorroides", "cistitis", "higiene_intima",
            "menopausia", "embarazo_prenatal", "alimentacion_bebe", "panales",
            "ojo_seco", "alergia", "veterinaria", "material_sanitario", "otros"
        ]

        return f"""Eres un clasificador de productos de farmacia española.

TAREA: Extraer informacion estructurada en JSON.

CATEGORIAS VALIDAS para "necesidad" (DEBES usar una de estas):
{', '.join(top_categories)}

REGLAS:
1. Responde SOLO con JSON valido
2. "necesidad" DEBE ser una de las categorias listadas
3. "confianza" entre 0.0 y 1.0
4. Todos los textos en español"""

    def _build_few_shot_messages(self) -> List[Dict[str, str]]:
        """Construye mensajes Few-Shot para el contexto."""
        messages = []
        for example in FEW_SHOT_EXAMPLES:
            messages.append({
                "role": "user",
                "content": f"Producto: {example['input']}"
            })
            messages.append({
                "role": "assistant",
                "content": json.dumps(example["output"], ensure_ascii=False)
            })
        return messages

    def _extract_json(self, text: str) -> Dict[str, Any]:
        """Extrae JSON de la respuesta con fallbacks."""
        # Intento 1: JSON directo
        try:
            return json.loads(text)
        except json.JSONDecodeError:
            pass

        # Intento 2: Buscar bloque ```json
        match = re.search(r'```json\s*(.*?)\s*```', text, re.DOTALL)
        if match:
            try:
                return json.loads(match.group(1))
            except json.JSONDecodeError:
                pass

        # Intento 3: Buscar {...}
        match = re.search(r'\{.*\}', text, re.DOTALL)
        if match:
            try:
                return json.loads(match.group(0))
            except json.JSONDecodeError:
                pass

        raise json.JSONDecodeError("No valid JSON found", text, 0)

    def _validate_necesidad(self, necesidad: str) -> str:
        """Valida y normaliza la categoría necesidad."""
        normalized = necesidad.lower().strip().replace(" ", "_").replace("-", "_")

        valid_categories = {cat.value for cat in NecesidadCategory}

        if normalized in valid_categories:
            return normalized

        # Fuzzy match simple
        for cat in valid_categories:
            if normalized in cat or cat in normalized:
                return cat

        # Si no hay match, usar "otros"
        logger.warning(f"Categoría '{normalized}' no válida, usando 'otros'")
        return "otros"

    @retry(
        stop=stop_after_attempt(3),
        wait=wait_exponential(multiplier=1, min=1, max=10),
        retry=retry_if_exception_type((ValidationError, ValueError, json.JSONDecodeError)),
        reraise=True
    )
    def enrich_product(
        self,
        product_name: str,
        existing_brand: Optional[str] = None,
        existing_necesidad: Optional[str] = None,
    ) -> ProductEnrichment:
        """
        Enriquece un producto usando Groq (o Ollama fallback).

        Args:
            product_name: Nombre del producto
            existing_brand: Marca detectada (opcional)
            existing_necesidad: Categoría sugerida (opcional)

        Returns:
            ProductEnrichment con datos extraídos
        """
        # Construir input
        user_input = f"Producto: {product_name}"
        if existing_brand:
            user_input += f"\nMarca: {existing_brand}"

        # Construir mensajes con Few-Shot
        messages = [
            {"role": "system", "content": self._system_prompt},
            *self._few_shot_messages,
            {"role": "user", "content": user_input},
        ]

        # Llamar al LLM
        if self.use_groq and self.groq_client:
            response = self.groq_client.chat.completions.create(
                model=self._model,
                messages=messages,
                temperature=0.1,
                max_tokens=300,
                response_format={"type": "json_object"},
            )
            content = response.choices[0].message.content
        else:
            # Fallback a Ollama
            response = self.ollama_client.chat(
                model=self._model,
                messages=messages,
                format="json",
                options={"temperature": 0.1, "num_predict": 300},
            )
            content = response.get("message", {}).get("content", "")

        if not content:
            raise ValueError("Empty response from LLM")

        # Parsear y validar
        data = self._extract_json(content)

        # Validar categoría
        if "necesidad" in data:
            data["necesidad"] = self._validate_necesidad(data["necesidad"])

        enrichment = ProductEnrichment(**data)

        logger.info(
            f"Enriched: {product_name[:40]}... -> {enrichment.necesidad} "
            f"(conf={enrichment.confianza:.2f})"
        )

        return enrichment

    def enrich_product_safe(
        self,
        product_name: str,
        existing_brand: Optional[str] = None,
        existing_necesidad: Optional[str] = None,
    ) -> LLMEnrichmentResult:
        """Version safe que nunca lanza excepciones."""
        start_time = time.time()

        try:
            enrichment = self.enrich_product(
                product_name=product_name,
                existing_brand=existing_brand,
                existing_necesidad=existing_necesidad,
            )

            processing_time_ms = int((time.time() - start_time) * 1000)

            schema_enrichment = ProductEnrichmentSchema(
                necesidad=enrichment.necesidad,
                subcategoria=enrichment.subcategoria,
                indicaciones=enrichment.indicaciones,
                composicion_principal=enrichment.composicion_principal,
                modo_empleo=enrichment.modo_empleo,
                confianza=enrichment.confianza,
            )

            return LLMEnrichmentResult(
                success=True,
                product_name=product_name,
                enrichment=schema_enrichment,
                processing_time_ms=processing_time_ms,
                model_version=self.model,
            )

        except Exception as e:
            processing_time_ms = int((time.time() - start_time) * 1000)
            logger.error(f"Error enriching {product_name[:40]}...: {e}")
            return LLMEnrichmentResult(
                success=False,
                product_name=product_name,
                error_message=str(e),
                processing_time_ms=processing_time_ms,
                model_version=self.model,
            )

    def enrich_batch(
        self,
        products: List[ProductEnrichmentInput],
        checkpoint_every: int = 100,
        checkpoint_dir: Optional[str] = None,
        dry_run: bool = False,
        progress_callback: Optional[callable] = None,
    ) -> BatchEnrichmentResult:
        """Enriquece un batch de productos con checkpoints."""
        start_time = datetime.now(timezone.utc)
        checkpoint_path = Path(checkpoint_dir) if checkpoint_dir else None

        if checkpoint_path:
            checkpoint_path.mkdir(parents=True, exist_ok=True)

        enriched_count = 0
        failed_count = 0
        skipped_count = 0
        total_confidence = 0.0
        errors: List[str] = []
        processed_ids: List[str] = []

        total = len(products)
        logger.info(f"Starting batch: {total} products, model={self.model}")

        for i, product in enumerate(products):
            try:
                if progress_callback:
                    progress = BatchEnrichmentProgress(
                        total_products=total,
                        processed=i,
                        enriched=enriched_count,
                        failed=failed_count,
                        skipped=skipped_count,
                        current_product=product.product_name[:50],
                    )
                    progress_callback(progress)

                result = self.enrich_product_safe(
                    product_name=product.product_name,
                    existing_brand=product.existing_brand,
                    existing_necesidad=product.existing_necesidad,
                )

                if result.success and result.enrichment:
                    enriched_count += 1
                    total_confidence += result.enrichment.confianza
                    if product.product_id:
                        processed_ids.append(product.product_id)
                else:
                    failed_count += 1
                    if result.error_message:
                        errors.append(f"{product.product_name[:30]}: {result.error_message}")

                # Checkpoint
                if checkpoint_path and (i + 1) % checkpoint_every == 0:
                    self._save_checkpoint(
                        checkpoint_path, processed_ids,
                        enriched_count, failed_count, product.product_id
                    )

            except Exception as e:
                failed_count += 1
                errors.append(str(e))

        end_time = datetime.now(timezone.utc)

        if checkpoint_path:
            self._save_checkpoint(
                checkpoint_path, processed_ids,
                enriched_count, failed_count, None, is_final=True
            )

        return BatchEnrichmentResult(
            total_processed=total,
            enriched_count=enriched_count,
            failed_count=failed_count,
            skipped_count=skipped_count,
            avg_confidence=total_confidence / enriched_count if enriched_count else 0,
            processing_time_seconds=(end_time - start_time).total_seconds(),
            checkpoint_path=str(checkpoint_path) if checkpoint_path else None,
            errors=errors[:100],
            started_at=start_time,
            completed_at=end_time,
            model_version=self.model,
        )

    def _save_checkpoint(
        self, path: Path, processed_ids: List[str],
        enriched: int, failed: int, last_id: Optional[str],
        is_final: bool = False
    ) -> Path:
        """Guarda checkpoint."""
        timestamp = datetime.now(timezone.utc)
        checkpoint_id = timestamp.strftime("%Y%m%d_%H%M%S")
        filename = f"checkpoint_{'final_' if is_final else ''}{checkpoint_id}.json"

        data = CheckpointData(
            checkpoint_id=checkpoint_id,
            processed_ids=processed_ids,
            enriched_count=enriched,
            failed_count=failed,
            last_product_id=last_id,
            timestamp=timestamp,
            model_version=self.model,
        )

        filepath = path / filename
        filepath.write_text(data.model_dump_json(indent=2), encoding="utf-8")
        logger.info(f"Checkpoint: {filepath}")
        return filepath

    def health_check(self) -> Dict[str, Any]:
        """Verifica disponibilidad del servicio."""
        try:
            if self.use_groq and self.groq_client:
                # Test Groq
                response = self.groq_client.chat.completions.create(
                    model=self._model,
                    messages=[{"role": "user", "content": "Responde: OK"}],
                    max_tokens=10,
                )
                return {
                    "status": "healthy",
                    "backend": "groq",
                    "model": self._model,
                    "response": response.choices[0].message.content[:20],
                }
            else:
                # Test Ollama
                response = self.ollama_client.chat(
                    model=self._model,
                    messages=[{"role": "user", "content": "Responde: OK"}],
                    options={"num_predict": 10},
                )
                return {
                    "status": "healthy",
                    "backend": "ollama",
                    "model": self._model,
                    "response": response.get("message", {}).get("content", "")[:20],
                }
        except Exception as e:
            return {
                "status": "unhealthy",
                "error": str(e),
                "backend": "groq" if self.use_groq else "ollama",
            }


# Singleton
_service: Optional[LLMEnrichmentService] = None


def get_llm_enrichment_service() -> LLMEnrichmentService:
    """Obtiene instancia singleton."""
    global _service
    if _service is None:
        _service = LLMEnrichmentService()
    return _service
