# backend/app/services/cluster_management_service.py
"""
Service for managing VentaLibre product clusters (Issue #464).

This service provides operations to split and merge clusters.
Clusters are represented by the `ml_category` field in ProductCatalogVentaLibre.

Key features:
- Atomic transactions for split/merge operations
- SalesEnrichment sync after updates (pattern from FeedbackServiceV2)
- Audit trail using AuditService
- Preview operations before execution

Categories are validated against NecesidadEspecifica enum (~130 values).
"""

import logging
from datetime import datetime, timezone
from typing import List, Optional, Tuple
from uuid import UUID

from sqlalchemy import Integer, func
from sqlalchemy.orm import Session

from app.models.audit_log import AuditAction
from app.models.product_catalog_venta_libre import ProductCatalogVentaLibre
from app.models.sales_enrichment import SalesEnrichment
from app.models.user import User
from app.schemas.cluster_management import (
    CategoryImpact,
    CategoryValidation,
    ClusterMergeRequest,
    ClusterMergeResponse,
    ClusterSplitRequest,
    ClusterSplitResponse,
    ClusterStats,
    MergePreviewResponse,
    ProductPreview,
    SplitPreviewResponse,
)
from app.schemas.symptom_taxonomy import (
    NecesidadEspecifica,
    get_display_name,
    get_parent_category,
)
from app.services.audit_service import AuditService

logger = logging.getLogger(__name__)


