510 lines
20 KiB
Python
510 lines
20 KiB
Python
"""Unit tests for v2 Alembic migrations (Task 6).
|
|
|
|
Tests the Alembic migration scripts for v2 schema changes including
|
|
upgrade and downgrade functionality, data migration, and rollback procedures.
|
|
"""
|
|
|
|
import pytest
|
|
import os
|
|
import tempfile
|
|
import shutil
|
|
from datetime import datetime, timezone
|
|
from sqlalchemy import create_engine, text
|
|
from sqlalchemy.orm import sessionmaker
|
|
from alembic import command
|
|
from alembic.config import Config
|
|
from alembic.script import ScriptDirectory
|
|
from alembic.runtime.migration import MigrationContext
|
|
from alembic.operations import Operations
|
|
|
|
from src.database.models import Base, register_model
|
|
from src.database.connection import get_database_url
|
|
|
|
|
|
class TestV2Migrations:
|
|
"""Test suite for v2 schema migrations."""
|
|
|
|
@pytest.fixture(scope="class")
|
|
def test_db_url(self):
|
|
"""Get test database URL."""
|
|
return get_database_url().replace("/trax", "/trax_test")
|
|
|
|
@pytest.fixture(scope="class")
|
|
def test_db_engine(self, test_db_url):
|
|
"""Create test database engine."""
|
|
engine = create_engine(test_db_url)
|
|
|
|
# Create initial schema (before v2 migration)
|
|
Base.metadata.create_all(engine)
|
|
yield engine
|
|
|
|
# Cleanup
|
|
Base.metadata.drop_all(engine)
|
|
engine.dispose()
|
|
|
|
@pytest.fixture
|
|
def alembic_config(self, test_db_url):
|
|
"""Create Alembic configuration for testing."""
|
|
# Create temporary directory for migrations
|
|
temp_dir = tempfile.mkdtemp()
|
|
|
|
# Copy alembic.ini to temp directory
|
|
original_ini = "alembic.ini"
|
|
temp_ini = os.path.join(temp_dir, "alembic.ini")
|
|
shutil.copy2(original_ini, temp_ini)
|
|
|
|
# Create migrations directory
|
|
migrations_dir = os.path.join(temp_dir, "migrations")
|
|
os.makedirs(migrations_dir, exist_ok=True)
|
|
|
|
# Create versions directory
|
|
versions_dir = os.path.join(migrations_dir, "versions")
|
|
os.makedirs(versions_dir, exist_ok=True)
|
|
|
|
# Create env.py
|
|
env_py_content = '''
|
|
from logging.config import fileConfig
|
|
from sqlalchemy import engine_from_config
|
|
from sqlalchemy import pool
|
|
from alembic import context
|
|
from src.database.models import Base
|
|
|
|
config = context.config
|
|
if config.config_file_name is not None:
|
|
fileConfig(config.config_file_name)
|
|
|
|
target_metadata = Base.metadata
|
|
|
|
def run_migrations_offline() -> None:
|
|
url = config.get_main_option("sqlalchemy.url")
|
|
context.configure(
|
|
url=url,
|
|
target_metadata=target_metadata,
|
|
literal_binds=True,
|
|
dialect_opts={"paramstyle": "named"},
|
|
)
|
|
|
|
with context.begin_transaction():
|
|
context.run_migrations()
|
|
|
|
def run_migrations_online() -> None:
|
|
connectable = engine_from_config(
|
|
config.get_section(config.config_ini_section, {}),
|
|
prefix="sqlalchemy.",
|
|
poolclass=pool.NullPool,
|
|
)
|
|
|
|
with connectable.connect() as connection:
|
|
context.configure(
|
|
connection=connection, target_metadata=target_metadata
|
|
)
|
|
|
|
with context.begin_transaction():
|
|
context.run_migrations()
|
|
|
|
if context.is_offline_mode():
|
|
run_migrations_offline()
|
|
else:
|
|
run_migrations_online()
|
|
'''
|
|
|
|
with open(os.path.join(migrations_dir, "env.py"), "w") as f:
|
|
f.write(env_py_content)
|
|
|
|
# Create script.py.mako
|
|
script_mako_content = '''"""${message}
|
|
|
|
Revision ID: ${up_revision}
|
|
Revises: ${down_revision | comma,n}
|
|
Create Date: ${create_date}
|
|
|
|
"""
|
|
from alembic import op
|
|
import sqlalchemy as sa
|
|
${imports if imports else ""}
|
|
|
|
# revision identifiers, used by Alembic.
|
|
revision = ${repr(up_revision)}
|
|
down_revision = ${repr(down_revision)}
|
|
branch_labels = ${repr(branch_labels)}
|
|
depends_on = ${repr(depends_on)}
|
|
|
|
|
|
def upgrade() -> None:
|
|
${upgrades if upgrades else "pass"}
|
|
|
|
|
|
def downgrade() -> None:
|
|
${downgrades if downgrades else "pass"}
|
|
'''
|
|
|
|
with open(os.path.join(migrations_dir, "script.py.mako"), "w") as f:
|
|
f.write(script_mako_content)
|
|
|
|
# Update alembic.ini with test database URL
|
|
with open(temp_ini, "r") as f:
|
|
ini_content = f.read()
|
|
|
|
ini_content = ini_content.replace(
|
|
"sqlalchemy.url = driver://user:pass@localhost/dbname",
|
|
f"sqlalchemy.url = {test_db_url}"
|
|
)
|
|
|
|
with open(temp_ini, "w") as f:
|
|
f.write(ini_content)
|
|
|
|
# Create Alembic config
|
|
config = Config(temp_ini)
|
|
config.set_main_option("script_location", migrations_dir)
|
|
|
|
yield config
|
|
|
|
# Cleanup
|
|
shutil.rmtree(temp_dir)
|
|
|
|
@pytest.fixture
|
|
def db_session(self, test_db_engine):
|
|
"""Create database session for tests."""
|
|
Session = sessionmaker(bind=test_db_engine)
|
|
session = Session()
|
|
yield session
|
|
session.rollback()
|
|
session.close()
|
|
|
|
def test_migration_script_creation(self, alembic_config):
|
|
"""Test that migration script can be created."""
|
|
# Create migration script
|
|
script = ScriptDirectory.from_config(alembic_config)
|
|
|
|
# Generate migration
|
|
revision = command.revision(alembic_config, message="Add v2 schema")
|
|
|
|
assert revision is not None, "Migration script should be created"
|
|
assert revision.revision is not None, "Migration should have revision ID"
|
|
|
|
def test_migration_upgrade(self, alembic_config, test_db_engine):
|
|
"""Test migration upgrade functionality."""
|
|
# Create initial migration script
|
|
command.revision(alembic_config, message="Initial schema")
|
|
|
|
# Create v2 migration script
|
|
v2_migration = command.revision(alembic_config, message="Add v2 schema")
|
|
|
|
# Write v2 migration content
|
|
migration_file = os.path.join(
|
|
alembic_config.get_main_option("script_location"),
|
|
"versions",
|
|
f"{v2_migration.revision}_add_v2_schema.py"
|
|
)
|
|
|
|
migration_content = '''
|
|
"""Add v2 schema
|
|
|
|
Revision ID: {revision}
|
|
Revises: {down_revision}
|
|
Create Date: {create_date}
|
|
|
|
"""
|
|
from alembic import op
|
|
import sqlalchemy as sa
|
|
from sqlalchemy.dialects.postgresql import JSONB
|
|
|
|
# revision identifiers
|
|
revision = '{revision}'
|
|
down_revision = '{down_revision}'
|
|
branch_labels = None
|
|
depends_on = None
|
|
|
|
def upgrade() -> None:
|
|
# Create speaker_profiles table
|
|
op.create_table(
|
|
'speaker_profiles',
|
|
sa.Column('id', sa.Integer(), nullable=False),
|
|
sa.Column('name', sa.String(255), nullable=False),
|
|
sa.Column('created_at', sa.TIMESTAMP(timezone=True), server_default=sa.text('CURRENT_TIMESTAMP')),
|
|
sa.Column('updated_at', sa.TIMESTAMP(timezone=True), server_default=sa.text('CURRENT_TIMESTAMP')),
|
|
sa.Column('characteristics', JSONB, nullable=True),
|
|
sa.Column('embedding', sa.LargeBinary(), nullable=True),
|
|
sa.Column('sample_count', sa.Integer(), server_default='0'),
|
|
sa.Column('user_id', sa.Integer(), nullable=True),
|
|
sa.PrimaryKeyConstraint('id')
|
|
)
|
|
|
|
# Create processing_jobs table
|
|
op.create_table(
|
|
'processing_jobs',
|
|
sa.Column('id', sa.Integer(), nullable=False),
|
|
sa.Column('status', sa.String(50), server_default='pending', nullable=False),
|
|
sa.Column('created_at', sa.TIMESTAMP(timezone=True), server_default=sa.text('CURRENT_TIMESTAMP')),
|
|
sa.Column('updated_at', sa.TIMESTAMP(timezone=True), server_default=sa.text('CURRENT_TIMESTAMP')),
|
|
sa.Column('completed_at', sa.TIMESTAMP(timezone=True), nullable=True),
|
|
sa.Column('transcript_id', sa.Integer(), nullable=True),
|
|
sa.Column('job_type', sa.String(50), nullable=False),
|
|
sa.Column('parameters', JSONB, nullable=True),
|
|
sa.Column('progress', sa.Float(), server_default='0'),
|
|
sa.Column('error_message', sa.Text(), nullable=True),
|
|
sa.Column('result_data', JSONB, nullable=True),
|
|
sa.PrimaryKeyConstraint('id')
|
|
)
|
|
|
|
# Add v2 columns to transcripts table
|
|
op.add_column('transcription_results', sa.Column('pipeline_version', sa.String(20), nullable=True))
|
|
op.add_column('transcription_results', sa.Column('enhanced_content', JSONB, nullable=True))
|
|
op.add_column('transcription_results', sa.Column('diarization_content', JSONB, nullable=True))
|
|
op.add_column('transcription_results', sa.Column('merged_content', JSONB, nullable=True))
|
|
op.add_column('transcription_results', sa.Column('model_used', sa.String(100), nullable=True))
|
|
op.add_column('transcription_results', sa.Column('domain_used', sa.String(100), nullable=True))
|
|
op.add_column('transcription_results', sa.Column('accuracy_estimate', sa.Float(), nullable=True))
|
|
op.add_column('transcription_results', sa.Column('confidence_scores', JSONB, nullable=True))
|
|
op.add_column('transcription_results', sa.Column('speaker_count', sa.Integer(), nullable=True))
|
|
op.add_column('transcription_results', sa.Column('quality_warnings', JSONB, nullable=True))
|
|
op.add_column('transcription_results', sa.Column('processing_metadata', JSONB, nullable=True))
|
|
|
|
def downgrade() -> None:
|
|
# Remove v2 columns from transcripts table
|
|
op.drop_column('transcription_results', 'processing_metadata')
|
|
op.drop_column('transcription_results', 'quality_warnings')
|
|
op.drop_column('transcription_results', 'speaker_count')
|
|
op.drop_column('transcription_results', 'confidence_scores')
|
|
op.drop_column('transcription_results', 'accuracy_estimate')
|
|
op.drop_column('transcription_results', 'domain_used')
|
|
op.drop_column('transcription_results', 'model_used')
|
|
op.drop_column('transcription_results', 'merged_content')
|
|
op.drop_column('transcription_results', 'diarization_content')
|
|
op.drop_column('transcription_results', 'enhanced_content')
|
|
op.drop_column('transcription_results', 'pipeline_version')
|
|
|
|
# Drop processing_jobs table
|
|
op.drop_table('processing_jobs')
|
|
|
|
# Drop speaker_profiles table
|
|
op.drop_table('speaker_profiles')
|
|
'''.format(
|
|
revision=v2_migration.revision,
|
|
down_revision=v2_migration.down_revision or "None",
|
|
create_date=datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
)
|
|
|
|
with open(migration_file, "w") as f:
|
|
f.write(migration_content)
|
|
|
|
# Run migration
|
|
command.upgrade(alembic_config, "head")
|
|
|
|
# Verify tables were created
|
|
with test_db_engine.connect() as conn:
|
|
# Check speaker_profiles table
|
|
result = conn.execute(text("""
|
|
SELECT table_name
|
|
FROM information_schema.tables
|
|
WHERE table_schema = 'public' AND table_name = 'speaker_profiles'
|
|
"""))
|
|
assert result.fetchone() is not None, "speaker_profiles table should exist"
|
|
|
|
# Check processing_jobs table
|
|
result = conn.execute(text("""
|
|
SELECT table_name
|
|
FROM information_schema.tables
|
|
WHERE table_schema = 'public' AND table_name = 'processing_jobs'
|
|
"""))
|
|
assert result.fetchone() is not None, "processing_jobs table should exist"
|
|
|
|
# Check v2 columns in transcription_results
|
|
result = conn.execute(text("""
|
|
SELECT column_name
|
|
FROM information_schema.columns
|
|
WHERE table_name = 'transcription_results'
|
|
AND column_name IN ('pipeline_version', 'enhanced_content', 'diarization_content')
|
|
"""))
|
|
columns = [row[0] for row in result.fetchall()]
|
|
assert 'pipeline_version' in columns, "pipeline_version column should exist"
|
|
assert 'enhanced_content' in columns, "enhanced_content column should exist"
|
|
assert 'diarization_content' in columns, "diarization_content column should exist"
|
|
|
|
def test_migration_downgrade(self, alembic_config, test_db_engine):
|
|
"""Test migration downgrade functionality."""
|
|
# First run upgrade
|
|
command.revision(alembic_config, message="Initial schema")
|
|
v2_migration = command.revision(alembic_config, message="Add v2 schema")
|
|
|
|
# Write migration content (same as above)
|
|
migration_file = os.path.join(
|
|
alembic_config.get_main_option("script_location"),
|
|
"versions",
|
|
f"{v2_migration.revision}_add_v2_schema.py"
|
|
)
|
|
|
|
# Write the same migration content as in test_migration_upgrade
|
|
# (Implementation would be the same)
|
|
|
|
# Run upgrade
|
|
command.upgrade(alembic_config, "head")
|
|
|
|
# Verify tables exist
|
|
with test_db_engine.connect() as conn:
|
|
result = conn.execute(text("""
|
|
SELECT table_name
|
|
FROM information_schema.tables
|
|
WHERE table_schema = 'public' AND table_name = 'speaker_profiles'
|
|
"""))
|
|
assert result.fetchone() is not None, "speaker_profiles table should exist before downgrade"
|
|
|
|
# Run downgrade
|
|
command.downgrade(alembic_config, "-1")
|
|
|
|
# Verify tables were removed
|
|
with test_db_engine.connect() as conn:
|
|
result = conn.execute(text("""
|
|
SELECT table_name
|
|
FROM information_schema.tables
|
|
WHERE table_schema = 'public' AND table_name = 'speaker_profiles'
|
|
"""))
|
|
assert result.fetchone() is None, "speaker_profiles table should not exist after downgrade"
|
|
|
|
result = conn.execute(text("""
|
|
SELECT table_name
|
|
FROM information_schema.tables
|
|
WHERE table_schema = 'public' AND table_name = 'processing_jobs'
|
|
"""))
|
|
assert result.fetchone() is None, "processing_jobs table should not exist after downgrade"
|
|
|
|
def test_migration_idempotency(self, alembic_config, test_db_engine):
|
|
"""Test that running migration twice doesn't cause errors."""
|
|
# Create and run migration
|
|
command.revision(alembic_config, message="Initial schema")
|
|
v2_migration = command.revision(alembic_config, message="Add v2 schema")
|
|
|
|
# Write migration content (same as above)
|
|
# (Implementation would be the same)
|
|
|
|
# Run migration first time
|
|
command.upgrade(alembic_config, "head")
|
|
|
|
# Run migration second time (should not fail)
|
|
try:
|
|
command.upgrade(alembic_config, "head")
|
|
except Exception as e:
|
|
pytest.fail(f"Running migration twice should not fail: {e}")
|
|
|
|
def test_data_migration_script(self, test_db_engine, db_session):
|
|
"""Test data migration script functionality."""
|
|
from src.database.models import TranscriptionResult
|
|
|
|
# Create some v1 transcripts
|
|
v1_transcripts = []
|
|
for i in range(5):
|
|
transcript = TranscriptionResult(
|
|
content={"text": f"V1 transcript {i}"},
|
|
accuracy=0.85 + (i * 0.02)
|
|
)
|
|
db_session.add(transcript)
|
|
v1_transcripts.append(transcript)
|
|
|
|
db_session.commit()
|
|
|
|
# Run data migration
|
|
from src.migrations.data_migration import migrate_existing_data
|
|
migrate_existing_data(test_db_engine.url)
|
|
|
|
# Verify migration results
|
|
for transcript in v1_transcripts:
|
|
db_session.refresh(transcript)
|
|
assert transcript.pipeline_version == "v1", "Existing transcripts should be marked as v1"
|
|
assert transcript.enhanced_content is not None, "Enhanced content should be set"
|
|
assert transcript.confidence_scores is not None, "Confidence scores should be set"
|
|
assert transcript.quality_warnings is not None, "Quality warnings should be set"
|
|
assert transcript.processing_metadata is not None, "Processing metadata should be set"
|
|
|
|
def test_migration_with_existing_data(self, alembic_config, test_db_engine, db_session):
|
|
"""Test migration with existing data preserves data integrity."""
|
|
from src.database.models import TranscriptionResult
|
|
|
|
# Create existing data before migration
|
|
existing_transcript = TranscriptionResult(
|
|
content={"text": "Existing transcript content"},
|
|
accuracy=0.90
|
|
)
|
|
db_session.add(existing_transcript)
|
|
db_session.commit()
|
|
original_id = existing_transcript.id
|
|
|
|
# Run migration
|
|
command.revision(alembic_config, message="Initial schema")
|
|
v2_migration = command.revision(alembic_config, message="Add v2 schema")
|
|
|
|
# Write migration content (same as above)
|
|
# (Implementation would be the same)
|
|
|
|
command.upgrade(alembic_config, "head")
|
|
|
|
# Verify existing data is preserved
|
|
db_session.refresh(existing_transcript)
|
|
assert existing_transcript.id == original_id, "Transcript ID should be preserved"
|
|
assert existing_transcript.content["text"] == "Existing transcript content", "Content should be preserved"
|
|
assert existing_transcript.accuracy == 0.90, "Accuracy should be preserved"
|
|
|
|
def test_foreign_key_constraints_after_migration(self, alembic_config, test_db_engine):
|
|
"""Test that foreign key constraints are properly set up after migration."""
|
|
# Run migration
|
|
command.revision(alembic_config, message="Initial schema")
|
|
v2_migration = command.revision(alembic_config, message="Add v2 schema")
|
|
|
|
# Write migration content with foreign key constraints
|
|
# (Implementation would include proper foreign key setup)
|
|
|
|
command.upgrade(alembic_config, "head")
|
|
|
|
# Verify foreign key constraints
|
|
with test_db_engine.connect() as conn:
|
|
result = conn.execute(text("""
|
|
SELECT
|
|
tc.table_name,
|
|
kcu.column_name,
|
|
ccu.table_name AS foreign_table_name
|
|
FROM information_schema.table_constraints AS tc
|
|
JOIN information_schema.key_column_usage AS kcu
|
|
ON tc.constraint_name = kcu.constraint_name
|
|
JOIN information_schema.constraint_column_usage AS ccu
|
|
ON ccu.constraint_name = tc.constraint_name
|
|
WHERE tc.constraint_type = 'FOREIGN KEY'
|
|
AND tc.table_name IN ('speaker_profiles', 'processing_jobs')
|
|
"""))
|
|
|
|
constraints = {row[0]: row[1:] for row in result.fetchall()}
|
|
|
|
# Verify expected foreign key constraints exist
|
|
assert 'speaker_profiles' in constraints or 'processing_jobs' in constraints, "Foreign key constraints should exist"
|
|
|
|
def test_migration_rollback_procedure(self, alembic_config, test_db_engine):
|
|
"""Test rollback procedure in case of migration failure."""
|
|
# This test would verify the rollback procedure works correctly
|
|
# Implementation would depend on the specific rollback requirements
|
|
pass
|
|
|
|
def test_migration_performance(self, alembic_config, test_db_engine):
|
|
"""Test that migration completes in reasonable time."""
|
|
import time
|
|
|
|
# Create migration
|
|
command.revision(alembic_config, message="Initial schema")
|
|
v2_migration = command.revision(alembic_config, message="Add v2 schema")
|
|
|
|
# Write migration content (same as above)
|
|
# (Implementation would be the same)
|
|
|
|
# Time the migration
|
|
start_time = time.time()
|
|
command.upgrade(alembic_config, "head")
|
|
end_time = time.time()
|
|
|
|
migration_time = end_time - start_time
|
|
|
|
# Migration should complete in reasonable time (< 30 seconds for test database)
|
|
assert migration_time < 30.0, f"Migration took {migration_time} seconds, should be under 30 seconds"
|
|
|
|
def test_migration_logging(self, alembic_config, test_db_engine):
|
|
"""Test that migration provides appropriate logging."""
|
|
# This test would verify that migration provides appropriate logging
|
|
# Implementation would depend on logging requirements
|
|
pass
|