#!/usr/bin/env python3
"""
Script para corregir registros con product_type NULL que tienen catálogo.

Issue #446: Bug encontrado donde reenrichment_service.py creaba enrichments
sin establecer product_type. Este script corrige los 4,521 registros afectados.

Bug corregido en: backend/app/services/reenrichment_service.py
- Ahora siempre llama a _derive_product_type al crear/actualizar enrichments

Uso:
    python scripts/backfill_null_product_type.py
    python scripts/backfill_null_product_type.py --dry-run
"""
import argparse
import sys
import logging
from pathlib import Path
from datetime import datetime

# Añadir raíz del proyecto al PYTHONPATH
sys.path.insert(0, str(Path(__file__).parent.parent))

from app.database import SessionLocal
from app.models.sales_enrichment import SalesEnrichment
from app.models.product_catalog import ProductCatalog
from app.services.enrichment_service import EnrichmentService

# Configurar logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)


def analyze_null_records(db):
    """Analiza los registros con NULL product_type que tienen catálogo"""

    # Contar total con NULL y catálogo
    null_with_catalog = db.query(SalesEnrichment).filter(
        SalesEnrichment.product_type.is_(None),
        SalesEnrichment.product_catalog_id.isnot(None)
    ).count()

    logger.info(f"Registros con product_type=NULL y catálogo: {null_with_catalog:,}")

    # Analizar por categoría de catálogo
    from sqlalchemy import func

    breakdown = db.query(
        ProductCatalog.xfarma_prescription_category,
        func.count(SalesEnrichment.id).label('count')
    ).join(
        SalesEnrichment,
        SalesEnrichment.product_catalog_id == ProductCatalog.id
    ).filter(
        SalesEnrichment.product_type.is_(None)
    ).group_by(
        ProductCatalog.xfarma_prescription_category
    ).all()

    logger.info("\nBreakdown por categoría de prescripción:")
    for category, count in sorted(breakdown, key=lambda x: x[1], reverse=True):
        display_cat = category if category else "NULL (venta libre)"
        logger.info(f"  {display_cat:.<40} {count:>8,}")

    return null_with_catalog


def backfill_null_product_type(db, dry_run=False):
    """Corrige registros con product_type NULL que tienen catálogo"""

    # Contar registros a corregir
    null_count = db.query(SalesEnrichment).filter(
        SalesEnrichment.product_type.is_(None),
        SalesEnrichment.product_catalog_id.isnot(None)
    ).count()

    logger.info(f"Registros a corregir: {null_count:,}")

    if null_count == 0:
        logger.info("No hay registros para corregir")
        return 0, 0

    if dry_run:
        logger.info("[DRY-RUN] No se actualizarán registros")
        return 0, 0

    # Inicializar servicio de enriquecimiento (para usar _derive_product_type)
    enrichment_service = EnrichmentService()

    # Procesar en batches
    batch_size = 1000
    updated_count = 0
    to_prescription = 0
    to_free_sale = 0
    errors = 0

    logger.info("Iniciando corrección...")

    while True:
        # Obtener batch de registros NULL con catálogo
        batch = db.query(SalesEnrichment).filter(
            SalesEnrichment.product_type.is_(None),
            SalesEnrichment.product_catalog_id.isnot(None)
        ).limit(batch_size).all()

        if not batch:
            break

        for enrichment in batch:
            try:
                # Obtener producto del catálogo
                product_catalog = db.query(ProductCatalog).filter(
                    ProductCatalog.id == enrichment.product_catalog_id
                ).first()

                if not product_catalog:
                    logger.warning(f"Catálogo no encontrado para enrichment {enrichment.id}")
                    # Sin catálogo, asumir free_sale
                    enrichment.product_type = 'free_sale'
                    to_free_sale += 1
                    updated_count += 1
                    continue

                # Derivar product_type usando la regla estándar
                new_product_type = enrichment_service._derive_product_type(product_catalog)

                # Actualizar
                enrichment.product_type = new_product_type

                if new_product_type == 'prescription':
                    to_prescription += 1
                else:
                    to_free_sale += 1

                updated_count += 1

                if updated_count % 500 == 0:
                    logger.info(f"Progreso: {updated_count:,}/{null_count:,} registros corregidos")
                    db.commit()  # Commit incremental

            except Exception as e:
                logger.error(f"Error procesando enrichment {enrichment.id}: {e}")
                errors += 1
                continue

        # Commit del batch
        db.commit()

    logger.info(f"Corrección completada: {updated_count:,} registros actualizados")
    logger.info(f"  → prescription: {to_prescription:,}")
    logger.info(f"  → free_sale: {to_free_sale:,}")
    if errors > 0:
        logger.warning(f"Errores: {errors}")

    return updated_count, errors


