"""
Utilities for safe database migrations in xFarma

This module provides utilities to prevent common migration errors:
- Duplicate column creation
- Duplicate index creation
- Table existence checks
- Safe rollback operations

Usage in migrations:
```python
from .migration_utils import safe_add_column, safe_create_index, safe_drop_column

def upgrade():
    safe_add_column('table_name', sa.Column('column_name', sa.String(50)))
    safe_create_index('index_name', 'table_name', ['column_name'])
```
"""

from typing import List, Optional

import sqlalchemy as sa
from alembic import op


def table_exists(table_name: str) -> bool:
    """Check if a table exists in the database"""
    try:
        connection = op.get_bind()
        inspector = sa.inspect(connection)
        return table_name in inspector.get_table_names()
    except Exception as e:
        print(f"Error checking table existence: {e}")
        return False


def column_exists(table_name: str, column_name: str) -> bool:
    """Check if a column exists in a table"""
    try:
        connection = op.get_bind()
        inspector = sa.inspect(connection)

        if not table_exists(table_name):
            return False

        columns = [col["name"] for col in inspector.get_columns(table_name)]
        return column_name in columns
    except Exception as e:
        print(f"Error checking column existence: {e}")
        return False


def index_exists(index_name: str, table_name: Optional[str] = None) -> bool:
    """Check if an index exists"""
    try:
        connection = op.get_bind()
        inspector = sa.inspect(connection)

        if table_name:
            if not table_exists(table_name):
                return False
            indexes = inspector.get_indexes(table_name)
            return any(idx["name"] == index_name for idx in indexes)
        else:
            # Check across all tables if table_name not provided
            for table in inspector.get_table_names():
                try:
                    indexes = inspector.get_indexes(table)
                    if any(idx["name"] == index_name for idx in indexes):
                        return True
                except:
                    continue
            return False
    except Exception as e:
        print(f"Error checking index existence: {e}")
        return False


def constraint_exists(table_name: str, constraint_name: str) -> bool:
    """Check if a constraint exists in a table"""
    try:
        connection = op.get_bind()
        inspector = sa.inspect(connection)

        if not table_exists(table_name):
            return False

        # Check foreign keys
        foreign_keys = inspector.get_foreign_keys(table_name)
        if any(fk.get("name") == constraint_name for fk in foreign_keys):
            return True

        # Check unique constraints
        unique_constraints = inspector.get_unique_constraints(table_name)
        if any(uc.get("name") == constraint_name for uc in unique_constraints):
            return True

        # Check check constraints
        check_constraints = inspector.get_check_constraints(table_name)
        if any(cc.get("name") == constraint_name for cc in check_constraints):
            return True

        return False
    except Exception as e:
        print(f"Error checking constraint existence: {e}")
        return False


def safe_add_column(table_name: str, column: sa.Column, print_info: bool = True) -> bool:
    """
    Safely add a column to a table, checking for existence first

    Args:
        table_name: Name of the table
        column: SQLAlchemy Column object
        print_info: Whether to print information messages

    Returns:
        bool: True if column was added, False if it already existed
    """
    if not table_exists(table_name):
        if print_info:
            print(f"Table {table_name} does not exist, cannot add column {column.name}")
        return False

    if column_exists(table_name, column.name):
        if print_info:
            print(f"Column {column.name} already exists in table {table_name}, skipping...")
        return False
    else:
        if print_info:
            print(f"Adding column {column.name} to table {table_name}...")
        op.add_column(table_name, column)
        return True


def safe_drop_column(table_name: str, column_name: str, print_info: bool = True) -> bool:
    """
    Safely drop a column from a table, checking for existence first

    Args:
        table_name: Name of the table
        column_name: Name of the column to drop
        print_info: Whether to print information messages

    Returns:
        bool: True if column was dropped, False if it didn't exist
    """
    if not table_exists(table_name):
        if print_info:
            print(f"Table {table_name} does not exist, cannot drop column {column_name}")
        return False

    if not column_exists(table_name, column_name):
        if print_info:
            print(f"Column {column_name} does not exist in table {table_name}, skipping...")
        return False
    else:
        if print_info:
            print(f"Dropping column {column_name} from table {table_name}...")
        op.drop_column(table_name, column_name)
        return True


