# backend/app/middleware/auth_middleware.py
"""
Global authentication middleware for xFarma API.
Validates JWT tokens on all requests except public endpoints.

Pivot 2026: Conditional JWT imports for local mode (PIN-based auth)
"""

import logging
import os
import time
from collections import defaultdict
from typing import Dict, List, Optional, Tuple

from fastapi import Request
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware

# Pivot 2026: Conditional JWT imports
IS_LOCAL_MODE = os.getenv("KAIFARMA_LOCAL", "").lower() == "true"

if IS_LOCAL_MODE:
    # Local mode: Stub JWT classes (not used - PIN-based auth)
    ExpiredSignatureError = Exception
    JWTError = Exception
    jwt = None  # type: ignore
else:
    # Cloud mode: Full JWT support
    from jose import ExpiredSignatureError, JWTError, jwt

from app.core.security import ALGORITHM, SECRET_KEY

logger = logging.getLogger(__name__)

# No authentication bypass - all requests require valid JWT tokens
# Development and production environments use the same authentication flow

# Public endpoints that don't require authentication
PUBLIC_ENDPOINTS = [
    "/",  # Root
    "/health",  # Health check
    "/api/health",  # API health check
    "/docs",  # API documentation
    "/redoc",  # API documentation
    "/openapi.json",  # OpenAPI schema
    "/api/v1/auth/login",  # Login endpoint
    "/api/v1/auth/register",  # Registration endpoint
    "/api/v1/auth/oauth/google",  # OAuth Google
    "/api/v1/auth/oauth/microsoft",  # OAuth Microsoft
    "/api/v1/auth/oauth/google/callback",  # OAuth Google callback
    "/api/v1/auth/oauth/microsoft/callback",  # OAuth Microsoft callback
    "/api/v1/auth/refresh",  # Token refresh
    "/api/v1/auth/forgot-password",  # Password reset
    "/api/v1/auth/reset-password",  # Password reset confirmation
    "/api/v1/auth/oauth/providers",  # OAuth providers info
    # Pivot 2026: Local PIN authentication (needed to unlock terminal)
    "/api/v1/auth/local/status",  # Check lock status
    "/api/v1/auth/local/unlock",  # Unlock with PIN
    "/api/v1/auth/local/lock",  # Lock terminal
    "/_favicon.ico",  # Browser favicon
    "/_dash-layout",  # Dash layout endpoint (required for initial page load)
    "/_dash-dependencies",  # Dash dependencies endpoint (required for callbacks)
    "/metrics",  # Prometheus metrics (Issue #114 - Fase 3)
]

# Endpoints that require specific permissions (handled at endpoint level)
ADMIN_ENDPOINTS = [
    "/api/v1/admin/",
    "/api/v1/system/",
    "/api/v1/developer/",
]


