# frontend/utils/auth/auth_cloud.py
"""
Cloud Authentication Manager - JWT/OAuth (Pivot 2026).

This implementation handles JWT tokens and OAuth flows for
cloud/Hub deployments.

Requires: PyJWT, cryptography

Author: Pivot 2026 - Auth Refactor
Date: 2026-01
"""

import base64
import logging
import os
from datetime import datetime, timezone
from typing import Any, Dict, Optional

import jwt
import requests

from .base import AuthManagerBase

logger = logging.getLogger(__name__)

# Try to import cryptography for token encryption
try:
    from cryptography.fernet import Fernet
    from cryptography.hazmat.primitives import hashes
    from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC

    CRYPTO_AVAILABLE = True
except ImportError:
    CRYPTO_AVAILABLE = False
    logger.warning("[AUTH_CLOUD] cryptography not available, tokens will be stored unencrypted")

# Backend configuration
BACKEND_URL = os.getenv("BACKEND_URL", "http://127.0.0.1:8000")


class AuthManagerCloud(AuthManagerBase):
    """
    JWT/OAuth authentication manager for cloud deployments.

    Handles:
    - JWT token storage and validation
    - OAuth flows (Google, Microsoft)
    - Token refresh
    - Encrypted token persistence for dcc.Store
    """

    def __init__(self):
        """Initialize the cloud auth manager."""
        self.access_token = None
        self.refresh_token = None
        self.user_info = None
        self._init_secure_storage()
        logger.info("[AUTH_CLOUD] AuthManagerCloud initialized")

    def _init_secure_storage(self):
        """Initialize secure storage for token encryption."""
        if CRYPTO_AVAILABLE:
            salt_string = os.getenv("AUTH_CRYPTO_SALT", "xfarma_salt_dev_2024")
            key_string = os.getenv("AUTH_CRYPTO_KEY", "xfarma_secret_key")

            if salt_string.startswith("base64:"):
                salt = base64.b64decode(salt_string[7:])
            else:
                salt = salt_string.encode("utf-8")

            kdf = PBKDF2HMAC(
                algorithm=hashes.SHA256(),
                length=32,
                salt=salt,
                iterations=100000,
            )
            key = base64.urlsafe_b64encode(kdf.derive(key_string.encode("utf-8")))
            self.cipher_suite = Fernet(key)
        else:
            self.cipher_suite = None

    # =========================================================================
    # Core Authentication State
    # =========================================================================

    def is_authenticated(self) -> bool:
        """Check if user has a valid JWT token."""
        if not self.access_token:
            return False

        try:
            payload = jwt.decode(self.access_token, options={"verify_signature": False})
            exp = datetime.fromtimestamp(payload.get("exp", 0), tz=timezone.utc)
            return exp > datetime.now(timezone.utc)
        except Exception:
            return False

    def get_access_token(self) -> Optional[str]:
        """Get the current JWT access token."""
        return self.access_token

    # =========================================================================
    # Session Management
    # =========================================================================

    def logout(self) -> None:
        """Logout and invalidate tokens."""
        try:
            if self.access_token:
                requests.post(
                    f"{BACKEND_URL}/api/v1/auth/logout",
                    headers={"Authorization": f"Bearer {self.access_token}"},
                    timeout=5,
                )
        except Exception:
            pass

        self.access_token = None
        self.refresh_token = None
        self.user_info = None
        logger.info("[AUTH_CLOUD] User logged out")

    def get_current_user(self) -> Optional[Dict[str, Any]]:
        """Get current user information."""
        if not self.is_authenticated():
            return None

        if self.user_info:
            return self.user_info

        try:
            response = requests.get(
                f"{BACKEND_URL}/api/v1/auth/me",
                headers={"Authorization": f"Bearer {self.access_token}"},
                timeout=10,
            )
            if response.status_code == 200:
                self.user_info = response.json()
                return self.user_info
        except Exception:
            pass

        return None

    # =========================================================================
    # Token Persistence
    # =========================================================================

    def save_tokens(self, access_token: str, refresh_token: str) -> None:
        """Save tokens with optional encryption."""
        self.access_token = access_token
        self.refresh_token = refresh_token

        if CRYPTO_AVAILABLE and self.cipher_suite:
            try:
                encrypted_access = self.cipher_suite.encrypt(access_token.encode())
                encrypted_refresh = self.cipher_suite.encrypt(refresh_token.encode())
                self._encrypted_tokens = {
                    "access": base64.b64encode(encrypted_access).decode("utf-8"),
                    "refresh": base64.b64encode(encrypted_refresh).decode("utf-8"),
                }
            except Exception as e:
                logger.warning(f"[AUTH_CLOUD] Token encryption failed: {e}")
                self._encrypted_tokens = {"access": access_token, "refresh": refresh_token}
        else:
            self._encrypted_tokens = {"access": access_token, "refresh": refresh_token}

    def get_encrypted_tokens(self) -> Optional[Dict[str, str]]:
        """Get encrypted tokens for dcc.Store persistence."""
        return self._encrypted_tokens if hasattr(self, "_encrypted_tokens") else None

    def restore_from_encrypted_tokens(
        self, encrypted_tokens: Dict[str, str], **kwargs
    ) -> bool:
        """Restore tokens from encrypted storage."""
        if not encrypted_tokens:
            return False

        try:
            if CRYPTO_AVAILABLE and self.cipher_suite:
                encrypted_access = base64.b64decode(encrypted_tokens["access"])
                encrypted_refresh = base64.b64decode(encrypted_tokens["refresh"])
                self.access_token = self.cipher_suite.decrypt(encrypted_access).decode("utf-8")
                self.refresh_token = self.cipher_suite.decrypt(encrypted_refresh).decode("utf-8")
            else:
                self.access_token = encrypted_tokens["access"]
                self.refresh_token = encrypted_tokens["refresh"]

            self._encrypted_tokens = encrypted_tokens
            return True
        except Exception as e:
            logger.warning(f"[AUTH_CLOUD] Token restoration failed: {e}")
            return False

    def get_refresh_token(self) -> Optional[str]:
        """Get the current refresh token."""
        return self.refresh_token

    # =========================================================================
    # Token Refresh
    # =========================================================================

    def refresh_access_token(self) -> bool:
        """Refresh the access token using the refresh token."""
        if not self.refresh_token:
            return False

        try:
            response = requests.post(
                f"{BACKEND_URL}/api/v1/auth/refresh",
                json={"refresh_token": self.refresh_token},
                timeout=10,
            )

            if response.status_code == 200:
                data = response.json()
                self.access_token = data.get("access_token")

                if hasattr(self, "_encrypted_tokens"):
                    if CRYPTO_AVAILABLE and self.cipher_suite:
                        try:
                            encrypted_access = self.cipher_suite.encrypt(self.access_token.encode())
                            self._encrypted_tokens["access"] = base64.b64encode(encrypted_access).decode("utf-8")
                        except Exception:
                            self._encrypted_tokens["access"] = self.access_token
                    else:
                        self._encrypted_tokens["access"] = self.access_token

                return True
        except Exception:
            pass

        return False

    # =========================================================================
    # Auth Mode Detection
    # =========================================================================

    @property
    def auth_mode(self) -> str:
        """Return 'cloud' to identify JWT/OAuth auth."""
        return "cloud"

    # =========================================================================
    # Cloud-Specific Methods
    # =========================================================================

    def login(self, email: str, password: str) -> Dict[str, Any]:
        """Login with email and password."""
        try:
            response = requests.post(
                f"{BACKEND_URL}/api/v1/auth/login",
                json={"email": email, "password": password},
                timeout=10,
            )

            if response.status_code == 200:
                data = response.json()
                self.save_tokens(data.get("access_token"), data.get("refresh_token"))
                self.user_info = data.get("user")
                return {"success": True, "data": data}
            else:
                error_detail = response.json().get("detail", "Login failed")
                if isinstance(error_detail, list):
                    error_messages = []
                    for error in error_detail:
                        field = " -> ".join(str(loc) for loc in error.get("loc", []))
                        msg = error.get("msg", "Error de validación")
                        error_messages.append(f"{field}: {msg}")
                    error_detail = "; ".join(error_messages)
                return {"success": False, "error": error_detail}

        except requests.RequestException as e:
            return {"success": False, "error": f"Connection error: {str(e)}"}

    def register(
        self,
        email: str,
        username: str,
        password: str,
        full_name: Optional[str] = None,
        phone: Optional[str] = None,
        pharmacy_name: Optional[str] = None,
        pharmacy_code: Optional[str] = None,
        pharmacy_email: Optional[str] = None,
        pharmacy_address: Optional[str] = None,
        pharmacy_city: Optional[str] = None,
        pharmacy_postal_code: Optional[str] = None,
        pharmacy_phone: Optional[str] = None,
    ) -> Dict[str, Any]:
        """Register a new user with pharmacy."""
        try:
            payload = {
                "email": email,
                "username": username,
                "password": password,
                "full_name": full_name,
                "phone": phone,
                "pharmacy_name": pharmacy_name,
                "pharmacy_code": pharmacy_code,
                "pharmacy_email": pharmacy_email,
                "pharmacy_address": pharmacy_address,
                "pharmacy_city": pharmacy_city,
                "pharmacy_postal_code": pharmacy_postal_code,
                "pharmacy_phone": pharmacy_phone,
            }

            response = requests.post(
                f"{BACKEND_URL}/api/v1/auth/register",
                json=payload,
                timeout=10,
            )

            if response.status_code == 200:
                return {"success": True, "data": response.json()}
            else:
                error_detail = response.json().get("detail", "Registration failed")
                if isinstance(error_detail, list):
                    error_messages = []
                    for error in error_detail:
                        field = " -> ".join(str(loc) for loc in error.get("loc", []))
                        msg = error.get("msg", "Error de validación")
                        error_messages.append(f"{field}: {msg}")
                    error_detail = "; ".join(error_messages)
                return {"success": False, "error": error_detail}

        except requests.RequestException as e:
            return {"success": False, "error": f"Connection error: {str(e)}"}

    def forgot_password(self, email: str) -> Dict[str, Any]:
        """Initiate password recovery."""
        try:
            response = requests.post(
                f"{BACKEND_URL}/api/v1/auth/forgot-password",
                json={"email": email},
                timeout=10,
            )
            if response.status_code == 200:
                return {"success": True, "message": response.json().get("message")}
            return {"success": False, "error": "Failed to send reset instructions"}
        except requests.RequestException as e:
            return {"success": False, "error": f"Connection error: {str(e)}"}

    def get_oauth_providers(self) -> Dict[str, Any]:
        """Get available OAuth providers."""
        try:
            response = requests.get(
                f"{BACKEND_URL}/api/v1/auth/oauth/providers",
                timeout=10,
            )
            if response.status_code == 200:
                return response.json()
        except Exception:
            pass
        return {}

    def generate_oauth_state(self) -> str:
        """Generate a unique state token for OAuth CSRF protection."""
        import hashlib
        import secrets
        import time

        random_bytes = secrets.token_bytes(32)
        timestamp = str(time.time()).encode("utf-8")
        state_hash = hashlib.sha256(random_bytes + timestamp).hexdigest()

        if not hasattr(self, "_oauth_states"):
            self._oauth_states = {}

        current_time = time.time()
        self._oauth_states = {
            k: v for k, v in self._oauth_states.items()
            if current_time - v["timestamp"] < 600
        }
        self._oauth_states[state_hash] = {"timestamp": current_time, "used": False}

        return state_hash

    def validate_oauth_state(self, state: str) -> bool:
        """Validate an OAuth state token."""
        import time

        if not hasattr(self, "_oauth_states"):
            return False

        if state not in self._oauth_states:
            return False

        state_data = self._oauth_states[state]

        if state_data["used"]:
            return False

        if time.time() - state_data["timestamp"] > 600:
            del self._oauth_states[state]
            return False

        self._oauth_states[state]["used"] = True
        return True

    def handle_oauth_callback(
        self, state_key: str, provider: str, state: Optional[str] = None
    ) -> bool:
        """Handle OAuth callback with state key."""
        if state and not self.validate_oauth_state(state):
            logger.warning("[AUTH_CLOUD] OAuth CSRF validation failed")
            return False

        try:
            response = requests.post(
                f"{BACKEND_URL}/api/v1/auth/oauth/exchange",
                params={"state_key": state_key},
                timeout=10,
            )

            if response.status_code == 200:
                data = response.json()
                self.save_tokens(data.get("access_token"), data.get("refresh_token"))
                self.user_info = data.get("user")
                return True
        except Exception as e:
            logger.error(f"[AUTH_CLOUD] OAuth callback error: {e}")

        return False


# Utility function for backwards compatibility
def create_authenticated_client(auth_manager: AuthManagerCloud):
    """Create an HTTP client with automatic authentication."""
    session = requests.Session()
    if auth_manager.is_authenticated():
        session.headers.update({"Authorization": f"Bearer {auth_manager.get_access_token()}"})
    return session


def require_auth(func):
    """Decorator for callbacks that require authentication (legacy)."""
    def wrapper(*args, **kwargs):
        return func(*args, **kwargs)
    return wrapper