def safe_create_index(
    index_name: str, table_name: str, columns: List[str], unique: bool = False, print_info: bool = True, **kwargs
) -> bool:
    """
    Safely create an index, checking for existence first

    Args:
        index_name: Name of the index
        table_name: Name of the table
        columns: List of column names
        unique: Whether the index should be unique
        print_info: Whether to print information messages
        **kwargs: Additional arguments for create_index

    Returns:
        bool: True if index was created, False if it already existed
    """
    if not table_exists(table_name):
        if print_info:
            print(f"Table {table_name} does not exist, cannot create index {index_name}")
        return False

    if index_exists(index_name, table_name):
        if print_info:
            print(f"Index {index_name} already exists, skipping...")
        return False
    else:
        if print_info:
            print(f"Creating index {index_name} on table {table_name}...")
        op.create_index(index_name, table_name, columns, unique=unique, **kwargs)
        return True


def safe_drop_index(index_name: str, table_name: Optional[str] = None, print_info: bool = True) -> bool:
    """
    Safely drop an index, checking for existence first

    Args:
        index_name: Name of the index
        table_name: Name of the table (optional)
        print_info: Whether to print information messages

    Returns:
        bool: True if index was dropped, False if it didn't exist
    """
    if not index_exists(index_name, table_name):
        if print_info:
            print(f"Index {index_name} does not exist, skipping...")
        return False
    else:
        if print_info:
            print(f"Dropping index {index_name}...")
        if table_name:
            op.drop_index(index_name, table_name=table_name)
        else:
            op.drop_index(index_name)
        return True


def safe_create_table(table_name: str, *args, print_info: bool = True, **kwargs) -> bool:
    """
    Safely create a table, checking for existence first

    Args:
        table_name: Name of the table
        *args: Columns and constraints
        print_info: Whether to print information messages
        **kwargs: Additional arguments for create_table

    Returns:
        bool: True if table was created, False if it already existed
    """
    if table_exists(table_name):
        if print_info:
            print(f"Table {table_name} already exists, skipping...")
        return False
    else:
        if print_info:
            print(f"Creating table {table_name}...")
        op.create_table(table_name, *args, **kwargs)
        return True


def safe_drop_table(table_name: str, print_info: bool = True) -> bool:
    """
    Safely drop a table, checking for existence first

    Args:
        table_name: Name of the table
        print_info: Whether to print information messages

    Returns:
        bool: True if table was dropped, False if it didn't exist
    """
    if not table_exists(table_name):
        if print_info:
            print(f"Table {table_name} does not exist, skipping...")
        return False
    else:
        if print_info:
            print(f"Dropping table {table_name}...")
        op.drop_table(table_name)
        return True


def get_database_info() -> dict:
    """Get general information about the database state"""
    try:
        connection = op.get_bind()
        inspector = sa.inspect(connection)

        tables = inspector.get_table_names()
        total_columns = 0
        total_indexes = 0

        for table in tables:
            try:
                columns = inspector.get_columns(table)
                indexes = inspector.get_indexes(table)
                total_columns += len(columns)
                total_indexes += len(indexes)
            except:
                continue

        return {
            "total_tables": len(tables),
            "total_columns": total_columns,
            "total_indexes": total_indexes,
            "tables": tables,
        }
    except Exception as e:
        return {"error": str(e)}


def validate_migration_safety(table_name: str, operations: List[dict]) -> List[str]:
    """
    Validate that a list of operations can be safely executed

    Args:
        table_name: Target table name
        operations: List of operations like [{"type": "add_column", "name": "col1", ...}]

    Returns:
        List of validation warnings/errors
    """
    warnings = []

    if not table_exists(table_name):
        warnings.append(f"Table {table_name} does not exist")
        return warnings

    for op_dict in operations:
        op_type = op_dict.get("type")

        if op_type == "add_column":
            col_name = op_dict.get("name")
            if column_exists(table_name, col_name):
                warnings.append(f"Column {col_name} already exists in {table_name}")

        elif op_type == "drop_column":
            col_name = op_dict.get("name")
            if not column_exists(table_name, col_name):
                warnings.append(f"Column {col_name} does not exist in {table_name}")

        elif op_type == "create_index":
            idx_name = op_dict.get("name")
            if index_exists(idx_name, table_name):
                warnings.append(f"Index {idx_name} already exists")

        elif op_type == "drop_index":
            idx_name = op_dict.get("name")
            if not index_exists(idx_name, table_name):
                warnings.append(f"Index {idx_name} does not exist")

    return warnings
