# backend/app/services/taxonomy_labeler_service.py
"""
TaxonomyLabelerService - Issue #462

Servicio para etiquetar clusters con taxonomía jerárquica:
- Tier 1 (Macro): Mapeo estático determinístico
- Tier 2 (Sub): Voto ponderado por ventas
- Tier 3 (Nombre): Generado por LLM

Modelo Híbrido:
- Estructura (Tier 1/2): Determinística, auditable
- Identidad (Tier 3): Creativa, LLM como "copywriter"

Uso:
    service = TaxonomyLabelerService(db)
    result = service.batch_label_all(pharmacy_id)
"""

import hashlib
import json
import logging
import os
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional, Tuple
from uuid import UUID

from pydantic import BaseModel, Field
from sqlalchemy import func
from sqlalchemy.orm import Session
from tenacity import (
    retry,
    retry_if_exception_type,
    stop_after_attempt,
    wait_exponential,
)

from ..constants.taxonomy import TIER1_MAPPING, get_tier1_for_category
from ..core.taxonomy_tier1 import ALL_CATEGORIES
from ..models.product_cluster import ClusterState, ProductCluster
from ..models.sales_enrichment import SalesEnrichment

logger = logging.getLogger(__name__)


# === CONFIGURATION ===

# LLM Configuration (API key loaded lazily for flexibility)
LLM_MODEL = "llama-3.3-70b-versatile"
LLM_FALLBACK_MODEL = "llama-3.1-8b-instant"


def _get_groq_api_key() -> str:
    """
    Get GROQ API key from environment (lazy load).

    Lazy loading allows key changes without service restart.
    """
    key = os.getenv("GROQ_API_KEY")
    if not key:
        raise ValueError(
            "GROQ_API_KEY no configurada. "
            "Se requiere para generar nombres de clusters. "
            "Obtener key gratis en: https://console.groq.com/"
        )
    return key


# === PYDANTIC MODELS ===

class TaxonomyLabel(BaseModel):
    """Resultado del etiquetado de un cluster."""
    tier1: str = Field(..., description="Macro-categoría (higiene_bucal, dermocosmetica, etc.)")
    tier2: str = Field(..., description="Categoría específica (sensibilidad_dental, acne, etc.)")
    display_name: str = Field(..., description="Nombre generado para el cluster")
    confidence: float = Field(0.85, ge=0.0, le=1.0)


class ClusterLabelingResult(BaseModel):
    """Resultado del etiquetado de un cluster individual."""
    cluster_id: str
    tier1: str
    tier2: str
    display_name: str
    confidence: float
    status: str  # "updated", "skipped", "error"
    reason: Optional[str] = None


class BatchLabelingResult(BaseModel):
    """Resultado del etiquetado batch de clusters."""
    success: bool
    message: str
    clusters_processed: int
    clusters_skipped: int
    clusters_updated: int
    clusters_errored: int
    llm_calls_made: int
    estimated_cost_usd: float
    execution_time_seconds: float
    details: List[ClusterLabelingResult]


# === NAMING PROMPT ===

TAXONOMY_NAMING_PROMPT = """Actúa como experto en taxonomía farmacéutica española.

Instrucción:
Analiza este grupo de productos de la categoría "{tier2}" y genera un nombre comercial para el cluster.

Reglas:
- Idioma: ESPAÑOL
- Longitud: Máximo 4 palabras
- Tono: Profesional y descriptivo
- Evita: Palabras genéricas como "Productos", "Varios", "Otros"

Ejemplos:
✅ Bueno: "Cuidado Facial Antiedad" o "Higiene Bucal Infantil"
❌ Malo: "Productos de la categoría facial" o "Varios de higiene"

Productos (nombre | precio medio):
{product_list}

Responde SOLO con el nombre generado (sin explicaciones, sin comillas, sin prefijos).
"""