class AuthenticationMiddleware(BaseHTTPMiddleware):
    """
    Global middleware to validate JWT tokens on protected routes.
    """

    def __init__(self, app, exclude_paths: Optional[List[str]] = None):
        super().__init__(app)
        self.exclude_paths = exclude_paths or PUBLIC_ENDPOINTS

    async def dispatch(self, request: Request, call_next):
        """
        Process each request to validate authentication.
        """
        # Pivot 2026: Skip JWT middleware entirely in local mode
        # Local mode uses PIN-based auth via LocalSecurityManager
        if IS_LOCAL_MODE:
            response = await call_next(request)
            return response

        # Check if path should be excluded from auth
        path = request.url.path

        # Skip authentication for public endpoints
        if self._is_public_endpoint(path):
            response = await call_next(request)
            return response

        # Extract token from Authorization header
        authorization = request.headers.get("Authorization")

        if not authorization:
            logger.warning(f"Unauthorized access attempt to {path} - No token provided")
            return JSONResponse(
                status_code=401,
                content={
                    "detail": "Autenticación requerida",
                    "error": "missing_token",
                },
            )

        # Validate Bearer token format
        try:
            scheme, token = authorization.split()
            if scheme.lower() != "bearer":
                raise ValueError("Invalid authentication scheme")
        except ValueError:
            logger.warning(f"Invalid authorization header format for {path}")
            return JSONResponse(
                status_code=401,
                content={
                    "detail": "Credenciales de autenticación inválidas",
                    "error": "invalid_format",
                },
            )

        # Validate JWT token
        try:
            payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])

            # Check token type
            if payload.get("type") != "access":
                logger.warning(f"Invalid token type for {path}")
                return JSONResponse(
                    status_code=401,
                    content={
                        "detail": "Tipo de token inválido",
                        "error": "invalid_token_type",
                    },
                )

            # Add user info to request state for downstream use
            # Note: request.state is thread-safe in Starlette/FastAPI as each
            # request gets its own isolated state object per the ASGI spec
            request.state.user_id = payload.get("sub")
            request.state.user_email = payload.get("email")
            request.state.user_role = payload.get("role")

            # Log successful authentication for admin endpoints
            if any(path.startswith(admin_path) for admin_path in ADMIN_ENDPOINTS):
                logger.info(f"Admin access: {request.state.user_email} accessing {path}")

        except ExpiredSignatureError:
            logger.warning(f"Expired token for {path}")
            return JSONResponse(
                status_code=401,
                content={
                    "detail": "Token expirado. Por favor, inicie sesión nuevamente",
                    "error": "token_expired",
                },
            )
        except JWTError as e:
            logger.warning(f"JWT validation failed for {path}: {str(e)}")
            return JSONResponse(
                status_code=401,
                content={
                    "detail": "No se pudieron validar las credenciales",
                    "error": "invalid_token",
                },
            )
        except Exception as e:
            logger.error(f"Unexpected error in auth middleware: {str(e)}")
            return JSONResponse(
                status_code=500,
                content={
                    "detail": "Error en el servicio de autenticación",
                    "error": "internal_error",
                },
            )

        # Process the request
        response = await call_next(request)

        # Add security headers
        response.headers["X-Content-Type-Options"] = "nosniff"
        response.headers["X-Frame-Options"] = "DENY"
        response.headers["X-XSS-Protection"] = "1; mode=block"

        return response

    def _is_public_endpoint(self, path: str) -> bool:
        """
        Check if the endpoint should be public (no auth required).
        """
        # Exact matches
        if path in self.exclude_paths:
            return True

        # Prefix matches for paths like /api/v1/auth/*
        public_prefixes = [
            "/api/v1/auth/",  # All auth endpoints are public
            "/static/",  # Static files
            "/media/",  # Media files
            "/assets/",  # Dash CSS/JS assets
            "/_dash-component-suites/",  # Dash internal components
            "/_dash",  # All Dash internal callbacks (_dash-update-component, _dash-dependencies, etc.)
        ]

        for prefix in public_prefixes:
            if path.startswith(prefix):
                return True

        return False


class AuditLoggingMiddleware(BaseHTTPMiddleware):
    """
    Middleware to log all critical operations for audit trail.
    """

    def __init__(self, app):
        super().__init__(app)
        self.critical_operations = [
            ("DELETE", "/api/v1/admin/delete-all-data"),
            ("POST", "/api/v1/admin/vacuum-database"),
            ("POST", "/api/v1/admin/clean-catalog"),
            ("DELETE", "/api/v1/"),
            ("PUT", "/api/v1/"),
        ]

    async def dispatch(self, request: Request, call_next):
        """
        Log critical operations with user information.
        """
        method = request.method
        path = request.url.path

        # Check if this is a critical operation
        is_critical = any(
            method == op_method and path.startswith(op_path) for op_method, op_path in self.critical_operations
        )

        if is_critical:
            # Log the operation attempt
            user_email = getattr(request.state, "user_email", "unknown")
            user_id = getattr(request.state, "user_id", "unknown")

            logger.warning(
                f"AUDIT: Critical operation attempted - "
                f"User: {user_email} (ID: {user_id}), "
                f"Method: {method}, "
                f"Path: {path}, "
                f"IP: {request.client.host if request.client else 'unknown'}"
            )

        # Process the request
        response = await call_next(request)

        # Log the result if it was a critical operation
        if is_critical:
            status_code = response.status_code
            success = 200 <= status_code < 300

            logger.warning(
                f"AUDIT: Critical operation {'completed' if success else 'failed'} - "
                f"User: {user_email}, "
                f"Method: {method}, "
                f"Path: {path}, "
                f"Status: {status_code}"
            )

        return response


