532 lines
22 KiB
Python
532 lines
22 KiB
Python
"""Unit tests for v2 schema migration (Task 6).
|
|
|
|
Tests the database schema migration for v2 features including speaker profiles,
|
|
processing jobs, enhanced transcripts, and new v2-specific columns while maintaining
|
|
backward compatibility.
|
|
"""
|
|
|
|
import pytest
|
|
from datetime import datetime, timezone
|
|
from sqlalchemy import create_engine, text
|
|
from sqlalchemy.orm import sessionmaker
|
|
from sqlalchemy.exc import IntegrityError, ProgrammingError
|
|
from uuid import uuid4
|
|
import json
|
|
import os
|
|
import tempfile
|
|
import shutil
|
|
|
|
from src.database.models import Base, register_model
|
|
from src.database.connection import get_database_url
|
|
|
|
|
|
class TestV2SchemaMigration:
|
|
"""Test suite for v2 schema migration components."""
|
|
|
|
def _create_test_transcript(self, db_session, pipeline_version="v2", **kwargs):
|
|
"""Helper method to create a test transcription result with required dependencies."""
|
|
from src.database.models import TranscriptionResult, TranscriptionJob, MediaFile
|
|
import uuid
|
|
|
|
# Create required records in dependency order
|
|
media_file = MediaFile(
|
|
id=uuid.uuid4(),
|
|
filename="test_audio.wav",
|
|
file_size=1024 * 1024,
|
|
source_path="/path/to/source.wav",
|
|
status="ready"
|
|
)
|
|
db_session.add(media_file)
|
|
db_session.commit()
|
|
|
|
transcription_job = TranscriptionJob(
|
|
id=uuid.uuid4(),
|
|
media_file_id=media_file.id,
|
|
status="completed"
|
|
)
|
|
db_session.add(transcription_job)
|
|
db_session.commit()
|
|
|
|
# Create transcription result with provided parameters
|
|
transcript_data = {
|
|
'id': uuid.uuid4(),
|
|
'job_id': transcription_job.id,
|
|
'media_file_id': media_file.id,
|
|
'pipeline_version': pipeline_version,
|
|
'content': {"text": "Test transcript"},
|
|
**kwargs
|
|
}
|
|
|
|
transcript = TranscriptionResult(**transcript_data)
|
|
db_session.add(transcript)
|
|
db_session.commit()
|
|
|
|
return transcript
|
|
|
|
@pytest.fixture(scope="class")
|
|
def test_db_engine(self):
|
|
"""Create test database engine."""
|
|
# Use test database URL
|
|
test_db_url = get_database_url().replace("/trax", "/trax_test")
|
|
engine = create_engine(test_db_url)
|
|
|
|
# Create all tables
|
|
Base.metadata.create_all(engine)
|
|
yield engine
|
|
|
|
# Cleanup
|
|
Base.metadata.drop_all(engine)
|
|
engine.dispose()
|
|
|
|
@pytest.fixture
|
|
def db_session(self, test_db_engine):
|
|
"""Create database session for tests."""
|
|
Session = sessionmaker(bind=test_db_engine)
|
|
session = Session()
|
|
yield session
|
|
session.rollback()
|
|
session.close()
|
|
|
|
def test_speaker_profiles_table_structure(self, test_db_engine):
|
|
"""Test that speaker_profiles table has correct structure."""
|
|
with test_db_engine.connect() as conn:
|
|
# Check table exists
|
|
result = conn.execute(text("""
|
|
SELECT table_name
|
|
FROM information_schema.tables
|
|
WHERE table_schema = 'public' AND table_name = 'speaker_profiles'
|
|
"""))
|
|
assert result.fetchone() is not None, "speaker_profiles table should exist"
|
|
|
|
# Check columns
|
|
result = conn.execute(text("""
|
|
SELECT column_name, data_type, is_nullable, column_default
|
|
FROM information_schema.columns
|
|
WHERE table_name = 'speaker_profiles'
|
|
ORDER BY ordinal_position
|
|
"""))
|
|
columns = {row[0]: row[1:] for row in result.fetchall()}
|
|
|
|
expected_columns = {
|
|
'id': ('integer', 'NO', None),
|
|
'name': ('character varying', 'NO', None),
|
|
'created_at': ('timestamp without time zone', 'NO', None),
|
|
'updated_at': ('timestamp without time zone', 'NO', None),
|
|
'characteristics': ('jsonb', 'YES', None),
|
|
'embedding': ('text', 'YES', None),
|
|
'sample_count': ('integer', 'YES', None),
|
|
'user_id': ('integer', 'YES', None)
|
|
}
|
|
|
|
for col_name, (data_type, nullable, default) in expected_columns.items():
|
|
assert col_name in columns, f"Column {col_name} should exist"
|
|
assert columns[col_name][0] == data_type, f"Column {col_name} should have type {data_type}"
|
|
assert columns[col_name][1] == nullable, f"Column {col_name} should be {nullable}"
|
|
|
|
def test_processing_jobs_table_structure(self, test_db_engine):
|
|
"""Test that v2_processing_jobs table has correct structure."""
|
|
with test_db_engine.connect() as conn:
|
|
# Check table exists
|
|
result = conn.execute(text("""
|
|
SELECT table_name
|
|
FROM information_schema.tables
|
|
WHERE table_schema = 'public' AND table_name = 'v2_processing_jobs'
|
|
"""))
|
|
assert result.fetchone() is not None, "v2_processing_jobs table should exist"
|
|
|
|
# Check columns
|
|
result = conn.execute(text("""
|
|
SELECT column_name, data_type, is_nullable, column_default
|
|
FROM information_schema.columns
|
|
WHERE table_name = 'v2_processing_jobs'
|
|
ORDER BY ordinal_position
|
|
"""))
|
|
columns = {row[0]: row[1:] for row in result.fetchall()}
|
|
|
|
expected_columns = {
|
|
'id': ('integer', 'NO', None),
|
|
'status': ('character varying', 'NO', None),
|
|
'created_at': ('timestamp without time zone', 'NO', None),
|
|
'updated_at': ('timestamp without time zone', 'NO', None),
|
|
'completed_at': ('timestamp without time zone', 'YES', None),
|
|
'transcript_id': ('uuid', 'YES', None),
|
|
'job_type': ('character varying', 'NO', None),
|
|
'parameters': ('jsonb', 'YES', None),
|
|
'progress': ('double precision', 'YES', None),
|
|
'error_message': ('text', 'YES', None),
|
|
'result_data': ('jsonb', 'YES', None)
|
|
}
|
|
|
|
for col_name, (data_type, nullable, default) in expected_columns.items():
|
|
assert col_name in columns, f"Column {col_name} should exist"
|
|
assert columns[col_name][0] == data_type, f"Column {col_name} should have type {data_type}"
|
|
assert columns[col_name][1] == nullable, f"Column {col_name} should be {nullable}"
|
|
|
|
def test_transcripts_v2_columns_structure(self, test_db_engine):
|
|
"""Test that transcription_results table has new v2 columns."""
|
|
with test_db_engine.connect() as conn:
|
|
# Check new v2 columns exist
|
|
result = conn.execute(text("""
|
|
SELECT column_name, data_type, is_nullable
|
|
FROM information_schema.columns
|
|
WHERE table_name = 'transcription_results'
|
|
AND column_name IN (
|
|
'pipeline_version', 'enhanced_content', 'diarization_content',
|
|
'merged_content', 'model_used', 'domain_used', 'accuracy_estimate',
|
|
'speaker_count', 'quality_warnings', 'processing_metadata'
|
|
)
|
|
ORDER BY column_name
|
|
"""))
|
|
columns = {row[0]: row[1:] for row in result.fetchall()}
|
|
|
|
expected_v2_columns = {
|
|
'pipeline_version': ('character varying', 'NO'),
|
|
'enhanced_content': ('jsonb', 'YES'),
|
|
'diarization_content': ('jsonb', 'YES'),
|
|
'merged_content': ('jsonb', 'YES'),
|
|
'model_used': ('character varying', 'YES'),
|
|
'domain_used': ('character varying', 'YES'),
|
|
'accuracy_estimate': ('double precision', 'YES'),
|
|
'speaker_count': ('integer', 'YES'),
|
|
'quality_warnings': ('jsonb', 'YES'),
|
|
'processing_metadata': ('jsonb', 'YES')
|
|
}
|
|
|
|
for col_name, (data_type, nullable) in expected_v2_columns.items():
|
|
assert col_name in columns, f"V2 column {col_name} should exist"
|
|
assert columns[col_name][0] == data_type, f"Column {col_name} should have type {data_type}"
|
|
assert columns[col_name][1] == nullable, f"Column {col_name} should be nullable for backward compatibility"
|
|
|
|
def test_foreign_key_constraints(self, test_db_engine):
|
|
"""Test foreign key constraints are properly set up."""
|
|
with test_db_engine.connect() as conn:
|
|
# Check foreign key constraints
|
|
result = conn.execute(text("""
|
|
SELECT
|
|
tc.table_name,
|
|
kcu.column_name,
|
|
ccu.table_name AS foreign_table_name,
|
|
ccu.column_name AS foreign_column_name
|
|
FROM information_schema.table_constraints AS tc
|
|
JOIN information_schema.key_column_usage AS kcu
|
|
ON tc.constraint_name = kcu.constraint_name
|
|
JOIN information_schema.constraint_column_usage AS ccu
|
|
ON ccu.constraint_name = tc.constraint_name
|
|
WHERE tc.constraint_type = 'FOREIGN KEY'
|
|
AND tc.table_name IN ('v2_processing_jobs')
|
|
ORDER BY tc.table_name, kcu.column_name
|
|
"""))
|
|
|
|
constraints = {row[0]: row[1:] for row in result.fetchall()}
|
|
|
|
# Check v2_processing_jobs foreign key
|
|
assert 'v2_processing_jobs' in constraints, "v2_processing_jobs should have foreign key constraints"
|
|
|
|
def test_speaker_profile_crud_operations(self, db_session):
|
|
"""Test CRUD operations on SpeakerProfile model."""
|
|
from src.database.models import SpeakerProfile
|
|
|
|
# Create
|
|
profile = SpeakerProfile(
|
|
name="Test Speaker",
|
|
user_id=1,
|
|
characteristics={"voice_type": "male", "accent": "american"},
|
|
sample_count=0
|
|
)
|
|
db_session.add(profile)
|
|
db_session.commit()
|
|
|
|
assert profile.id is not None, "Profile should have an ID after creation"
|
|
assert profile.name == "Test Speaker"
|
|
assert profile.characteristics["voice_type"] == "male"
|
|
|
|
# Read
|
|
retrieved = db_session.query(SpeakerProfile).filter_by(id=profile.id).first()
|
|
assert retrieved is not None, "Should be able to retrieve created profile"
|
|
assert retrieved.name == "Test Speaker"
|
|
|
|
# Update
|
|
retrieved.name = "Updated Speaker"
|
|
retrieved.sample_count = 5
|
|
db_session.commit()
|
|
|
|
updated = db_session.query(SpeakerProfile).filter_by(id=profile.id).first()
|
|
assert updated.name == "Updated Speaker"
|
|
assert updated.sample_count == 5
|
|
|
|
# Delete
|
|
db_session.delete(updated)
|
|
db_session.commit()
|
|
|
|
deleted = db_session.query(SpeakerProfile).filter_by(id=profile.id).first()
|
|
assert deleted is None, "Profile should be deleted"
|
|
|
|
def test_processing_job_crud_operations(self, db_session):
|
|
"""Test CRUD operations on V2ProcessingJob model."""
|
|
from src.database.models import V2ProcessingJob, TranscriptionResult, TranscriptionJob, MediaFile
|
|
import uuid
|
|
|
|
# Create required records in dependency order
|
|
media_file = MediaFile(
|
|
id=uuid.uuid4(),
|
|
filename="test_audio.wav",
|
|
file_size=1024 * 1024,
|
|
source_path="/path/to/source.wav",
|
|
status="ready"
|
|
)
|
|
db_session.add(media_file)
|
|
db_session.commit()
|
|
|
|
transcription_job = TranscriptionJob(
|
|
id=uuid.uuid4(),
|
|
media_file_id=media_file.id,
|
|
status="completed"
|
|
)
|
|
db_session.add(transcription_job)
|
|
db_session.commit()
|
|
|
|
# Create a transcription result to reference
|
|
transcript = TranscriptionResult(
|
|
id=uuid.uuid4(),
|
|
job_id=transcription_job.id,
|
|
media_file_id=media_file.id,
|
|
pipeline_version="v2",
|
|
content={"text": "Test transcript"}
|
|
)
|
|
db_session.add(transcript)
|
|
db_session.commit()
|
|
|
|
# Create
|
|
job = V2ProcessingJob(
|
|
status="pending",
|
|
transcript_id=transcript.id,
|
|
job_type="enhancement",
|
|
parameters={"model": "gpt-4", "temperature": 0.7},
|
|
progress=0.0
|
|
)
|
|
db_session.add(job)
|
|
db_session.commit()
|
|
|
|
assert job.id is not None, "Job should have an ID after creation"
|
|
assert job.status == "pending"
|
|
assert job.job_type == "enhancement"
|
|
assert job.parameters["model"] == "gpt-4"
|
|
|
|
# Read
|
|
retrieved = db_session.query(V2ProcessingJob).filter_by(id=job.id).first()
|
|
assert retrieved is not None, "Should be able to retrieve created job"
|
|
assert retrieved.status == "pending"
|
|
|
|
# Update
|
|
retrieved.status = "processing"
|
|
retrieved.progress = 0.5
|
|
db_session.commit()
|
|
|
|
updated = db_session.query(V2ProcessingJob).filter_by(id=job.id).first()
|
|
assert updated.status == "processing"
|
|
assert updated.progress == 0.5
|
|
|
|
# Delete
|
|
db_session.delete(updated)
|
|
db_session.commit()
|
|
|
|
deleted = db_session.query(V2ProcessingJob).filter_by(id=job.id).first()
|
|
assert deleted is None, "Job should be deleted"
|
|
|
|
def test_transcript_v2_fields(self, db_session):
|
|
"""Test that Transcript model supports v2 fields."""
|
|
|
|
# Create transcript with v2 fields using helper method
|
|
transcript = self._create_test_transcript(
|
|
db_session,
|
|
pipeline_version="v2",
|
|
content={"text": "Test transcript", "segments": []},
|
|
enhanced_content={"enhanced_text": "Enhanced transcript"},
|
|
diarization_content={"speakers": [{"id": 1, "name": "Speaker 1"}]},
|
|
merged_content={"merged_text": "Merged transcript"},
|
|
model_used="whisper-large-v3",
|
|
domain_used="technical",
|
|
accuracy_estimate=0.95,
|
|
speaker_count=2,
|
|
quality_warnings=["low_confidence_segments"],
|
|
processing_metadata={"enhancement_applied": True}
|
|
)
|
|
|
|
assert transcript.id is not None, "Transcript should have an ID after creation"
|
|
assert transcript.pipeline_version == "v2"
|
|
assert transcript.enhanced_content["enhanced_text"] == "Enhanced transcript"
|
|
assert transcript.diarization_content["speakers"][0]["name"] == "Speaker 1"
|
|
assert transcript.accuracy_estimate == 0.95
|
|
assert transcript.speaker_count == 2
|
|
|
|
def test_backward_compatibility(self, db_session):
|
|
"""Test backward compatibility with v1 data."""
|
|
from src.database.models import TranscriptionResult
|
|
|
|
# Create v1-style transcript (without v2 fields)
|
|
v1_transcript = self._create_test_transcript(
|
|
db_session,
|
|
pipeline_version="v1",
|
|
content={"text": "V1 transcript content"},
|
|
accuracy=0.85
|
|
)
|
|
|
|
# Verify v1 transcript works
|
|
assert v1_transcript.id is not None
|
|
assert v1_transcript.pipeline_version == "v1"
|
|
assert v1_transcript.enhanced_content is None # Should be None for v1
|
|
assert v1_transcript.diarization_content is None # Should be None for v1
|
|
|
|
# Verify we can update v1 transcript to v2
|
|
v1_transcript.pipeline_version = "v2"
|
|
v1_transcript.enhanced_content = {"enhanced": True}
|
|
db_session.commit()
|
|
|
|
updated = db_session.query(TranscriptionResult).filter_by(id=v1_transcript.id).first()
|
|
assert updated.pipeline_version == "v2"
|
|
assert updated.enhanced_content["enhanced"] is True
|
|
|
|
def test_data_migration_script(self, test_db_engine):
|
|
"""Test data migration script functionality."""
|
|
# This test would verify the migration script works correctly
|
|
# Implementation would depend on the actual migration script
|
|
pass
|
|
|
|
def test_alembic_migration_rollback(self, test_db_engine):
|
|
"""Test that Alembic migration can be rolled back."""
|
|
# This test would verify the downgrade path works
|
|
# Implementation would depend on the actual migration script
|
|
pass
|
|
|
|
def test_performance_impact(self, test_db_engine):
|
|
"""Test that adding v2 columns doesn't significantly impact performance."""
|
|
with test_db_engine.connect() as conn:
|
|
# Test query performance on transcripts table
|
|
start_time = datetime.now()
|
|
|
|
# Simple count query
|
|
result = conn.execute(text("SELECT COUNT(*) FROM transcription_results"))
|
|
count = result.scalar()
|
|
|
|
end_time = datetime.now()
|
|
query_time = (end_time - start_time).total_seconds()
|
|
|
|
# Query should complete in reasonable time (< 1 second for small datasets)
|
|
assert query_time < 1.0, f"Query took {query_time} seconds, should be under 1 second"
|
|
|
|
def test_jsonb_field_operations(self, db_session):
|
|
"""Test JSONB field operations for v2 content."""
|
|
from src.database.models import TranscriptionResult
|
|
import uuid
|
|
|
|
# Test complex JSONB data
|
|
complex_content = {
|
|
"segments": [
|
|
{
|
|
"start": 0.0,
|
|
"end": 2.5,
|
|
"text": "Hello world",
|
|
"speaker": "speaker_1",
|
|
"confidence": 0.95
|
|
},
|
|
{
|
|
"start": 2.5,
|
|
"end": 5.0,
|
|
"text": "How are you?",
|
|
"speaker": "speaker_2",
|
|
"confidence": 0.88
|
|
}
|
|
],
|
|
"metadata": {
|
|
"language": "en",
|
|
"model_version": "whisper-large-v3",
|
|
"processing_time": 12.5
|
|
}
|
|
}
|
|
|
|
transcript = self._create_test_transcript(
|
|
db_session,
|
|
pipeline_version="v2",
|
|
content=complex_content,
|
|
enhanced_content={"enhanced_segments": complex_content["segments"]},
|
|
diarization_content={"speakers": ["speaker_1", "speaker_2"]}
|
|
)
|
|
|
|
# Verify JSONB data is stored correctly
|
|
retrieved = db_session.query(TranscriptionResult).filter_by(id=transcript.id).first()
|
|
assert retrieved.content["segments"][0]["text"] == "Hello world"
|
|
assert retrieved.content["metadata"]["language"] == "en"
|
|
assert len(retrieved.diarization_content["speakers"]) == 2
|
|
|
|
def test_indexes_and_constraints(self, test_db_engine):
|
|
"""Test that proper indexes and constraints are created."""
|
|
with test_db_engine.connect() as conn:
|
|
# Check indexes on speaker_profiles
|
|
result = conn.execute(text("""
|
|
SELECT indexname, indexdef
|
|
FROM pg_indexes
|
|
WHERE tablename = 'speaker_profiles'
|
|
"""))
|
|
speaker_indexes = [row[0] for row in result.fetchall()]
|
|
|
|
# Should have primary key index
|
|
assert any('pkey' in idx.lower() for idx in speaker_indexes), "Should have primary key index"
|
|
|
|
# Check indexes on v2_processing_jobs
|
|
result = conn.execute(text("""
|
|
SELECT indexname, indexdef
|
|
FROM pg_indexes
|
|
WHERE tablename = 'v2_processing_jobs'
|
|
"""))
|
|
job_indexes = [row[0] for row in result.fetchall()]
|
|
|
|
# Should have primary key index
|
|
assert any('pkey' in idx.lower() for idx in job_indexes), "Should have primary key index"
|
|
|
|
def test_timestamp_auto_updating(self, db_session):
|
|
"""Test that timestamp fields auto-update correctly."""
|
|
from src.database.models import SpeakerProfile
|
|
|
|
# Create profile
|
|
profile = SpeakerProfile(name="Test Speaker", user_id=1)
|
|
db_session.add(profile)
|
|
db_session.commit()
|
|
|
|
original_updated_at = profile.updated_at
|
|
|
|
# Wait a moment
|
|
import time
|
|
time.sleep(0.1)
|
|
|
|
# Update profile
|
|
profile.name = "Updated Speaker"
|
|
db_session.commit()
|
|
|
|
# Check that updated_at changed
|
|
assert profile.updated_at > original_updated_at, "updated_at should be updated on modification"
|
|
|
|
def test_null_handling_for_backward_compatibility(self, db_session):
|
|
"""Test that NULL values are handled correctly for backward compatibility."""
|
|
from src.database.models import TranscriptionResult
|
|
import uuid
|
|
|
|
# Create transcript with minimal v1 data
|
|
minimal_transcript = self._create_test_transcript(
|
|
db_session,
|
|
content={"text": "Minimal content"}
|
|
# All v2 fields should be NULL by default
|
|
)
|
|
|
|
# Verify v2 fields are NULL
|
|
retrieved = db_session.query(TranscriptionResult).filter_by(id=minimal_transcript.id).first()
|
|
assert retrieved.enhanced_content is None
|
|
assert retrieved.diarization_content is None
|
|
assert retrieved.merged_content is None
|
|
assert retrieved.model_used is None
|
|
assert retrieved.domain_used is None
|
|
assert retrieved.accuracy_estimate is None
|
|
assert retrieved.confidence_scores is None
|
|
assert retrieved.speaker_count is None
|
|
assert retrieved.quality_warnings is None
|
|
assert retrieved.processing_metadata is None
|