class TaxonomyLabelerService:
    """
    Servicio de etiquetado jerárquico de clusters.

    Modelo híbrido:
    - Tier 1/2: Determinístico (voto ponderado + mapeo estático)
    - Tier 3: LLM naming
    """

    def __init__(self, db: Session):
        """
        Inicializa el servicio.

        Args:
            db: Sesión de SQLAlchemy

        Raises:
            RuntimeError: Si TIER1_MAPPING no cubre todas las categorías
        """
        self.db = db
        self._validate_tier1_coverage()
        self._groq_client = None

    def _validate_tier1_coverage(self) -> None:
        """Valida que todas las categorías tengan mapeo a Tier 1."""
        missing = ALL_CATEGORIES - set(TIER1_MAPPING.keys())
        if missing:
            logger.warning(
                f"Categorías sin mapear a Tier 1: {sorted(missing)}. "
                "Se usará 'otros_parafarmacia' como fallback."
            )

    # === COMPOSITION HASH ===

    def compute_composition_hash(self, product_ids: List[str]) -> str:
        """
        Calcula hash SHA256 determinístico de la composición del cluster.

        Args:
            product_ids: Lista de UUIDs de productos en el cluster

        Returns:
            Hash SHA256 de 64 caracteres
        """
        # Ordenar para determinismo
        sorted_ids = sorted(str(pid) for pid in product_ids)
        concatenated = "|".join(sorted_ids)
        return hashlib.sha256(concatenated.encode()).hexdigest()

    def needs_relabeling(
        self,
        cluster: ProductCluster,
        new_hash: str,
        force: bool = False
    ) -> bool:
        """
        Determina si un cluster necesita re-etiquetado.

        Usa comparación binaria de hash: cualquier cambio en la composición
        del cluster (productos añadidos/eliminados) dispara re-etiquetado.

        Nota: Para implementar threshold de cambio parcial (ej: 10% cambio),
        se requeriría almacenar product_ids y calcular Jaccard similarity.
        La implementación actual es conservadora: cualquier cambio re-etiqueta.

        Args:
            cluster: Cluster a evaluar
            new_hash: Hash actual de la composición
            force: Si True, re-etiqueta siempre

        Returns:
            True si necesita re-etiquetado
        """
        if force:
            return True

        if cluster.composition_hash is None:
            return True

        # Comparación binaria: cualquier cambio dispara re-etiquetado
        return cluster.composition_hash != new_hash

    # === WEIGHTED CATEGORY CALCULATION ===

    def calculate_weighted_category(
        self,
        cluster_id: UUID
    ) -> Tuple[str, str, Dict[str, float]]:
        """
        Calcula la categoría del cluster por voto ponderado de ventas.

        Algoritmo:
        1. Agrupa productos por predicted_necesidad (o ml_category)
        2. Suma ventas por categoría
        3. La categoría con más ventas gana (Tier 2)
        4. Mapeo estático para Tier 1

        Args:
            cluster_id: ID del cluster

        Returns:
            Tuple (tier1, tier2, sales_distribution)
        """
        # Query: suma de ventas por categoría
        # Usamos sales_enrichment unido a sales_data para obtener ventas
        from ..models.sales_data import SalesData

        results = (
            self.db.query(
                SalesEnrichment.ml_category,
                func.sum(SalesData.total_amount).label("total_sales")
            )
            .join(SalesData, SalesEnrichment.sales_data_id == SalesData.id)
            .filter(SalesEnrichment.product_cluster_id == cluster_id)
            .filter(SalesEnrichment.ml_category.isnot(None))
            .group_by(SalesEnrichment.ml_category)
            .all()
        )

        if not results:
            # Fallback si no hay ventas
            return "otros_parafarmacia", "otros", {}

        # Construir distribución
        sales_by_cat = {r.ml_category: float(r.total_sales or 0) for r in results}

        # Categoría ganadora (Tier 2)
        winning_tier2 = max(sales_by_cat, key=sales_by_cat.get)

        # Tier 1 por mapeo estático
        winning_tier1 = get_tier1_for_category(winning_tier2)

        return winning_tier1, winning_tier2, sales_by_cat

    # === LLM NAMING ===

    def _get_groq_client(self):
        """Obtiene o crea cliente Groq con lazy-loaded API key."""
        if self._groq_client is None:
            api_key = _get_groq_api_key()  # Lazy load from environment
            try:
                from groq import Groq
                self._groq_client = Groq(api_key=api_key)
            except ImportError:
                raise ImportError(
                    "groq package no instalado. "
                    "Ejecutar: pip install groq"
                )
        return self._groq_client

    def _sample_products_for_naming(
        self,
        cluster_id: UUID,
        limit: int = 15
    ) -> List[Dict[str, Any]]:
        """
        Obtiene muestra de productos para naming.

        Estrategia híbrida:
        - Top 10 por ventas (representan el core del cluster)
        - Top 5 más cercanos al centroide (representan la identidad)

        Args:
            cluster_id: ID del cluster
            limit: Máximo de productos a retornar

        Returns:
            Lista de dicts con product_name y avg_price
        """
        from ..models.sales_data import SalesData

        # Top por ventas
        results = (
            self.db.query(
                SalesData.product_name,
                func.avg(SalesData.unit_price).label("avg_price"),
                func.sum(SalesData.total_amount).label("total_sales")
            )
            .join(SalesEnrichment, SalesData.id == SalesEnrichment.sales_data_id)
            .filter(SalesEnrichment.product_cluster_id == cluster_id)
            .group_by(SalesData.product_name)
            .order_by(func.sum(SalesData.total_amount).desc())
            .limit(limit)
            .all()
        )

        return [
            {
                "product_name": r.product_name,
                "avg_price": float(r.avg_price or 0)
            }
            for r in results
        ]

    @retry(
        retry=retry_if_exception_type((Exception,)),
        stop=stop_after_attempt(3),
        wait=wait_exponential(multiplier=1, min=2, max=10),
    )
    def generate_naming_llm(
        self,
        tier2: str,
        products: List[Dict[str, Any]]
    ) -> Tuple[str, float]:
        """
        Genera nombre para el cluster usando LLM.

        Args:
            tier2: Categoría del cluster (para contexto)
            products: Lista de productos con nombre y precio

        Returns:
            Tuple (display_name, confidence)
        """
        if not products:
            return self._fallback_name(tier2), 0.5

        # Formatear lista de productos
        product_list = "\n".join(
            f"- {p['product_name']} | €{p['avg_price']:.2f}"
            for p in products[:15]
        )

        prompt = TAXONOMY_NAMING_PROMPT.format(
            tier2=tier2.replace("_", " ").title(),
            product_list=product_list
        )

        try:
            client = self._get_groq_client()

            response = client.chat.completions.create(
                model=LLM_MODEL,
                messages=[
                    {"role": "system", "content": "Eres un experto en taxonomía farmacéutica española."},
                    {"role": "user", "content": prompt}
                ],
                max_tokens=50,
                temperature=0.3,  # Bajo para consistencia
            )

            raw_name = response.choices[0].message.content.strip()

            # Limpiar respuesta
            display_name = self._clean_llm_response(raw_name)

            # Validar longitud
            if len(display_name.split()) > 5:
                display_name = " ".join(display_name.split()[:4])

            logger.info(f"LLM generated name for {tier2}: '{display_name}'")
            return display_name, 0.90

        except Exception as e:
            logger.warning(f"LLM naming failed for {tier2}: {e}. Using fallback.")
            return self._fallback_name(tier2), 0.5

    def _clean_llm_response(self, response: str) -> str:
        """Limpia la respuesta del LLM."""
        # Eliminar comillas, prefijos comunes
        response = response.strip('"\'')
        response = response.replace("Nombre: ", "")
        response = response.replace("Nombre sugerido: ", "")
        return response.strip()

    def _fallback_name(self, tier2: str) -> str:
        """Genera nombre fallback sin LLM."""
        return tier2.replace("_", " ").title()

    # === PERSIST LABELS ===

    def persist_labels(
        self,
        cluster: ProductCluster,
        tier1: str,
        tier2: str,
        display_name: str,
        confidence: float,
        composition_hash: str,
        llm_version: str = LLM_MODEL
    ) -> None:
        """
        Persiste las etiquetas en el cluster.

        Args:
            cluster: Cluster a actualizar
            tier1: Macro-categoría
            tier2: Categoría específica
            display_name: Nombre generado
            confidence: Confianza del naming
            composition_hash: Hash de composición actual
            llm_version: Versión del modelo LLM usado
        """
        cluster.primary_necesidad = tier1
        cluster.primary_subcategory = tier2
        cluster.llm_generated_name = display_name
        cluster.llm_name_confidence = confidence
        cluster.composition_hash = composition_hash
        cluster.labeled_at = datetime.now(timezone.utc)
        cluster.llm_version = llm_version

        self.db.add(cluster)

    # === LABEL SINGLE CLUSTER ===

    def label_cluster(
        self,
        cluster: ProductCluster,
        force: bool = False
    ) -> ClusterLabelingResult:
        """
        Etiqueta un cluster individual.

        Args:
            cluster: Cluster a etiquetar
            force: Forzar re-etiquetado

        Returns:
            ClusterLabelingResult con detalles del proceso
        """
        cluster_id_str = str(cluster.id)

        try:
            # 1. Obtener product_ids del cluster
            product_ids = (
                self.db.query(SalesEnrichment.id)
                .filter(SalesEnrichment.product_cluster_id == cluster.id)
                .all()
            )
            product_ids = [str(p.id) for p in product_ids]

            if not product_ids:
                return ClusterLabelingResult(
                    cluster_id=cluster_id_str,
                    tier1="otros_parafarmacia",
                    tier2="otros",
                    display_name=cluster.name or "Sin productos",
                    confidence=0.0,
                    status="skipped",
                    reason="Cluster sin productos"
                )

            # 2. Calcular composition hash
            new_hash = self.compute_composition_hash(product_ids)

            # 3. Verificar si necesita re-etiquetado
            if not self.needs_relabeling(cluster, new_hash, force):
                return ClusterLabelingResult(
                    cluster_id=cluster_id_str,
                    tier1=cluster.primary_necesidad or "otros_parafarmacia",
                    tier2=cluster.primary_subcategory or "otros",
                    display_name=cluster.llm_generated_name or cluster.name,
                    confidence=float(cluster.llm_name_confidence or 0),
                    status="skipped",
                    reason="Composición sin cambios significativos"
                )

            # 4. Calcular categoría por voto ponderado
            tier1, tier2, _ = self.calculate_weighted_category(cluster.id)

            # 5. Generar nombre con LLM
            products = self._sample_products_for_naming(cluster.id)
            display_name, confidence = self.generate_naming_llm(tier2, products)

            # 6. Persistir
            self.persist_labels(
                cluster=cluster,
                tier1=tier1,
                tier2=tier2,
                display_name=display_name,
                confidence=confidence,
                composition_hash=new_hash
            )

            return ClusterLabelingResult(
                cluster_id=cluster_id_str,
                tier1=tier1,
                tier2=tier2,
                display_name=display_name,
                confidence=confidence,
                status="updated"
            )

        except Exception as e:
            logger.error(f"Error labeling cluster {cluster_id_str}: {e}")
            return ClusterLabelingResult(
                cluster_id=cluster_id_str,
                tier1="otros_parafarmacia",
                tier2="otros",
                display_name=cluster.name or "Error",
                confidence=0.0,
                status="error",
                reason=str(e)
            )

    # === BATCH LABELING ===

    def batch_label_all(
        self,
        cluster_ids: Optional[List[UUID]] = None,
        force: bool = False
    ) -> BatchLabelingResult:
        """
        Etiqueta múltiples clusters en batch.

        Args:
            cluster_ids: Lista específica de clusters (opcional)
            force: Forzar re-etiquetado de todos

        Returns:
            BatchLabelingResult con estadísticas
        """
        import time
        start_time = time.time()

        # Query base
        query = self.db.query(ProductCluster).filter(
            ProductCluster.state.in_([ClusterState.PROVISIONAL, ClusterState.LOCKED])
        )

        if cluster_ids:
            query = query.filter(ProductCluster.id.in_(cluster_ids))

        clusters = query.all()

        if not clusters:
            return BatchLabelingResult(
                success=True,
                message="No hay clusters para etiquetar",
                clusters_processed=0,
                clusters_skipped=0,
                clusters_updated=0,
                clusters_errored=0,
                llm_calls_made=0,
                estimated_cost_usd=0.0,
                execution_time_seconds=0.0,
                details=[]
            )

        # Procesar cada cluster
        details = []
        llm_calls = 0

        for cluster in clusters:
            result = self.label_cluster(cluster, force)
            details.append(result)

            if result.status == "updated":
                llm_calls += 1

        # Commit cambios
        self.db.commit()

        # Estadísticas
        updated = sum(1 for d in details if d.status == "updated")
        skipped = sum(1 for d in details if d.status == "skipped")
        errored = sum(1 for d in details if d.status == "error")

        # Coste estimado (Groq es gratis pero estimamos por si cambia)
        # ~500 tokens por llamada, $0.05/1M tokens
        estimated_cost = (llm_calls * 500 * 0.05) / 1_000_000

        execution_time = time.time() - start_time

        return BatchLabelingResult(
            success=errored == 0,
            message=f"Etiquetado completado: {updated} actualizados, {skipped} sin cambios, {errored} errores",
            clusters_processed=len(clusters),
            clusters_skipped=skipped,
            clusters_updated=updated,
            clusters_errored=errored,
            llm_calls_made=llm_calls,
            estimated_cost_usd=estimated_cost,
            execution_time_seconds=execution_time,
            details=details
        )

    # === STATUS ===

    def get_labeling_status(self) -> Dict[str, Any]:
        """
        Obtiene estado actual del etiquetado de clusters.

        Returns:
            Dict con estadísticas de cobertura
        """
        query = self.db.query(ProductCluster).filter(
            ProductCluster.state.in_([ClusterState.PROVISIONAL, ClusterState.LOCKED])
        )

        total = query.count()
        labeled = query.filter(ProductCluster.labeled_at.isnot(None)).count()
        unlabeled = total - labeled

        # Clusters "stale" (etiquetados hace más de 7 días)
        from datetime import timedelta
        stale_threshold = datetime.now(timezone.utc) - timedelta(days=7)
        stale = query.filter(
            ProductCluster.labeled_at < stale_threshold
        ).count()

        # Distribución por tier1
        tier1_dist = (
            self.db.query(
                ProductCluster.primary_necesidad,
                func.count(ProductCluster.id)
            )
            .filter(ProductCluster.state.in_([ClusterState.PROVISIONAL, ClusterState.LOCKED]))
            .filter(ProductCluster.primary_necesidad.isnot(None))
            .group_by(ProductCluster.primary_necesidad)
            .all()
        )

        # Último etiquetado
        last_labeled = (
            self.db.query(func.max(ProductCluster.labeled_at))
            .filter(ProductCluster.state.in_([ClusterState.PROVISIONAL, ClusterState.LOCKED]))
            .scalar()
        )

        return {
            "total_clusters": total,
            "labeled_clusters": labeled,
            "unlabeled_clusters": unlabeled,
            "stale_clusters": stale,
            "last_labeled_at": last_labeled.isoformat() if last_labeled else None,
            "coverage_percent": (labeled / total * 100) if total > 0 else 0,
            "tier1_distribution": {r[0]: r[1] for r in tier1_dist}
        }


# === FACTORY FUNCTION ===

def get_taxonomy_labeler_service(db: Session) -> TaxonomyLabelerService:
    """Factory function para obtener instancia del servicio."""
    return TaxonomyLabelerService(db)