class RateLimitMiddleware(BaseHTTPMiddleware):
    """
    Rate limiting middleware for authentication endpoints to prevent brute force attacks.
    """

    def __init__(
        self,
        app,
        max_requests: int = 5,  # Maximum requests
        time_window: int = 60,  # Time window in seconds
        lockout_time: int = 300,  # Lockout time in seconds (5 minutes)
    ):
        super().__init__(app)
        self.max_requests = max_requests
        self.time_window = time_window
        self.lockout_time = lockout_time
        # Store IP attempts: {ip: [(timestamp, path), ...]}
        self.attempts: Dict[str, List[Tuple[float, str]]] = defaultdict(list)
        # Store locked IPs: {ip: lockout_until_timestamp}
        self.locked_ips: Dict[str, float] = {}

    async def dispatch(self, request: Request, call_next):
        """
        Check rate limits for authentication endpoints.
        """
        path = request.url.path

        # Only apply rate limiting to authentication endpoints
        auth_endpoints = [
            "/api/v1/auth/login",
            "/api/v1/auth/register",
            "/api/v1/auth/forgot-password",
        ]

        if not any(path == endpoint for endpoint in auth_endpoints):
            return await call_next(request)

        # Get client IP
        client_ip = request.client.host if request.client else "unknown"

        # Check if IP is currently locked
        current_time = time.time()
        if client_ip in self.locked_ips:
            if current_time < self.locked_ips[client_ip]:
                remaining_lockout = int(self.locked_ips[client_ip] - current_time)
                logger.warning(f"Rate limit: Locked IP {client_ip} attempted to access {path}")
                return JSONResponse(
                    status_code=429,
                    content={
                        "detail": f"Demasiados intentos. Por favor, espere {remaining_lockout} segundos",
                        "error": "rate_limit_exceeded",
                        "retry_after": remaining_lockout,
                    },
                )
            else:
                # Lockout expired, remove from locked list
                del self.locked_ips[client_ip]

        # Clean old attempts for this IP
        if client_ip in self.attempts:
            self.attempts[client_ip] = [
                (ts, p) for ts, p in self.attempts[client_ip] if current_time - ts < self.time_window
            ]

        # Check current request count
        request_count = len(self.attempts[client_ip])

        if request_count >= self.max_requests:
            # Lock the IP
            self.locked_ips[client_ip] = current_time + self.lockout_time
            logger.warning(
                f"Rate limit exceeded: IP {client_ip} locked for {self.lockout_time} seconds after {request_count} attempts on {path}"
            )
            return JSONResponse(
                status_code=429,
                content={
                    "detail": f"Límite de intentos excedido. Bloqueado por {self.lockout_time} segundos",
                    "error": "rate_limit_exceeded",
                    "retry_after": self.lockout_time,
                },
            )

        # Record this attempt
        self.attempts[client_ip].append((current_time, path))

        # Process the request
        response = await call_next(request)

        # If login failed, count it; if succeeded, clear attempts
        if path == "/api/v1/auth/login" and response.status_code == 200:
            # Successful login, clear attempts for this IP
            if client_ip in self.attempts:
                del self.attempts[client_ip]
            if client_ip in self.locked_ips:
                del self.locked_ips[client_ip]

        return response
