trax/tests/test_v2_schema_migration.py

532 lines
22 KiB
Python

"""Unit tests for v2 schema migration (Task 6).
Tests the database schema migration for v2 features including speaker profiles,
processing jobs, enhanced transcripts, and new v2-specific columns while maintaining
backward compatibility.
"""
import pytest
from datetime import datetime, timezone
from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker
from sqlalchemy.exc import IntegrityError, ProgrammingError
from uuid import uuid4
import json
import os
import tempfile
import shutil
from src.database.models import Base, register_model
from src.database.connection import get_database_url
class TestV2SchemaMigration:
"""Test suite for v2 schema migration components."""
def _create_test_transcript(self, db_session, pipeline_version="v2", **kwargs):
"""Helper method to create a test transcription result with required dependencies."""
from src.database.models import TranscriptionResult, TranscriptionJob, MediaFile
import uuid
# Create required records in dependency order
media_file = MediaFile(
id=uuid.uuid4(),
filename="test_audio.wav",
file_size=1024 * 1024,
source_path="/path/to/source.wav",
status="ready"
)
db_session.add(media_file)
db_session.commit()
transcription_job = TranscriptionJob(
id=uuid.uuid4(),
media_file_id=media_file.id,
status="completed"
)
db_session.add(transcription_job)
db_session.commit()
# Create transcription result with provided parameters
transcript_data = {
'id': uuid.uuid4(),
'job_id': transcription_job.id,
'media_file_id': media_file.id,
'pipeline_version': pipeline_version,
'content': {"text": "Test transcript"},
**kwargs
}
transcript = TranscriptionResult(**transcript_data)
db_session.add(transcript)
db_session.commit()
return transcript
@pytest.fixture(scope="class")
def test_db_engine(self):
"""Create test database engine."""
# Use test database URL
test_db_url = get_database_url().replace("/trax", "/trax_test")
engine = create_engine(test_db_url)
# Create all tables
Base.metadata.create_all(engine)
yield engine
# Cleanup
Base.metadata.drop_all(engine)
engine.dispose()
@pytest.fixture
def db_session(self, test_db_engine):
"""Create database session for tests."""
Session = sessionmaker(bind=test_db_engine)
session = Session()
yield session
session.rollback()
session.close()
def test_speaker_profiles_table_structure(self, test_db_engine):
"""Test that speaker_profiles table has correct structure."""
with test_db_engine.connect() as conn:
# Check table exists
result = conn.execute(text("""
SELECT table_name
FROM information_schema.tables
WHERE table_schema = 'public' AND table_name = 'speaker_profiles'
"""))
assert result.fetchone() is not None, "speaker_profiles table should exist"
# Check columns
result = conn.execute(text("""
SELECT column_name, data_type, is_nullable, column_default
FROM information_schema.columns
WHERE table_name = 'speaker_profiles'
ORDER BY ordinal_position
"""))
columns = {row[0]: row[1:] for row in result.fetchall()}
expected_columns = {
'id': ('integer', 'NO', None),
'name': ('character varying', 'NO', None),
'created_at': ('timestamp without time zone', 'NO', None),
'updated_at': ('timestamp without time zone', 'NO', None),
'characteristics': ('jsonb', 'YES', None),
'embedding': ('text', 'YES', None),
'sample_count': ('integer', 'YES', None),
'user_id': ('integer', 'YES', None)
}
for col_name, (data_type, nullable, default) in expected_columns.items():
assert col_name in columns, f"Column {col_name} should exist"
assert columns[col_name][0] == data_type, f"Column {col_name} should have type {data_type}"
assert columns[col_name][1] == nullable, f"Column {col_name} should be {nullable}"
def test_processing_jobs_table_structure(self, test_db_engine):
"""Test that v2_processing_jobs table has correct structure."""
with test_db_engine.connect() as conn:
# Check table exists
result = conn.execute(text("""
SELECT table_name
FROM information_schema.tables
WHERE table_schema = 'public' AND table_name = 'v2_processing_jobs'
"""))
assert result.fetchone() is not None, "v2_processing_jobs table should exist"
# Check columns
result = conn.execute(text("""
SELECT column_name, data_type, is_nullable, column_default
FROM information_schema.columns
WHERE table_name = 'v2_processing_jobs'
ORDER BY ordinal_position
"""))
columns = {row[0]: row[1:] for row in result.fetchall()}
expected_columns = {
'id': ('integer', 'NO', None),
'status': ('character varying', 'NO', None),
'created_at': ('timestamp without time zone', 'NO', None),
'updated_at': ('timestamp without time zone', 'NO', None),
'completed_at': ('timestamp without time zone', 'YES', None),
'transcript_id': ('uuid', 'YES', None),
'job_type': ('character varying', 'NO', None),
'parameters': ('jsonb', 'YES', None),
'progress': ('double precision', 'YES', None),
'error_message': ('text', 'YES', None),
'result_data': ('jsonb', 'YES', None)
}
for col_name, (data_type, nullable, default) in expected_columns.items():
assert col_name in columns, f"Column {col_name} should exist"
assert columns[col_name][0] == data_type, f"Column {col_name} should have type {data_type}"
assert columns[col_name][1] == nullable, f"Column {col_name} should be {nullable}"
def test_transcripts_v2_columns_structure(self, test_db_engine):
"""Test that transcription_results table has new v2 columns."""
with test_db_engine.connect() as conn:
# Check new v2 columns exist
result = conn.execute(text("""
SELECT column_name, data_type, is_nullable
FROM information_schema.columns
WHERE table_name = 'transcription_results'
AND column_name IN (
'pipeline_version', 'enhanced_content', 'diarization_content',
'merged_content', 'model_used', 'domain_used', 'accuracy_estimate',
'speaker_count', 'quality_warnings', 'processing_metadata'
)
ORDER BY column_name
"""))
columns = {row[0]: row[1:] for row in result.fetchall()}
expected_v2_columns = {
'pipeline_version': ('character varying', 'NO'),
'enhanced_content': ('jsonb', 'YES'),
'diarization_content': ('jsonb', 'YES'),
'merged_content': ('jsonb', 'YES'),
'model_used': ('character varying', 'YES'),
'domain_used': ('character varying', 'YES'),
'accuracy_estimate': ('double precision', 'YES'),
'speaker_count': ('integer', 'YES'),
'quality_warnings': ('jsonb', 'YES'),
'processing_metadata': ('jsonb', 'YES')
}
for col_name, (data_type, nullable) in expected_v2_columns.items():
assert col_name in columns, f"V2 column {col_name} should exist"
assert columns[col_name][0] == data_type, f"Column {col_name} should have type {data_type}"
assert columns[col_name][1] == nullable, f"Column {col_name} should be nullable for backward compatibility"
def test_foreign_key_constraints(self, test_db_engine):
"""Test foreign key constraints are properly set up."""
with test_db_engine.connect() as conn:
# Check foreign key constraints
result = conn.execute(text("""
SELECT
tc.table_name,
kcu.column_name,
ccu.table_name AS foreign_table_name,
ccu.column_name AS foreign_column_name
FROM information_schema.table_constraints AS tc
JOIN information_schema.key_column_usage AS kcu
ON tc.constraint_name = kcu.constraint_name
JOIN information_schema.constraint_column_usage AS ccu
ON ccu.constraint_name = tc.constraint_name
WHERE tc.constraint_type = 'FOREIGN KEY'
AND tc.table_name IN ('v2_processing_jobs')
ORDER BY tc.table_name, kcu.column_name
"""))
constraints = {row[0]: row[1:] for row in result.fetchall()}
# Check v2_processing_jobs foreign key
assert 'v2_processing_jobs' in constraints, "v2_processing_jobs should have foreign key constraints"
def test_speaker_profile_crud_operations(self, db_session):
"""Test CRUD operations on SpeakerProfile model."""
from src.database.models import SpeakerProfile
# Create
profile = SpeakerProfile(
name="Test Speaker",
user_id=1,
characteristics={"voice_type": "male", "accent": "american"},
sample_count=0
)
db_session.add(profile)
db_session.commit()
assert profile.id is not None, "Profile should have an ID after creation"
assert profile.name == "Test Speaker"
assert profile.characteristics["voice_type"] == "male"
# Read
retrieved = db_session.query(SpeakerProfile).filter_by(id=profile.id).first()
assert retrieved is not None, "Should be able to retrieve created profile"
assert retrieved.name == "Test Speaker"
# Update
retrieved.name = "Updated Speaker"
retrieved.sample_count = 5
db_session.commit()
updated = db_session.query(SpeakerProfile).filter_by(id=profile.id).first()
assert updated.name == "Updated Speaker"
assert updated.sample_count == 5
# Delete
db_session.delete(updated)
db_session.commit()
deleted = db_session.query(SpeakerProfile).filter_by(id=profile.id).first()
assert deleted is None, "Profile should be deleted"
def test_processing_job_crud_operations(self, db_session):
"""Test CRUD operations on V2ProcessingJob model."""
from src.database.models import V2ProcessingJob, TranscriptionResult, TranscriptionJob, MediaFile
import uuid
# Create required records in dependency order
media_file = MediaFile(
id=uuid.uuid4(),
filename="test_audio.wav",
file_size=1024 * 1024,
source_path="/path/to/source.wav",
status="ready"
)
db_session.add(media_file)
db_session.commit()
transcription_job = TranscriptionJob(
id=uuid.uuid4(),
media_file_id=media_file.id,
status="completed"
)
db_session.add(transcription_job)
db_session.commit()
# Create a transcription result to reference
transcript = TranscriptionResult(
id=uuid.uuid4(),
job_id=transcription_job.id,
media_file_id=media_file.id,
pipeline_version="v2",
content={"text": "Test transcript"}
)
db_session.add(transcript)
db_session.commit()
# Create
job = V2ProcessingJob(
status="pending",
transcript_id=transcript.id,
job_type="enhancement",
parameters={"model": "gpt-4", "temperature": 0.7},
progress=0.0
)
db_session.add(job)
db_session.commit()
assert job.id is not None, "Job should have an ID after creation"
assert job.status == "pending"
assert job.job_type == "enhancement"
assert job.parameters["model"] == "gpt-4"
# Read
retrieved = db_session.query(V2ProcessingJob).filter_by(id=job.id).first()
assert retrieved is not None, "Should be able to retrieve created job"
assert retrieved.status == "pending"
# Update
retrieved.status = "processing"
retrieved.progress = 0.5
db_session.commit()
updated = db_session.query(V2ProcessingJob).filter_by(id=job.id).first()
assert updated.status == "processing"
assert updated.progress == 0.5
# Delete
db_session.delete(updated)
db_session.commit()
deleted = db_session.query(V2ProcessingJob).filter_by(id=job.id).first()
assert deleted is None, "Job should be deleted"
def test_transcript_v2_fields(self, db_session):
"""Test that Transcript model supports v2 fields."""
# Create transcript with v2 fields using helper method
transcript = self._create_test_transcript(
db_session,
pipeline_version="v2",
content={"text": "Test transcript", "segments": []},
enhanced_content={"enhanced_text": "Enhanced transcript"},
diarization_content={"speakers": [{"id": 1, "name": "Speaker 1"}]},
merged_content={"merged_text": "Merged transcript"},
model_used="whisper-large-v3",
domain_used="technical",
accuracy_estimate=0.95,
speaker_count=2,
quality_warnings=["low_confidence_segments"],
processing_metadata={"enhancement_applied": True}
)
assert transcript.id is not None, "Transcript should have an ID after creation"
assert transcript.pipeline_version == "v2"
assert transcript.enhanced_content["enhanced_text"] == "Enhanced transcript"
assert transcript.diarization_content["speakers"][0]["name"] == "Speaker 1"
assert transcript.accuracy_estimate == 0.95
assert transcript.speaker_count == 2
def test_backward_compatibility(self, db_session):
"""Test backward compatibility with v1 data."""
from src.database.models import TranscriptionResult
# Create v1-style transcript (without v2 fields)
v1_transcript = self._create_test_transcript(
db_session,
pipeline_version="v1",
content={"text": "V1 transcript content"},
accuracy=0.85
)
# Verify v1 transcript works
assert v1_transcript.id is not None
assert v1_transcript.pipeline_version == "v1"
assert v1_transcript.enhanced_content is None # Should be None for v1
assert v1_transcript.diarization_content is None # Should be None for v1
# Verify we can update v1 transcript to v2
v1_transcript.pipeline_version = "v2"
v1_transcript.enhanced_content = {"enhanced": True}
db_session.commit()
updated = db_session.query(TranscriptionResult).filter_by(id=v1_transcript.id).first()
assert updated.pipeline_version == "v2"
assert updated.enhanced_content["enhanced"] is True
def test_data_migration_script(self, test_db_engine):
"""Test data migration script functionality."""
# This test would verify the migration script works correctly
# Implementation would depend on the actual migration script
pass
def test_alembic_migration_rollback(self, test_db_engine):
"""Test that Alembic migration can be rolled back."""
# This test would verify the downgrade path works
# Implementation would depend on the actual migration script
pass
def test_performance_impact(self, test_db_engine):
"""Test that adding v2 columns doesn't significantly impact performance."""
with test_db_engine.connect() as conn:
# Test query performance on transcripts table
start_time = datetime.now()
# Simple count query
result = conn.execute(text("SELECT COUNT(*) FROM transcription_results"))
count = result.scalar()
end_time = datetime.now()
query_time = (end_time - start_time).total_seconds()
# Query should complete in reasonable time (< 1 second for small datasets)
assert query_time < 1.0, f"Query took {query_time} seconds, should be under 1 second"
def test_jsonb_field_operations(self, db_session):
"""Test JSONB field operations for v2 content."""
from src.database.models import TranscriptionResult
import uuid
# Test complex JSONB data
complex_content = {
"segments": [
{
"start": 0.0,
"end": 2.5,
"text": "Hello world",
"speaker": "speaker_1",
"confidence": 0.95
},
{
"start": 2.5,
"end": 5.0,
"text": "How are you?",
"speaker": "speaker_2",
"confidence": 0.88
}
],
"metadata": {
"language": "en",
"model_version": "whisper-large-v3",
"processing_time": 12.5
}
}
transcript = self._create_test_transcript(
db_session,
pipeline_version="v2",
content=complex_content,
enhanced_content={"enhanced_segments": complex_content["segments"]},
diarization_content={"speakers": ["speaker_1", "speaker_2"]}
)
# Verify JSONB data is stored correctly
retrieved = db_session.query(TranscriptionResult).filter_by(id=transcript.id).first()
assert retrieved.content["segments"][0]["text"] == "Hello world"
assert retrieved.content["metadata"]["language"] == "en"
assert len(retrieved.diarization_content["speakers"]) == 2
def test_indexes_and_constraints(self, test_db_engine):
"""Test that proper indexes and constraints are created."""
with test_db_engine.connect() as conn:
# Check indexes on speaker_profiles
result = conn.execute(text("""
SELECT indexname, indexdef
FROM pg_indexes
WHERE tablename = 'speaker_profiles'
"""))
speaker_indexes = [row[0] for row in result.fetchall()]
# Should have primary key index
assert any('pkey' in idx.lower() for idx in speaker_indexes), "Should have primary key index"
# Check indexes on v2_processing_jobs
result = conn.execute(text("""
SELECT indexname, indexdef
FROM pg_indexes
WHERE tablename = 'v2_processing_jobs'
"""))
job_indexes = [row[0] for row in result.fetchall()]
# Should have primary key index
assert any('pkey' in idx.lower() for idx in job_indexes), "Should have primary key index"
def test_timestamp_auto_updating(self, db_session):
"""Test that timestamp fields auto-update correctly."""
from src.database.models import SpeakerProfile
# Create profile
profile = SpeakerProfile(name="Test Speaker", user_id=1)
db_session.add(profile)
db_session.commit()
original_updated_at = profile.updated_at
# Wait a moment
import time
time.sleep(0.1)
# Update profile
profile.name = "Updated Speaker"
db_session.commit()
# Check that updated_at changed
assert profile.updated_at > original_updated_at, "updated_at should be updated on modification"
def test_null_handling_for_backward_compatibility(self, db_session):
"""Test that NULL values are handled correctly for backward compatibility."""
from src.database.models import TranscriptionResult
import uuid
# Create transcript with minimal v1 data
minimal_transcript = self._create_test_transcript(
db_session,
content={"text": "Minimal content"}
# All v2 fields should be NULL by default
)
# Verify v2 fields are NULL
retrieved = db_session.query(TranscriptionResult).filter_by(id=minimal_transcript.id).first()
assert retrieved.enhanced_content is None
assert retrieved.diarization_content is None
assert retrieved.merged_content is None
assert retrieved.model_used is None
assert retrieved.domain_used is None
assert retrieved.accuracy_estimate is None
assert retrieved.confidence_scores is None
assert retrieved.speaker_count is None
assert retrieved.quality_warnings is None
assert retrieved.processing_metadata is None