﻿# backend/app/services/audit_service.py
"""
Audit service for logging critical operations.
"""

import logging
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
from uuid import UUID

from fastapi import Request
from sqlalchemy import and_, desc, or_
from sqlalchemy.orm import Session

from app.models.audit_log import AuditAction, AuditLog, AuditSeverity
from app.models.user import User
from app.utils.datetime_utils import utc_now

logger = logging.getLogger(__name__)


class AuditService:
    """Service for managing audit logs"""

    def __init__(self, db: Session):
        self.db = db

    def log_action(
        self,
        action: AuditAction,
        method: str,
        endpoint: str,
        user: Optional[User] = None,
        request: Optional[Request] = None,
        resource_type: Optional[str] = None,
        resource_id: Optional[str] = None,
        description: Optional[str] = None,
        details: Optional[Dict[str, Any]] = None,
        success: bool = True,
        status_code: Optional[int] = None,
        error_message: Optional[str] = None,
        duration_ms: Optional[float] = None,
    ) -> AuditLog:
        """
        Log an audit entry for a critical operation.

        Args:
            action: Type of action performed
            method: HTTP method
            endpoint: API endpoint called
            user: User performing the action
            request: FastAPI request object for extracting metadata
            resource_type: Type of resource affected
            resource_id: ID of the affected resource
            description: Human-readable description
            details: Additional structured data
            success: Whether the operation succeeded
            status_code: HTTP status code
            error_message: Error message if failed
            duration_ms: Request duration in milliseconds

        Returns:
            Created AuditLog entry
        """
        try:
            # Determine severity based on action and success
            severity = self._determine_severity(action, success)

            # Extract request information
            ip_address = None
            user_agent = None
            if request:
                if request.client:
                    ip_address = request.client.host
                user_agent = request.headers.get("User-Agent")

            # Create audit log entry
            audit_log = AuditLog(
                user_id=user.id if user else None,
                user_email=user.email if user else None,
                user_role=user.role if user else None,
                action=action,
                severity=severity,
                resource_type=resource_type,
                resource_id=resource_id,
                method=method,
                endpoint=endpoint,
                ip_address=ip_address,
                user_agent=user_agent,
                description=description,
                details=details,
                success="success" if success else "failure",
                status_code=str(status_code) if status_code else None,
                error_message=error_message,
                duration_ms=str(duration_ms) if duration_ms else None,
            )

            self.db.add(audit_log)

            # For critical operations, commit immediately
            # For non-critical operations, use flush to batch commits
            if severity == AuditSeverity.CRITICAL or not success:
                self.db.commit()  # Immediate commit for critical or failed operations
            else:
                self.db.flush()  # Batch non-critical successful operations

            # Log to application logger as well
            log_message = f"AUDIT: {action.value} by {user.email if user else 'system'} on {endpoint}"
            if success:
                logger.info(log_message)
            else:
                logger.warning(f"{log_message} - Failed: {error_message}")

            return audit_log

        except Exception as e:
            logger.error(f"Failed to create audit log: {str(e)}")
            self.db.rollback()
            # Don't raise - audit logging should not break the application
            return None

    def _determine_severity(self, action: AuditAction, success: bool) -> AuditSeverity:
        """
        Determine severity level based on action type and success status.
        """
        # Critical actions
        critical_actions = [
            AuditAction.DELETE,
            AuditAction.ADMIN_ACTION,
            AuditAction.SYSTEM_CHANGE,
        ]

        if not success:
            # Failed critical actions are critical
            if action in critical_actions:
                return AuditSeverity.CRITICAL
            # Other failures are warnings
            return AuditSeverity.WARNING

        # Successful critical actions are warnings (notable but not problematic)
        if action in critical_actions:
            return AuditSeverity.WARNING

        # Everything else is info
        return AuditSeverity.INFO

    def get_user_activity(
        self,
        user_id: UUID,
        start_date: Optional[datetime] = None,
        end_date: Optional[datetime] = None,
        limit: int = 100,
    ) -> List[AuditLog]:
        """
        Get audit logs for a specific user.

        Args:
            user_id: User ID to filter by
            start_date: Start date filter
            end_date: End date filter
            limit: Maximum number of records to return

        Returns:
            List of audit logs
        """
        query = self.db.query(AuditLog).filter(AuditLog.user_id == user_id)

        if start_date:
            query = query.filter(AuditLog.timestamp >= start_date)
        if end_date:
            query = query.filter(AuditLog.timestamp <= end_date)

        return query.order_by(desc(AuditLog.timestamp)).limit(limit).all()

    def get_critical_actions(
        self,
        hours: int = 24,
        limit: int = 100,
    ) -> List[AuditLog]:
        """
        Get critical audit logs from the last N hours.

        Args:
            hours: Number of hours to look back
            limit: Maximum number of records to return

        Returns:
            List of critical audit logs
        """
        since = utc_now() - timedelta(hours=hours)

        return (
            self.db.query(AuditLog)
            .filter(
                and_(
                    AuditLog.timestamp >= since,
                    or_(
                        AuditLog.severity == AuditSeverity.CRITICAL,
                        AuditLog.severity == AuditSeverity.WARNING,
                    ),
                )
            )
            .order_by(desc(AuditLog.timestamp))
            .limit(limit)
            .all()
        )

    def get_failed_operations(
        self,
        hours: int = 24,
        limit: int = 100,
    ) -> List[AuditLog]:
        """
        Get failed operations from the last N hours.

        Args:
            hours: Number of hours to look back
            limit: Maximum number of records to return

        Returns:
            List of failed operation logs
        """
        since = utc_now() - timedelta(hours=hours)

        return (
            self.db.query(AuditLog)
            .filter(
                and_(
                    AuditLog.timestamp >= since,
                    AuditLog.success == "failure",
                )
            )
            .order_by(desc(AuditLog.timestamp))
            .limit(limit)
            .all()
        )

    def get_admin_actions(
        self,
        hours: int = 168,  # 1 week
        limit: int = 100,
    ) -> List[AuditLog]:
        """
        Get admin actions from the last N hours.

        Args:
            hours: Number of hours to look back
            limit: Maximum number of records to return

        Returns:
            List of admin action logs
        """
        since = utc_now() - timedelta(hours=hours)

        return (
            self.db.query(AuditLog)
            .filter(
                and_(
                    AuditLog.timestamp >= since,
                    or_(
                        AuditLog.action == AuditAction.ADMIN_ACTION,
                        AuditLog.endpoint.like("/api/v1/admin/%"),
                    ),
                )
            )
            .order_by(desc(AuditLog.timestamp))
            .limit(limit)
            .all()
        )

    def cleanup_old_logs(self, days: int = 90) -> int:
        """
        Remove audit logs older than N days.
        Keep critical logs for longer to comply with pharmaceutical regulations.

        Spanish pharmaceutical regulations require:
        - 5 years for medicine dispensing records
        - 3 years for system access and critical operations
        - 2 years for general transaction records

        Args:
            days: Number of days to keep regular logs (default: 90)

        Returns:
            Number of records deleted
        """
        try:
            # Retention periods based on pharmaceutical regulations
            regular_cutoff = utc_now() - timedelta(days=days)  # 90 days for INFO logs
            warning_cutoff = utc_now() - timedelta(days=730)  # 2 years for WARNING logs
            critical_cutoff = utc_now() - timedelta(days=1095)  # 3 years for CRITICAL logs

            # Delete old INFO logs (after 90 days)
            info_deleted = (
                self.db.query(AuditLog)
                .filter(
                    and_(
                        AuditLog.timestamp < regular_cutoff,
                        AuditLog.severity == AuditSeverity.INFO,
                    )
                )
                .delete(synchronize_session=False)
            )

            # Delete old WARNING logs (after 2 years)
            warning_deleted = (
                self.db.query(AuditLog)
                .filter(
                    and_(
                        AuditLog.timestamp < warning_cutoff,
                        AuditLog.severity == AuditSeverity.WARNING,
                    )
                )
                .delete(synchronize_session=False)
            )

            # Delete old CRITICAL logs (after 3 years)
            critical_deleted = (
                self.db.query(AuditLog)
                .filter(
                    and_(
                        AuditLog.timestamp < critical_cutoff,
                        AuditLog.severity == AuditSeverity.CRITICAL,
                    )
                )
                .delete(synchronize_session=False)
            )

            self.db.commit()
            total_deleted = info_deleted + warning_deleted + critical_deleted

            logger.info(f"Cleaned up {total_deleted} old audit logs")
            return total_deleted

        except Exception as e:
            logger.error(f"Failed to cleanup audit logs: {str(e)}")
            self.db.rollback()
            return 0


# Singleton instance
_audit_service_instances = {}


def get_audit_service(db: Session) -> AuditService:
    """
    Get or create an AuditService instance for the given database session.
    """
    session_id = id(db)
    if session_id not in _audit_service_instances:
        _audit_service_instances[session_id] = AuditService(db)
    return _audit_service_instances[session_id]
