"""add_subscription_expiration_fields

Add subscription_start and subscription_end fields to users table for
subscription expiration management (Issue #444).

BACKGROUND:
This migration implements subscription expiration tracking for PRO and MAX plans.
When a subscription expires, users will be automatically downgraded to FREE.

PROBLEM:
- Current system has subscription_plan but no expiration dates
- PRO/MAX plans need to expire and downgrade to FREE automatically
- Business requirement: All current PRO accounts expire March 31, 2026

SOLUTION:
- Add subscription_start: When the subscription started
- Add subscription_end: When the subscription expires (NULL = never expires)
- Add index on subscription_end for efficient expiration queries
- Initialize existing PRO/MAX accounts with expiration date 2026-03-31

BEHAVIORAL CHANGES:
Before:
- subscription_plan tracks plan type only
- No expiration tracking

After:
- subscription_end tracks when plan expires
- NULL subscription_end = never expires (FREE plan, or legacy)
- Job can query expiring subscriptions efficiently

Revision ID: 20251210_subscription_exp
Revises: 20251205_add_cima_uppercase
Create Date: 2025-12-10 12:00:00.000000

"""
from typing import Sequence, Union
import logging

from alembic import op
import sqlalchemy as sa

# Configure logging
logger = logging.getLogger(__name__)

# revision identifiers, used by Alembic.
revision: str = '20251210_subscription_exp'
down_revision: Union[str, None] = '20251205_add_cima_uppercase'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
    """
    Add subscription_start and subscription_end columns to users table.

    Idempotent: Checks for existing columns before creating.
    """
    conn = op.get_bind()

    # =================================================================
    # 1. ADD COLUMN: users.subscription_start
    # =================================================================
    logger.info("Checking if users.subscription_start column exists...")
    result = conn.execute(sa.text(
        "SELECT column_name FROM information_schema.columns "
        "WHERE table_name='users' AND column_name='subscription_start'"
    ))

    if not result.fetchone():
        logger.info("Adding subscription_start column to users...")
        op.add_column(
            'users',
            sa.Column('subscription_start', sa.DateTime(timezone=True), nullable=True)
        )
        logger.info("subscription_start column added to users")
    else:
        logger.info("users.subscription_start column already exists (skipping)")

    # =================================================================
    # 2. ADD COLUMN: users.subscription_end
    # =================================================================
    logger.info("Checking if users.subscription_end column exists...")
    result = conn.execute(sa.text(
        "SELECT column_name FROM information_schema.columns "
        "WHERE table_name='users' AND column_name='subscription_end'"
    ))

    if not result.fetchone():
        logger.info("Adding subscription_end column to users...")
        op.add_column(
            'users',
            sa.Column('subscription_end', sa.DateTime(timezone=True), nullable=True)
        )
        logger.info("subscription_end column added to users")
    else:
        logger.info("users.subscription_end column already exists (skipping)")

    # =================================================================
    # 3. CREATE INDEX: idx_users_subscription_end for efficient queries
    # =================================================================
    logger.info("Checking if index idx_users_subscription_end exists...")
    result = conn.execute(sa.text(
        "SELECT indexname FROM pg_indexes WHERE indexname = 'idx_users_subscription_end'"
    ))

    if not result.fetchone():
        logger.info("Creating index idx_users_subscription_end...")
        conn.execute(sa.text(
            "CREATE INDEX idx_users_subscription_end ON users (subscription_end) "
            "WHERE subscription_end IS NOT NULL"
        ))
        logger.info("Index idx_users_subscription_end created")
    else:
        logger.info("Index idx_users_subscription_end already exists (skipping)")

    # =================================================================
    # 4. INITIALIZE: Set expiration date for existing PRO/MAX accounts
    # Business requirement: All PRO accounts expire March 31, 2026
    # =================================================================
    logger.info("Initializing subscription dates for existing PRO/MAX accounts...")

    # Set subscription_start to created_at for existing PRO/MAX users
    conn.execute(sa.text("""
        UPDATE users
        SET subscription_start = created_at
        WHERE subscription_plan IN ('pro', 'max')
        AND subscription_start IS NULL
    """))

    # Set subscription_end to March 31, 2026 for all PRO/MAX users
    conn.execute(sa.text("""
        UPDATE users
        SET subscription_end = '2026-03-31 23:59:59+00'::timestamptz
        WHERE subscription_plan IN ('pro', 'max')
        AND subscription_end IS NULL
    """))

    # Count affected users
    result = conn.execute(sa.text(
        "SELECT COUNT(*) FROM users WHERE subscription_plan IN ('pro', 'max')"
    ))
    count = result.scalar()
    logger.info(f"Initialized {count} PRO/MAX accounts with expiration 2026-03-31")

    # =================================================================
    # 5. SYNC: Update pharmacy subscription dates to match user
    # =================================================================
    logger.info("Syncing pharmacy subscription dates with user dates...")
    conn.execute(sa.text("""
        UPDATE pharmacies p
        SET
            subscription_start = u.subscription_start,
            subscription_end = u.subscription_end
        FROM users u
        WHERE u.pharmacy_id = p.id
        AND u.subscription_plan IN ('pro', 'max')
    """))

    logger.info("Migration completed successfully!")


def downgrade() -> None:
    """
    Rollback subscription expiration fields.

    WARNING: This will remove expiration dates from users.
    """
    conn = op.get_bind()

    # Drop index
    logger.info("Dropping index idx_users_subscription_end...")
    result = conn.execute(sa.text(
        "SELECT indexname FROM pg_indexes WHERE indexname = 'idx_users_subscription_end'"
    ))
    if result.fetchone():
        op.drop_index('idx_users_subscription_end', table_name='users')

    # Drop columns
    logger.info("Dropping subscription_end column from users...")
    result = conn.execute(sa.text(
        "SELECT column_name FROM information_schema.columns "
        "WHERE table_name='users' AND column_name='subscription_end'"
    ))
    if result.fetchone():
        op.drop_column('users', 'subscription_end')

    logger.info("Dropping subscription_start column from users...")
    result = conn.execute(sa.text(
        "SELECT column_name FROM information_schema.columns "
        "WHERE table_name='users' AND column_name='subscription_start'"
    ))
    if result.fetchone():
        op.drop_column('users', 'subscription_start')

    logger.info("Rollback completed!")