class ClusterManagementService:
    """
    Servicio para gestionar clusters (categorías) de productos VentaLibre.

    Los clusters están representados por el campo `ml_category` en
    ProductCatalogVentaLibre. Este servicio permite:
    - Dividir un cluster moviendo productos a una nueva categoría
    - Fusionar múltiples clusters en uno solo
    - Previsualizar operaciones antes de ejecutarlas
    - Obtener estadísticas de clusters
    """

    def __init__(self, db: Session):
        """
        Initialize the service.

        Args:
            db: SQLAlchemy database session
        """
        self.db = db
        self.audit_service = AuditService(db)

    def validate_category_name(self, name: str) -> CategoryValidation:
        """
        Valida si un nombre de categoría es válido.

        Una categoría es válida si existe en el enum NecesidadEspecifica.

        Args:
            name: Nombre de la categoría a validar

        Returns:
            CategoryValidation con resultado de la validación
        """
        name_lower = name.lower().strip()

        try:
            # Validar existencia en enum (raises ValueError si no existe)
            _ = NecesidadEspecifica(name_lower)
            parent = get_parent_category(name_lower)

            return CategoryValidation(
                category=name_lower,
                is_valid=True,
                display_name=get_display_name(name_lower),
                parent_category=parent,
                reason=None,
            )
        except ValueError:
            # Verificar si es un alias común
            aliases = {
                "dolor": "dolor_fiebre",
                "alergia": "alergias",
                "flora_intestinal": "probioticos",
            }

            if name_lower in aliases:
                return CategoryValidation(
                    category=name_lower,
                    is_valid=False,
                    display_name=None,
                    parent_category=None,
                    reason=f"'{name_lower}' es un alias. Usa '{aliases[name_lower]}' en su lugar.",
                )

            return CategoryValidation(
                category=name_lower,
                is_valid=False,
                display_name=None,
                parent_category=None,
                reason=f"'{name_lower}' no es una categoría válida en NecesidadEspecifica.",
            )

    def preview_split(
        self,
        source_category: str,
        product_ids: List[UUID],
        new_category: str,
    ) -> SplitPreviewResponse:
        """
        Previsualiza una operación de split sin ejecutarla.

        Muestra qué productos se moverían y el impacto en ambas categorías.

        Args:
            source_category: Categoría origen
            product_ids: IDs de productos a mover
            new_category: Categoría destino

        Returns:
            SplitPreviewResponse con detalles de la operación
        """
        source_cat = source_category.lower().strip()
        new_cat = new_category.lower().strip()

        # Validar categoría destino
        category_validation = self.validate_category_name(new_cat)
        warnings = []

        if not category_validation.is_valid:
            warnings.append(f"Categoría destino inválida: {category_validation.reason}")

        # Obtener productos a mover
        products = (
            self.db.query(ProductCatalogVentaLibre)
            .filter(ProductCatalogVentaLibre.id.in_(product_ids))
            .all()
        )

        products_preview = []
        total_sales_impact = 0

        for product in products:
            # Verificar que el producto está en la categoría origen
            if product.ml_category != source_cat:
                warnings.append(
                    f"Producto '{product.product_name_display[:30]}...' no está en categoría origen"
                )

            products_preview.append(ProductPreview(
                id=product.id,
                product_name=product.product_name_display,
                current_category=product.ml_category,
                detected_brand=product.detected_brand,
                total_sales_count=product.total_sales_count or 0,
                human_verified=product.human_verified,
            ))
            total_sales_impact += product.total_sales_count or 0

        # Contar productos restantes en origen
        source_remaining = (
            self.db.query(func.count(ProductCatalogVentaLibre.id))
            .filter(ProductCatalogVentaLibre.ml_category == source_cat)
            .filter(~ProductCatalogVentaLibre.id.in_(product_ids))
            .scalar()
        )

        # Productos no encontrados
        found_ids = {p.id for p in products}
        not_found = [pid for pid in product_ids if pid not in found_ids]
        if not_found:
            warnings.append(f"{len(not_found)} producto(s) no encontrado(s)")

        return SplitPreviewResponse(
            source_category=source_cat,
            new_category=new_cat,
            is_valid_category=category_validation.is_valid,
            products_to_move=products_preview,
            products_count=len(products),
            source_category_remaining=source_remaining or 0,
            total_sales_impact=total_sales_impact,
            warnings=warnings,
        )

    def preview_merge(
        self,
        source_categories: List[str],
        destination_category: str,
    ) -> MergePreviewResponse:
        """
        Previsualiza una operación de merge sin ejecutarla.

        Muestra qué categorías se fusionarían y el impacto.

        Args:
            source_categories: Lista de categorías origen
            destination_category: Categoría destino

        Returns:
            MergePreviewResponse con detalles de la operación
        """
        source_cats = [cat.lower().strip() for cat in source_categories]
        dest_cat = destination_category.lower().strip()

        # Validar categoría destino
        category_validation = self.validate_category_name(dest_cat)
        warnings = []

        if not category_validation.is_valid:
            warnings.append(f"Categoría destino inválida: {category_validation.reason}")

        # Verificar que destino no está en origen
        if dest_cat in source_cats:
            warnings.append("La categoría destino no puede estar en las categorías origen")

        # Obtener impacto por categoría origen
        source_impacts = []
        total_products = 0
        total_sales = 0
        total_verified = 0

        for source_cat in source_cats:
            stats = (
                self.db.query(
                    func.count(ProductCatalogVentaLibre.id).label("count"),
                    func.coalesce(func.sum(ProductCatalogVentaLibre.total_sales_count), 0).label("sales"),
                    func.sum(
                        func.cast(ProductCatalogVentaLibre.human_verified, Integer)
                    ).label("verified"),
                )
                .filter(ProductCatalogVentaLibre.ml_category == source_cat)
                .first()
            )

            verified_count = int(stats.verified or 0) if stats else 0

            impact = CategoryImpact(
                category=source_cat,
                product_count=stats.count if stats else 0,
                total_sales=int(stats.sales) if stats else 0,
                verified_count=verified_count,
            )
            source_impacts.append(impact)

            total_products += impact.product_count
            total_sales += impact.total_sales
            total_verified += impact.verified_count

        # Productos actuales en destino
        dest_current = (
            self.db.query(func.count(ProductCatalogVentaLibre.id))
            .filter(ProductCatalogVentaLibre.ml_category == dest_cat)
            .scalar()
        ) or 0

        return MergePreviewResponse(
            source_categories=source_cats,
            destination_category=dest_cat,
            is_valid_destination=category_validation.is_valid,
            source_impacts=source_impacts,
            total_products_to_merge=total_products,
            total_sales_impact=total_sales,
            total_verified_products=total_verified,
            destination_current_count=dest_current,
            destination_after_merge=dest_current + total_products,
            warnings=warnings,
        )

    def split_cluster(
        self,
        request: ClusterSplitRequest,
        user: Optional[User] = None,
    ) -> ClusterSplitResponse:
        """
        Divide un cluster moviendo productos seleccionados a una nueva categoría.

        Esta operación es atómica: o todos los productos se mueven o ninguno.
        Después de mover, sincroniza los registros de SalesEnrichment.

        Args:
            request: Detalles de la operación de split
            user: Usuario que ejecuta la operación (para audit)

        Returns:
            ClusterSplitResponse con resultado de la operación

        Raises:
            ValueError: Si la categoría destino es inválida
        """
        source_cat = request.source_category.lower().strip()
        new_cat = request.new_category.lower().strip()

        # Validar categoría destino
        validation = self.validate_category_name(new_cat)
        if not validation.is_valid:
            raise ValueError(f"Categoría inválida: {validation.reason}")

        logger.info(
            f"Iniciando split: {len(request.product_ids)} productos de "
            f"'{source_cat}' a '{new_cat}'"
        )

        try:
            # Obtener productos a mover
            products = (
                self.db.query(ProductCatalogVentaLibre)
                .filter(ProductCatalogVentaLibre.id.in_(request.product_ids))
                .all()
            )

            if not products:
                raise ValueError("No se encontraron productos con los IDs proporcionados")

            # Mover productos (actualizar ml_category y marcar como verificado)
            moved_count = 0
            for product in products:
                product.ml_category = new_cat
                product.verified_category = new_cat
                product.human_verified = True
                product.verified_at = datetime.now(timezone.utc)
                product.reviewer_notes = (
                    f"Split de '{source_cat}' a '{new_cat}'. "
                    f"{request.notes or 'Sin notas adicionales.'}"
                )
                product.prediction_source = "HUMAN"
                moved_count += 1

            # Commit cambios en catálogo
            self.db.flush()

            # Sincronizar SalesEnrichment (patrón de FeedbackServiceV2)
            enrichments_synced = self._sync_enrichments_batch(products)

            self.db.commit()

            # Contar productos restantes en origen y total en destino
            source_remaining = (
                self.db.query(func.count(ProductCatalogVentaLibre.id))
                .filter(ProductCatalogVentaLibre.ml_category == source_cat)
                .scalar()
            ) or 0

            new_total = (
                self.db.query(func.count(ProductCatalogVentaLibre.id))
                .filter(ProductCatalogVentaLibre.ml_category == new_cat)
                .scalar()
            ) or 0

            # Audit log
            audit_log = self.audit_service.log_action(
                action=AuditAction.UPDATE,
                method="POST",
                endpoint="/api/v1/ventalibre/clusters/split",
                user=user,
                resource_type="cluster_split",
                resource_id=f"{source_cat}->{new_cat}",
                description=f"Split cluster: {moved_count} productos de '{source_cat}' a '{new_cat}'",
                details={
                    "source_category": source_cat,
                    "new_category": new_cat,
                    "products_moved": moved_count,
                    "product_ids": [str(pid) for pid in request.product_ids],
                    "enrichments_synced": enrichments_synced,
                    "notes": request.notes,
                },
                success=True,
            )

            logger.info(
                f"Split completado: {moved_count} productos, "
                f"{enrichments_synced} enrichments sincronizados"
            )

            return ClusterSplitResponse(
                success=True,
                message=f"Split completado: {moved_count} productos movidos a '{new_cat}'",
                source_category=source_cat,
                new_category=new_cat,
                products_moved=moved_count,
                enrichments_synced=enrichments_synced,
                audit_log_id=str(audit_log.id) if audit_log else None,
                executed_at=datetime.now(timezone.utc),
                executed_by=user.email if user else None,
                source_category_remaining=source_remaining,
                new_category_total=new_total,
            )

        except Exception as e:
            self.db.rollback()
            logger.error(f"Error en split: {e}")

            # Log failed operation
            self.audit_service.log_action(
                action=AuditAction.UPDATE,
                method="POST",
                endpoint="/api/v1/ventalibre/clusters/split",
                user=user,
                resource_type="cluster_split",
                resource_id=f"{source_cat}->{new_cat}",
                description=f"Split cluster fallido",
                details={
                    "source_category": source_cat,
                    "new_category": new_cat,
                    "error": str(e),
                },
                success=False,
                error_message=str(e),
            )

            raise

    def merge_clusters(
        self,
        request: ClusterMergeRequest,
        user: Optional[User] = None,
    ) -> ClusterMergeResponse:
        """
        Fusiona múltiples categorías en una categoría destino.

        Mueve TODOS los productos de las categorías origen a la categoría destino.
        Esta operación es atómica.

        Args:
            request: Detalles de la operación de merge
            user: Usuario que ejecuta la operación (para audit)

        Returns:
            ClusterMergeResponse con resultado de la operación

        Raises:
            ValueError: Si la categoría destino es inválida
        """
        source_cats = [cat.lower().strip() for cat in request.source_categories]
        dest_cat = request.destination_category.lower().strip()

        # Validar categoría destino
        validation = self.validate_category_name(dest_cat)
        if not validation.is_valid:
            raise ValueError(f"Categoría destino inválida: {validation.reason}")

        # Verificar que destino no está en origen
        if dest_cat in source_cats:
            raise ValueError("La categoría destino no puede estar en las categorías origen")

        logger.info(
            f"Iniciando merge: {source_cats} -> '{dest_cat}'"
        )

        try:
            products_per_source = {}
            all_products = []

            # Obtener y mover productos de cada categoría origen
            for source_cat in source_cats:
                products = (
                    self.db.query(ProductCatalogVentaLibre)
                    .filter(ProductCatalogVentaLibre.ml_category == source_cat)
                    .all()
                )

                products_per_source[source_cat] = len(products)
                all_products.extend(products)

                # Actualizar cada producto
                for product in products:
                    product.ml_category = dest_cat
                    product.verified_category = dest_cat
                    product.human_verified = True
                    product.verified_at = datetime.now(timezone.utc)
                    product.reviewer_notes = (
                        f"Merge de '{source_cat}' a '{dest_cat}'. "
                        f"{request.notes or 'Sin notas adicionales.'}"
                    )
                    product.prediction_source = "HUMAN"

            total_merged = len(all_products)

            if total_merged == 0:
                raise ValueError("No se encontraron productos en las categorías origen")

            # Commit cambios en catálogo
            self.db.flush()

            # Sincronizar SalesEnrichment
            enrichments_synced = self._sync_enrichments_batch(all_products)

            self.db.commit()

            # Contar total en destino después del merge
            dest_total = (
                self.db.query(func.count(ProductCatalogVentaLibre.id))
                .filter(ProductCatalogVentaLibre.ml_category == dest_cat)
                .scalar()
            ) or 0

            # Audit log
            audit_log = self.audit_service.log_action(
                action=AuditAction.UPDATE,
                method="POST",
                endpoint="/api/v1/ventalibre/clusters/merge",
                user=user,
                resource_type="cluster_merge",
                resource_id=f"{','.join(source_cats)}->{dest_cat}",
                description=f"Merge clusters: {total_merged} productos de {source_cats} a '{dest_cat}'",
                details={
                    "source_categories": source_cats,
                    "destination_category": dest_cat,
                    "total_products_merged": total_merged,
                    "products_per_source": products_per_source,
                    "enrichments_synced": enrichments_synced,
                    "notes": request.notes,
                },
                success=True,
            )

            logger.info(
                f"Merge completado: {total_merged} productos, "
                f"{enrichments_synced} enrichments sincronizados"
            )

            return ClusterMergeResponse(
                success=True,
                message=f"Merge completado: {total_merged} productos movidos a '{dest_cat}'",
                source_categories=source_cats,
                destination_category=dest_cat,
                total_products_merged=total_merged,
                enrichments_synced=enrichments_synced,
                products_per_source=products_per_source,
                audit_log_id=str(audit_log.id) if audit_log else None,
                executed_at=datetime.now(timezone.utc),
                executed_by=user.email if user else None,
                destination_total=dest_total,
            )

        except Exception as e:
            self.db.rollback()
            logger.error(f"Error en merge: {e}")

            # Log failed operation
            self.audit_service.log_action(
                action=AuditAction.UPDATE,
                method="POST",
                endpoint="/api/v1/ventalibre/clusters/merge",
                user=user,
                resource_type="cluster_merge",
                resource_id=f"{','.join(source_cats)}->{dest_cat}",
                description=f"Merge clusters fallido",
                details={
                    "source_categories": source_cats,
                    "destination_category": dest_cat,
                    "error": str(e),
                },
                success=False,
                error_message=str(e),
            )

            raise

    def get_cluster_stats(self, category: str) -> ClusterStats:
        """
        Obtiene estadísticas detalladas de un cluster (categoría).

        Args:
            category: Nombre de la categoría

        Returns:
            ClusterStats con estadísticas del cluster
        """
        cat = category.lower().strip()

        # Query base para productos en esta categoría
        base_query = self.db.query(ProductCatalogVentaLibre).filter(
            ProductCatalogVentaLibre.ml_category == cat,
            ProductCatalogVentaLibre.is_active == True,
        )

        # Conteos básicos
        total = base_query.count()
        verified = base_query.filter(
            ProductCatalogVentaLibre.human_verified == True
        ).count()

        # Estadísticas agregadas
        stats = (
            self.db.query(
                func.coalesce(func.sum(ProductCatalogVentaLibre.total_sales_count), 0).label("sales"),
                func.coalesce(func.sum(ProductCatalogVentaLibre.pharmacies_count), 0).label("pharmacies"),
                func.avg(ProductCatalogVentaLibre.ml_confidence).label("avg_confidence"),
            )
            .filter(
                ProductCatalogVentaLibre.ml_category == cat,
                ProductCatalogVentaLibre.is_active == True,
            )
            .first()
        )

        # Top brands
        top_brands_query = (
            self.db.query(
                ProductCatalogVentaLibre.detected_brand,
                func.count(ProductCatalogVentaLibre.id).label("count"),
            )
            .filter(
                ProductCatalogVentaLibre.ml_category == cat,
                ProductCatalogVentaLibre.is_active == True,
                ProductCatalogVentaLibre.detected_brand.isnot(None),
            )
            .group_by(ProductCatalogVentaLibre.detected_brand)
            .order_by(func.count(ProductCatalogVentaLibre.id).desc())
            .limit(5)
            .all()
        )

        top_brands = [brand for brand, _ in top_brands_query]

        # Conteo de marcas únicas
        brand_count = (
            self.db.query(func.count(func.distinct(ProductCatalogVentaLibre.detected_brand)))
            .filter(
                ProductCatalogVentaLibre.ml_category == cat,
                ProductCatalogVentaLibre.is_active == True,
                ProductCatalogVentaLibre.detected_brand.isnot(None),
            )
            .scalar()
        ) or 0

        verification_rate = round(100 * verified / total, 1) if total > 0 else 0.0

        return ClusterStats(
            category=cat,
            display_name=get_display_name(cat),
            total_products=total,
            verified_products=verified,
            pending_verification=total - verified,
            total_sales_count=int(stats.sales) if stats else 0,
            pharmacies_count=int(stats.pharmacies) if stats else 0,
            top_brands=top_brands,
            brand_count=brand_count,
            avg_confidence=float(stats.avg_confidence) if stats and stats.avg_confidence else None,
            verification_rate=verification_rate,
        )

    def get_all_clusters_stats(self, limit: int = 50) -> Tuple[List[ClusterStats], int, int]:
        """
        Obtiene estadísticas de todos los clusters ordenados por número de productos.

        Args:
            limit: Máximo de clusters a devolver

        Returns:
            Tuple con (lista de ClusterStats, total clusters, total productos)
        """
        # Obtener categorías únicas con conteos
        categories_query = (
            self.db.query(
                ProductCatalogVentaLibre.ml_category,
                func.count(ProductCatalogVentaLibre.id).label("count"),
            )
            .filter(
                ProductCatalogVentaLibre.is_active == True,
                ProductCatalogVentaLibre.ml_category.isnot(None),
            )
            .group_by(ProductCatalogVentaLibre.ml_category)
            .order_by(func.count(ProductCatalogVentaLibre.id).desc())
            .limit(limit)
            .all()
        )

        clusters = []
        total_products = 0

        for cat, count in categories_query:
            stats = self.get_cluster_stats(cat)
            clusters.append(stats)
            total_products += count

        total_verified = sum(c.verified_products for c in clusters)

        return clusters, len(clusters), total_products

    def _sync_enrichments_batch(self, products: List[ProductCatalogVentaLibre]) -> int:
        """
        Sincroniza SalesEnrichment records con la clasificación actualizada del catálogo.

        Patrón de FeedbackServiceV2._sync_enrichments_from_catalog().
        Actualiza ml_category y ml_confidence en todos los SalesEnrichment
        que referencian los productos actualizados.

        Args:
            products: Lista de productos del catálogo actualizados

        Returns:
            Número de registros SalesEnrichment actualizados
        """
        if not products:
            return 0

        total_updated = 0

        for product in products:
            category = product.verified_category or product.ml_category

            updated = (
                self.db.query(SalesEnrichment)
                .filter(SalesEnrichment.venta_libre_product_id == product.id)
                .update(
                    {
                        SalesEnrichment.ml_category: category,
                        SalesEnrichment.ml_confidence: 1.0 if product.human_verified else product.ml_confidence,
                    },
                    synchronize_session=False,
                )
            )
            total_updated += updated

        return total_updated


# === Factory function ===


def get_cluster_management_service(db: Session) -> ClusterManagementService:
    """
    Factory para obtener instancia del servicio.

    Args:
        db: Sesión de SQLAlchemy

    Returns:
        Instancia de ClusterManagementService
    """
    return ClusterManagementService(db)
