# backend/app/measures/prescription_measures.py
"""
Medidas de prescripción farmacéutica - Issue #484

Medidas incluidas:
- PrescriptionSales: Total de ventas de productos de prescripción
- PrescriptionUnits: Total de unidades de productos de prescripción
- PrescriptionCategoryDistribution: Distribución por categoría de prescripción
- PrescriptionATCDistribution: Distribución por código ATC
- PrescriptionPercentage: % de ventas de prescripción sobre total farmacia
"""

import logging
from typing import Any, Dict, List

from sqlalchemy import func, literal

from app.models.sales_data import SalesData
from app.models.sales_enrichment import SalesEnrichment
from app.models.product_catalog import ProductCatalog

from .base import BaseMeasure, QueryContext

logger = logging.getLogger(__name__)


class PrescriptionSales(BaseMeasure):
    """
    Total de ventas de productos de prescripción.
    Filtra por product_type='prescription' en SalesEnrichment.
    """

    def __init__(self):
        super().__init__()
        self.description = "Importe total de ventas de prescripción"
        self.unit = "€"
        self.category = "Prescripción"

    def calculate(self, context: QueryContext) -> float:
        """Calcular suma de ventas de prescripción."""
        # Usar sales_enrichment_query para acceder a product_type
        query = context.sales_enrichment_query

        # Filtrar solo prescripción
        query = query.filter(SalesEnrichment.product_type == "prescription")

        result = query.with_entities(
            func.sum(SalesData.total_amount).label("total")
        ).scalar()

        return float(result or 0)


class PrescriptionUnits(BaseMeasure):
    """
    Total de unidades de productos de prescripción.
    Filtra por product_type='prescription' en SalesEnrichment.
    """

    def __init__(self):
        super().__init__()
        self.description = "Unidades totales vendidas de prescripción"
        self.unit = "unidades"
        self.category = "Prescripción"

    def calculate(self, context: QueryContext) -> int:
        """Calcular suma de unidades de prescripción."""
        query = context.sales_enrichment_query
        query = query.filter(SalesEnrichment.product_type == "prescription")

        result = query.with_entities(
            func.sum(SalesData.quantity).label("total")
        ).scalar()

        return int(result or 0)


class PrescriptionPercentage(BaseMeasure):
    """
    Porcentaje de ventas de prescripción sobre el total.
    Equivale a [PrescriptionSales] / [TotalVentas] * 100.
    """

    def __init__(self):
        super().__init__()
        self.description = "Porcentaje de ventas de prescripción"
        self.unit = "%"
        self.category = "Prescripción"
        self.dependencies = ["PrescriptionSales"]

    def calculate(self, context: QueryContext) -> float:
        """Calcular porcentaje de prescripción."""
        # Ventas de prescripción
        prescription_sales = PrescriptionSales().calculate(context)

        # Ventas totales (sin filtro de product_type)
        total_result = context.base_query.with_entities(
            func.sum(SalesData.total_amount).label("total")
        ).scalar()
        total_sales = float(total_result or 0)

        if total_sales == 0:
            return 0.0

        percentage = (prescription_sales / total_sales) * 100
        return round(percentage, 2)


class PrescriptionCategoryDistribution(BaseMeasure):
    """
    Distribución de ventas por categoría de prescripción.

    Retorna diccionario con:
    - category_summary: Lista de categorías con ventas, unidades, %
    - total_sales: Total de ventas de prescripción
    - total_units: Total de unidades de prescripción
    """

    def __init__(self):
        super().__init__()
        self.description = "Distribución de ventas por categoría de prescripción"
        self.unit = "distribución"
        self.category = "Prescripción"

    def calculate(self, context: QueryContext) -> Dict[str, Any]:
        """Calcular distribución por categoría."""
        query = context.sales_enrichment_query

        # Filtrar solo prescripción
        query = query.filter(SalesEnrichment.product_type == "prescription")

        # Agrupar por prescription_category
        results = query.with_entities(
            SalesEnrichment.prescription_category,
            func.sum(SalesData.total_amount).label("sales"),
            func.sum(SalesData.quantity).label("units"),
            func.count(SalesData.id).label("records"),
        ).group_by(
            SalesEnrichment.prescription_category
        ).all()

        # Calcular total
        total_sales = sum(float(r.sales or 0) for r in results)
        total_units = sum(int(r.units or 0) for r in results)

        # Construir resumen de categorías
        category_summary = []
        for row in results:
            category = row.prescription_category or "sin_clasificar"
            sales = float(row.sales or 0)
            units = int(row.units or 0)
            pct = (sales / total_sales * 100) if total_sales > 0 else 0

            category_summary.append({
                "category": category,
                "total_sales": round(sales, 2),
                "total_units": units,
                "percentage": round(pct, 2),
                "records": row.records,
            })

        # Ordenar por ventas descendente
        category_summary.sort(key=lambda x: x["total_sales"], reverse=True)

        return {
            "category_summary": category_summary,
            "total_sales": round(total_sales, 2),
            "total_units": total_units,
            "categories_count": len(category_summary),
        }


