"""
Backfill eficiente de códigos ATC desde CIMA.

Este script:
1. Obtiene nregistros únicos que no tienen ATC
2. Consulta la API CIMA una vez por nregistro único
3. Actualiza TODOS los productos con ese nregistro
4. Marca productos sin ATC disponible para evitar reprocesar
"""

import asyncio
import httpx
from sqlalchemy import or_, update
from datetime import datetime

# Add parent to path
import sys
sys.path.insert(0, '/app')

from app.database import SessionLocal
from app.models.product_catalog import ProductCatalog


async def backfill_atc_efficient():
    """
    Backfill eficiente de ATCs:
    1. Obtener nregistros únicos sin ATC
    2. Consultar API CIMA una vez por nregistro
    3. Actualizar TODOS los productos con ese nregistro
    4. Marcar sin ATC disponible con valor especial
    """
    db = SessionLocal()
    start = datetime.now()

    # Obtener nregistros únicos que necesitan ATCs
    # Excluir productos ya marcados como NO_ATC_DISPONIBLE
    unique_nregs_query = db.query(
        ProductCatalog.cima_nregistro
    ).filter(
        ProductCatalog.cima_nregistro.isnot(None),
        ProductCatalog.cima_nregistro != '',
        or_(
            ProductCatalog.cima_atc_code.is_(None),
            ProductCatalog.cima_atc_code == ''
        )
    ).distinct().all()

    unique_nregs = [r[0] for r in unique_nregs_query]
    total_nregs = len(unique_nregs)

    print(f'Total nregistros únicos sin ATC: {total_nregs:,}')
    print(f'Iniciando backfill a las {start.strftime("%H:%M:%S")}...')

    processed = 0
    with_atc = 0
    no_atc = 0
    errors = 0

    semaphore = asyncio.Semaphore(10)  # 10 concurrent requests

    async def fetch_atc(client, nregistro):
        """Fetch ATC from CIMA API"""
        async with semaphore:
            try:
                resp = await client.get(
                    'https://cima.aemps.es/cima/rest/medicamento',
                    params={'nregistro': nregistro}
                )
                if resp.status_code == 200:
                    data = resp.json()
                    atcs = data.get('atcs', [])
                    if atcs:
                        # Find nivel 5 (most specific) or last one
                        atc_principal = None
                        for atc in atcs:
                            if isinstance(atc, dict) and atc.get('codigo'):
                                if atc.get('nivel') == 5:
                                    atc_principal = atc.get('codigo')
                                    break
                        if not atc_principal:
                            atc_principal = atcs[-1].get('codigo') if atcs else None
                        return {'atc': atc_principal, 'atcs': atcs}
                return None  # No ATC available
            except Exception as e:
                return {'error': str(e)}

    async with httpx.AsyncClient(timeout=15.0) as client:
        # Process in batches of 50 nregistros
        batch_size = 50
        for i in range(0, total_nregs, batch_size):
            batch = unique_nregs[i:i+batch_size]

            # Fetch ATCs concurrently
            tasks = [fetch_atc(client, nreg) for nreg in batch]
            results = await asyncio.gather(*tasks)

            # Update database
            for nreg, result in zip(batch, results):
                processed += 1

                if result is None:
                    # No ATC available - mark with special value
                    db.execute(
                        update(ProductCatalog)
                        .where(ProductCatalog.cima_nregistro == nreg)
                        .values(cima_atc_code='NO_ATC_DISPONIBLE')
                    )
                    no_atc += 1
                elif 'error' in result:
                    errors += 1
                else:
                    # Has ATC - update all products with this nregistro
                    atc = result.get('atc')
                    atcs = result.get('atcs', [])
                    db.execute(
                        update(ProductCatalog)
                        .where(ProductCatalog.cima_nregistro == nreg)
                        .values(
                            cima_atc_code=atc,
                            cima_atc_codes=atcs
                        )
                    )
                    with_atc += 1

            db.commit()

            # Progress log every 200 nregistros
            if processed % 200 == 0 or processed == total_nregs:
                elapsed = (datetime.now() - start).total_seconds()
                rate = processed / elapsed if elapsed > 0 else 0
                remaining = (total_nregs - processed) / rate if rate > 0 else 0
                pct = processed * 100 / total_nregs
                print(f'[{pct:.1f}%] {processed:,}/{total_nregs:,} | Con ATC: {with_atc:,} | Sin ATC: {no_atc:,} | Errores: {errors} | {rate:.1f}/s | ETA: {remaining/60:.1f}min')

    db.close()

    elapsed = (datetime.now() - start).total_seconds()
    print()
    print('=== COMPLETADO ===')
    print(f'nregistros procesados: {processed:,}')
    print(f'Con ATCs encontrados: {with_atc:,}')
    print(f'Sin ATC disponible: {no_atc:,}')
    print(f'Errores: {errors:,}')
    print(f'Tiempo total: {elapsed/60:.1f} minutos')


if __name__ == '__main__':
    asyncio.run(backfill_atc_efficient())
