# backend/app/services/subscription_expiration_service.py
"""
Subscription expiration service for managing PRO/MAX plan expirations.

Issue #444: Implementar caducidad automatica de planes PRO y MAX.

This service handles:
- Checking for expired subscriptions
- Downgrading users to FREE when their subscription expires
- Providing subscription status information (days remaining, etc.)
- Syncing User and Pharmacy subscription dates
"""

import logging
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import List, Optional, Tuple
from uuid import UUID

from sqlalchemy import and_
from sqlalchemy.orm import Session

from app.core.subscription_limits import SubscriptionPlan
from app.models.audit_log import AuditAction
from app.models.pharmacy import Pharmacy
from app.models.user import User
from app.services.audit_service import get_audit_service
from app.utils.datetime_utils import utc_now

logger = logging.getLogger(__name__)


@dataclass
class ExpirationResult:
    """Result of a subscription expiration check"""
    total_checked: int
    expired_count: int
    downgraded_user_ids: List[str]  # List of user IDs (GDPR-compliant, not emails)
    errors: List[str]


@dataclass
class SubscriptionStatus:
    """Current subscription status for a user"""
    plan: str
    is_active: bool
    subscription_start: Optional[datetime]
    subscription_end: Optional[datetime]
    days_remaining: Optional[int]
    is_expiring_soon: bool  # True if <= 7 days remaining
    is_expired: bool