class PrescriptionATCDistribution(BaseMeasure):
    """
    Distribución de ventas por código ATC.

    Retorna diccionario con:
    - atc_distribution: Lista de códigos ATC con ventas, unidades, %
    - uncategorized: Ventas sin código ATC
    """

    def __init__(self):
        super().__init__()
        self.description = "Distribución de ventas por código ATC"
        self.unit = "distribución"
        self.category = "Prescripción"

    def calculate(self, context: QueryContext) -> Dict[str, Any]:
        """Calcular distribución por código ATC - optimizado (1 query)."""
        # Determinar nivel ATC desde filtros (default nivel 1)
        atc_level = getattr(context.filters, 'atc_level', 1) or 1

        # Necesitamos enriched_query para acceder a atc_code
        query = context.enriched_query

        # Filtrar solo prescripción
        query = query.filter(SalesEnrichment.product_type == "prescription")

        # Definir longitud de código ATC según nivel
        atc_lengths = {1: 1, 2: 3, 3: 4, 4: 5, 5: 7}
        substr_len = atc_lengths.get(atc_level, 1)

        # Usar COALESCE para agrupar NULL/empty como '__UNCATEGORIZED__'
        atc_substr = func.left(ProductCatalog.atc_code, substr_len)
        atc_grouped = func.coalesce(
            func.nullif(atc_substr, ''),
            literal('__UNCATEGORIZED__')
        )

        # Single query con agregación condicional
        results = query.with_entities(
            atc_grouped.label("atc_code"),
            func.sum(SalesData.total_amount).label("sales"),
            func.sum(SalesData.quantity).label("units"),
        ).group_by(
            atc_grouped
        ).order_by(
            func.sum(SalesData.total_amount).desc()
        ).all()

        # Procesar resultados en una sola pasada
        atc_distribution = []
        uncategorized = {"sales": 0.0, "units": 0, "percentage": 0.0}
        total_sales = sum(float(r.sales or 0) for r in results)

        for row in results:
            sales = float(row.sales or 0)
            pct = (sales / total_sales * 100) if total_sales > 0 else 0

            if row.atc_code == '__UNCATEGORIZED__':
                uncategorized = {
                    "sales": round(sales, 2),
                    "units": int(row.units or 0),
                    "percentage": round(pct, 2),
                }
            else:
                atc_distribution.append({
                    "atc_code": row.atc_code,
                    "atc_name": _get_atc_name(row.atc_code),
                    "sales": round(sales, 2),
                    "units": int(row.units or 0),
                    "percentage": round(pct, 2),
                })

        # Calcular cobertura ATC
        categorized_sales = total_sales - uncategorized["sales"]
        atc_coverage = (categorized_sales / total_sales * 100) if total_sales > 0 else 0

        return {
            "atc_distribution": atc_distribution,
            "uncategorized": uncategorized,
            "atc_coverage": round(atc_coverage, 2),
            "atc_level": atc_level,
            "total_sales": round(total_sales, 2),
        }


class PrescriptionKPIs(BaseMeasure):
    """
    KPIs agregados de prescripción - Medida compuesta.

    Retorna todos los KPIs principales en una sola llamada:
    - total_sales: Ventas totales de prescripción
    - total_units: Unidades totales
    - prescription_percentage: % sobre ventas totales
    - atc_coverage: % de ventas con código ATC
    - ticket_promedio: Ticket promedio de prescripción
    """

    def __init__(self):
        super().__init__()
        self.description = "KPIs agregados de prescripción"
        self.unit = "kpis"
        self.category = "Prescripción"
        self.dependencies = [
            "PrescriptionSales",
            "PrescriptionUnits",
            "PrescriptionPercentage",
        ]

    def calculate(self, context: QueryContext) -> Dict[str, Any]:
        """Calcular todos los KPIs de prescripción - usando registry."""
        # Import local para evitar circular import
        from app.measures import measure_registry

        # Usar registry para beneficios de cache
        prescription_sales_measure = measure_registry.get_measure("prescription_sales")
        prescription_units_measure = measure_registry.get_measure("prescription_units")
        prescription_pct_measure = measure_registry.get_measure("prescription_percentage")
        atc_distribution_measure = measure_registry.get_measure("prescription_atc_distribution")

        total_sales = prescription_sales_measure.calculate(context)
        total_units = prescription_units_measure.calculate(context)
        prescription_pct = prescription_pct_measure.calculate(context)

        # Calcular ticket promedio (inline - más simple que medida separada)
        query = context.sales_enrichment_query
        query = query.filter(SalesEnrichment.product_type == "prescription")

        num_transactions = query.with_entities(
            func.count(func.distinct(SalesData.operation_id)).label("count")
        ).scalar()
        num_transactions = int(num_transactions or 0)

        ticket_promedio = (total_sales / num_transactions) if num_transactions > 0 else 0

        # Cobertura ATC desde registry
        atc_result = atc_distribution_measure.calculate(context)
        atc_coverage = atc_result.get("atc_coverage", 0)

        return {
            "total_sales": round(total_sales, 2),
            "total_units": total_units,
            "prescription_percentage": round(prescription_pct, 2),
            "atc_coverage": round(atc_coverage, 2),
            "ticket_promedio": round(ticket_promedio, 2),
            "num_transactions": num_transactions,
        }


def _get_atc_name(atc_code: str) -> str:
    """
    Helper para obtener nombre de código ATC.

    Nivel 1 (A-V): Grupos anatómicos principales
    """
    atc_level1_names = {
        "A": "Tracto alimentario y metabolismo",
        "B": "Sangre y órganos hematopoyéticos",
        "C": "Sistema cardiovascular",
        "D": "Dermatológicos",
        "G": "Sistema genitourinario y hormonas sexuales",
        "H": "Preparados hormonales sistémicos",
        "J": "Antiinfecciosos sistémicos",
        "L": "Agentes antineoplásicos e inmunomoduladores",
        "M": "Sistema musculoesquelético",
        "N": "Sistema nervioso",
        "P": "Antiparasitarios, insecticidas y repelentes",
        "R": "Sistema respiratorio",
        "S": "Órganos de los sentidos",
        "V": "Varios",
    }

    if not atc_code:
        return "Sin clasificar"

    # Nivel 1
    if len(atc_code) == 1:
        return atc_level1_names.get(atc_code.upper(), atc_code)

    # Niveles 2-5: Solo devolver código por ahora
    # TODO: Integrar con nomenclator para nombres completos
    return atc_code