def verify_fix(db):
    """Verifica que no quedan registros NULL con catálogo"""
    null_with_catalog = db.query(SalesEnrichment).filter(
        SalesEnrichment.product_type.is_(None),
        SalesEnrichment.product_catalog_id.isnot(None)
    ).count()

    if null_with_catalog == 0:
        logger.info("✅ Verificación exitosa: 0 registros NULL con catálogo")
    else:
        logger.warning(f"⚠️  Todavía quedan {null_with_catalog:,} registros NULL con catálogo")

    # Mostrar distribución actualizada
    from sqlalchemy import func

    distribution = db.query(
        SalesEnrichment.product_type,
        func.count(SalesEnrichment.id)
    ).group_by(
        SalesEnrichment.product_type
    ).all()

    logger.info("\nDistribución actualizada de product_type:")
    for ptype, count in sorted(distribution, key=lambda x: x[1], reverse=True):
        display_type = ptype if ptype else "NULL"
        logger.info(f"  {display_type:.<30} {count:>8,}")

    # Verificar NULL restantes (sin catálogo, es esperado)
    null_no_catalog = db.query(SalesEnrichment).filter(
        SalesEnrichment.product_type.is_(None),
        SalesEnrichment.product_catalog_id.is_(None)
    ).count()

    if null_no_catalog > 0:
        logger.info(f"\nNota: {null_no_catalog:,} registros NULL sin catálogo (esperado, son productos no encontrados)")

    return null_with_catalog == 0


def main():
    parser = argparse.ArgumentParser(
        description="Corregir registros con product_type NULL que tienen catálogo"
    )
    parser.add_argument(
        '--dry-run',
        action='store_true',
        help="Solo mostrar estadísticas sin actualizar"
    )
    args = parser.parse_args()

    print("=" * 80)
    print("BACKFILL: CORREGIR product_type NULL CON CATÁLOGO")
    print("=" * 80)
    print(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"Modo: {'DRY-RUN' if args.dry_run else 'ACTUALIZACIÓN REAL'}")
    print("=" * 80)
    print()

    db = SessionLocal()

    try:
        # Analizar estado actual
        print("ANÁLISIS INICIAL")
        print("-" * 80)
        analyze_null_records(db)
        print()

        # Ejecutar corrección
        print("CORRECCIÓN")
        print("-" * 80)
        updated, errors = backfill_null_product_type(db, args.dry_run)

        if not args.dry_run and updated > 0:
            print()
            print("=" * 80)
            print("VERIFICACIÓN POST-CORRECCIÓN")
            print("=" * 80)

            # Verificar fix
            is_fixed = verify_fix(db)

            print()
            print("=" * 80)
            print("RESUMEN")
            print("=" * 80)
            print(f"  Registros corregidos: {updated:,}")
            print(f"  Errores: {errors}")
            print(f"  Estado: {'✅ CORREGIDO' if is_fixed else '⚠️  REVISAR'}")
            print("=" * 80)

    except Exception as e:
        logger.error(f"Error durante backfill: {e}", exc_info=True)
        db.rollback()
        sys.exit(1)
    finally:
        db.close()


if __name__ == "__main__":
    main()