class SubscriptionExpirationService:
    """Service for managing subscription expiration"""

    # Warning thresholds in days - used by get_users_by_expiration_threshold()
    # for sending notifications at specific intervals before expiration
    WARNING_THRESHOLDS = [7, 3, 1]

    def __init__(self, db: Session):
        self.db = db
        self.audit_service = get_audit_service(db)

    def get_subscription_status(self, user: User) -> SubscriptionStatus:
        """
        Get the current subscription status for a user.

        Args:
            user: User model instance

        Returns:
            SubscriptionStatus with current plan details
        """
        now = utc_now()

        # FREE plan never expires
        if user.subscription_plan == SubscriptionPlan.FREE:
            return SubscriptionStatus(
                plan=user.subscription_plan,
                is_active=user.is_active,
                subscription_start=user.subscription_start,
                subscription_end=None,
                days_remaining=None,
                is_expiring_soon=False,
                is_expired=False,
            )

        # Check if subscription has expired
        is_expired = False
        days_remaining = None

        if user.subscription_end:
            if user.subscription_end <= now:
                is_expired = True
                days_remaining = 0
            else:
                delta = user.subscription_end - now
                days_remaining = delta.days

        is_expiring_soon = days_remaining is not None and 0 < days_remaining <= 7

        return SubscriptionStatus(
            plan=user.subscription_plan,
            is_active=user.is_active,
            subscription_start=user.subscription_start,
            subscription_end=user.subscription_end,
            days_remaining=days_remaining,
            is_expiring_soon=is_expiring_soon,
            is_expired=is_expired,
        )

    def check_and_expire_subscriptions(self) -> ExpirationResult:
        """
        Check all PRO/MAX subscriptions and downgrade expired ones to FREE.

        This method should be called by a daily job.

        Returns:
            ExpirationResult with counts and details
        """
        now = utc_now()
        result = ExpirationResult(
            total_checked=0,
            expired_count=0,
            downgraded_user_ids=[],
            errors=[],
        )

        try:
            # Find all PRO/MAX users with expired subscriptions
            expired_users = (
                self.db.query(User)
                .filter(
                    and_(
                        User.subscription_plan.in_([SubscriptionPlan.PRO, SubscriptionPlan.MAX]),
                        User.subscription_end.isnot(None),
                        User.subscription_end <= now,
                        User.deleted_at.is_(None),  # Exclude soft-deleted users
                    )
                )
                .all()
            )

            result.total_checked = len(expired_users)

            for user in expired_users:
                try:
                    old_plan = user.subscription_plan
                    self._downgrade_to_free(user, old_plan)
                    result.expired_count += 1
                    result.downgraded_user_ids.append(str(user.id))
                except Exception as e:
                    error_msg = f"Failed to downgrade user {user.email}: {str(e)}"
                    result.errors.append(error_msg)
                    logger.error(error_msg)

            self.db.commit()
            logger.info(
                f"Subscription expiration check completed: "
                f"{result.expired_count}/{result.total_checked} expired"
            )

        except Exception as e:
            error_msg = f"Error in subscription expiration check: {str(e)}"
            result.errors.append(error_msg)
            logger.error(error_msg)
            self.db.rollback()

        return result

    def _downgrade_to_free(self, user: User, old_plan: str) -> None:
        """
        Downgrade a user to FREE plan.

        Args:
            user: User to downgrade
            old_plan: Previous plan name for audit log
        """
        # Update user subscription
        user.subscription_plan = SubscriptionPlan.FREE
        # Keep subscription_start and subscription_end for historical reference

        # Sync pharmacy subscription
        if user.pharmacy:
            user.pharmacy.subscription_plan = SubscriptionPlan.FREE

        # Create audit log
        self.audit_service.log_action(
            action=AuditAction.SYSTEM_CHANGE,
            method="SYSTEM",
            endpoint="/system/subscription-expiration",
            user=None,  # System action
            resource_type="user",
            resource_id=str(user.id),
            description=f"Subscription expired: {old_plan} -> free for {user.email}",
            details={
                "user_id": str(user.id),
                "user_email": user.email,
                "old_plan": old_plan,
                "new_plan": SubscriptionPlan.FREE,
                "subscription_end": user.subscription_end.isoformat() if user.subscription_end else None,
            },
            success=True,
        )

        logger.info(f"User {user.email} downgraded from {old_plan} to FREE (subscription expired)")

    def get_expiring_soon(self, days: int = 7) -> List[User]:
        """
        Get users whose subscription expires within N days.

        Args:
            days: Number of days to look ahead (default 7)

        Returns:
            List of users with expiring subscriptions
        """
        now = utc_now()
        cutoff = now + timedelta(days=days)

        return (
            self.db.query(User)
            .filter(
                and_(
                    User.subscription_plan.in_([SubscriptionPlan.PRO, SubscriptionPlan.MAX]),
                    User.subscription_end.isnot(None),
                    User.subscription_end > now,  # Not yet expired
                    User.subscription_end <= cutoff,  # Expires within N days
                    User.deleted_at.is_(None),
                )
            )
            .all()
        )

    def get_users_by_expiration_threshold(self, threshold_days: int) -> List[User]:
        """
        Get users whose subscription expires in exactly N days.

        Used for sending notifications at specific thresholds (7, 3, 1 days).

        Args:
            threshold_days: Exact number of days until expiration

        Returns:
            List of users expiring in exactly N days
        """
        now = utc_now()
        start_of_day = datetime(now.year, now.month, now.day, tzinfo=now.tzinfo)
        target_day_start = start_of_day + timedelta(days=threshold_days)
        target_day_end = target_day_start + timedelta(days=1)

        return (
            self.db.query(User)
            .filter(
                and_(
                    User.subscription_plan.in_([SubscriptionPlan.PRO, SubscriptionPlan.MAX]),
                    User.subscription_end.isnot(None),
                    User.subscription_end >= target_day_start,
                    User.subscription_end < target_day_end,
                    User.deleted_at.is_(None),
                )
            )
            .all()
        )

    def update_subscription(
        self,
        user: User,
        plan: str,
        subscription_end: Optional[datetime] = None,
        subscription_start: Optional[datetime] = None,
        admin_user: Optional[User] = None,
    ) -> Tuple[bool, str]:
        """
        Update a user's subscription plan and dates.

        Args:
            user: User to update
            plan: New subscription plan ('free', 'pro', 'max')
            subscription_end: New expiration date (None for FREE plans)
            subscription_start: New start date (defaults to now if upgrading)
            admin_user: Admin user making the change (for audit)

        Returns:
            Tuple of (success, message)
        """
        try:
            old_plan = user.subscription_plan
            old_end = user.subscription_end
            now = utc_now()

            # Validate plan
            if plan not in [SubscriptionPlan.FREE, SubscriptionPlan.PRO, SubscriptionPlan.MAX]:
                return False, f"Invalid plan: {plan}"

            # Update user
            user.subscription_plan = plan

            if plan == SubscriptionPlan.FREE:
                # FREE plans don't have expiration - clear dates for consistency
                user.subscription_start = None
                user.subscription_end = None
            else:
                # PRO/MAX plans need dates
                if subscription_start:
                    user.subscription_start = subscription_start
                elif not user.subscription_start:
                    user.subscription_start = now

                if subscription_end:
                    user.subscription_end = subscription_end
                elif not user.subscription_end:
                    # Default to 1 year if no expiration specified
                    user.subscription_end = now + timedelta(days=365)

            # Sync with pharmacy
            if user.pharmacy:
                user.pharmacy.subscription_plan = plan
                user.pharmacy.subscription_start = user.subscription_start
                user.pharmacy.subscription_end = user.subscription_end

            # Audit log
            self.audit_service.log_action(
                action=AuditAction.ADMIN_ACTION,
                method="PUT",
                endpoint="/api/v1/admin/users/subscription",
                user=admin_user,
                resource_type="user",
                resource_id=str(user.id),
                description=f"Subscription updated: {old_plan} -> {plan}",
                details={
                    "user_id": str(user.id),
                    "user_email": user.email,
                    "old_plan": old_plan,
                    "new_plan": plan,
                    "old_subscription_end": old_end.isoformat() if old_end else None,
                    "new_subscription_end": user.subscription_end.isoformat() if user.subscription_end else None,
                },
                success=True,
            )

            self.db.commit()
            logger.info(f"Subscription updated for {user.email}: {old_plan} -> {plan}")
            return True, f"Subscription updated to {plan}"

        except Exception as e:
            self.db.rollback()
            error_msg = f"Failed to update subscription: {str(e)}"
            logger.error(error_msg)
            return False, error_msg

    def get_subscription_stats(self) -> dict:
        """
        Get subscription statistics for admin dashboard.

        Returns:
            Dictionary with subscription counts and details
        """
        from sqlalchemy import func

        # Count by plan
        plan_counts = (
            self.db.query(
                User.subscription_plan,
                func.count(User.id)
            )
            .filter(User.deleted_at.is_(None))
            .group_by(User.subscription_plan)
            .all()
        )

        # Count expiring soon (next 7 days)
        now = utc_now()
        expiring_soon = (
            self.db.query(func.count(User.id))
            .filter(
                and_(
                    User.subscription_plan.in_([SubscriptionPlan.PRO, SubscriptionPlan.MAX]),
                    User.subscription_end.isnot(None),
                    User.subscription_end > now,
                    User.subscription_end <= now + timedelta(days=7),
                    User.deleted_at.is_(None),
                )
            )
            .scalar()
        )

        # Count already expired (but not yet downgraded - edge case)
        expired = (
            self.db.query(func.count(User.id))
            .filter(
                and_(
                    User.subscription_plan.in_([SubscriptionPlan.PRO, SubscriptionPlan.MAX]),
                    User.subscription_end.isnot(None),
                    User.subscription_end <= now,
                    User.deleted_at.is_(None),
                )
            )
            .scalar()
        )

        return {
            "by_plan": {plan: count for plan, count in plan_counts},
            "expiring_soon_7_days": expiring_soon or 0,
            "expired_not_downgraded": expired or 0,
            "total_users": sum(count for _, count in plan_counts),
        }


def get_subscription_expiration_service(db: Session) -> SubscriptionExpirationService:
    """
    Get a SubscriptionExpirationService instance for the given session.

    Note: Service is lightweight (no heavy state), so we create a new instance
    per request instead of caching. This avoids memory leaks in multi-worker
    Render environments where session id() would accumulate indefinitely.
    """
    return SubscriptionExpirationService(db)
