"""Unit tests for v2 repository layer (Task 6). Tests the repository classes for speaker profiles and processing jobs that provide clean data access layer for the new v2 schema. """ import pytest from datetime import datetime, timezone from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from typing import List, Optional, Dict, Any import json from src.database.models import Base, register_model from src.database.connection import get_database_url class TestSpeakerProfileRepository: """Test suite for SpeakerProfileRepository.""" @pytest.fixture(scope="class") def test_db_engine(self): """Create test database engine.""" test_db_url = get_database_url().replace("/trax", "/trax_test") engine = create_engine(test_db_url) Base.metadata.create_all(engine) yield engine Base.metadata.drop_all(engine) engine.dispose() @pytest.fixture def db_session(self, test_db_engine): """Create database session for tests.""" Session = sessionmaker(bind=test_db_engine) session = Session() yield session session.rollback() session.close() @pytest.fixture def speaker_repo(self, db_session): """Create SpeakerProfileRepository instance.""" from src.repositories.speaker_profile_repository import SpeakerProfileRepository return SpeakerProfileRepository(db_session) def test_create_speaker_profile(self, speaker_repo): """Test creating a new speaker profile.""" # Test data name = "John Doe" user_id = 1 characteristics = { "voice_type": "male", "accent": "american", "speaking_rate": "normal", "pitch": "medium" } embedding = b"test_embedding_data" # Create profile profile = speaker_repo.create( name=name, user_id=user_id, characteristics=characteristics, embedding=embedding ) # Verify creation assert profile.id is not None assert profile.name == name assert profile.user_id == user_id assert profile.characteristics == characteristics assert profile.embedding == embedding assert profile.sample_count == 0 assert profile.created_at is not None assert profile.updated_at is not None def test_create_speaker_profile_minimal(self, speaker_repo): """Test creating a speaker profile with minimal data.""" profile = speaker_repo.create(name="Jane Smith", user_id=2) assert profile.id is not None assert profile.name == "Jane Smith" assert profile.user_id == 2 assert profile.characteristics is None assert profile.embedding is None assert profile.sample_count == 0 def test_get_speaker_profile_by_id(self, speaker_repo): """Test retrieving a speaker profile by ID.""" # Create profile created_profile = speaker_repo.create( name="Test Speaker", user_id=1, characteristics={"voice_type": "female"} ) # Retrieve by ID retrieved_profile = speaker_repo.get_by_id(created_profile.id) assert retrieved_profile is not None assert retrieved_profile.id == created_profile.id assert retrieved_profile.name == "Test Speaker" assert retrieved_profile.characteristics["voice_type"] == "female" def test_get_speaker_profile_by_id_not_found(self, speaker_repo): """Test retrieving a non-existent speaker profile.""" profile = speaker_repo.get_by_id(99999) assert profile is None def test_get_speaker_profiles_by_user(self, speaker_repo): """Test retrieving all speaker profiles for a user.""" user_id = 1 # Create multiple profiles for the same user profile1 = speaker_repo.create(name="Speaker 1", user_id=user_id) profile2 = speaker_repo.create(name="Speaker 2", user_id=user_id) profile3 = speaker_repo.create(name="Speaker 3", user_id=2) # Different user # Get profiles for user_id = 1 user_profiles = speaker_repo.get_by_user(user_id) assert len(user_profiles) == 2 profile_ids = [p.id for p in user_profiles] assert profile1.id in profile_ids assert profile2.id in profile_ids assert profile3.id not in profile_ids def test_update_speaker_profile(self, speaker_repo): """Test updating a speaker profile.""" # Create profile profile = speaker_repo.create( name="Original Name", user_id=1, characteristics={"voice_type": "male"} ) # Update profile updated_profile = speaker_repo.update( profile.id, name="Updated Name", characteristics={"voice_type": "female", "accent": "british"}, sample_count=5 ) assert updated_profile is not None assert updated_profile.name == "Updated Name" assert updated_profile.characteristics["voice_type"] == "female" assert updated_profile.characteristics["accent"] == "british" assert updated_profile.sample_count == 5 assert updated_profile.updated_at > profile.updated_at def test_update_speaker_profile_not_found(self, speaker_repo): """Test updating a non-existent speaker profile.""" result = speaker_repo.update(99999, name="New Name") assert result is None def test_delete_speaker_profile(self, speaker_repo): """Test deleting a speaker profile.""" # Create profile profile = speaker_repo.create(name="To Delete", user_id=1) profile_id = profile.id # Delete profile success = speaker_repo.delete(profile_id) assert success is True # Verify deletion deleted_profile = speaker_repo.get_by_id(profile_id) assert deleted_profile is None def test_delete_speaker_profile_not_found(self, speaker_repo): """Test deleting a non-existent speaker profile.""" success = speaker_repo.delete(99999) assert success is False def test_speaker_profile_relationships(self, speaker_repo, db_session): """Test speaker profile relationships with other entities.""" # This test would verify relationships with users and other entities # Implementation depends on the actual relationship structure pass class TestProcessingJobRepository: """Test suite for ProcessingJobRepository.""" @pytest.fixture(scope="class") def test_db_engine(self): """Create test database engine.""" test_db_url = get_database_url().replace("/trax", "/trax_test") engine = create_engine(test_db_url) Base.metadata.create_all(engine) yield engine Base.metadata.drop_all(engine) engine.dispose() @pytest.fixture def db_session(self, test_db_engine): """Create database session for tests.""" Session = sessionmaker(bind=test_db_engine) session = Session() yield session session.rollback() session.close() @pytest.fixture def job_repo(self, db_session): """Create ProcessingJobRepository instance.""" from src.repositories.processing_job_repository import ProcessingJobRepository return ProcessingJobRepository(db_session) def test_create_processing_job(self, job_repo): """Test creating a new processing job.""" transcript_id = 1 job_type = "enhancement" parameters = { "model": "gpt-4", "temperature": 0.7, "max_tokens": 1000 } # Create job job = job_repo.create( transcript_id=transcript_id, job_type=job_type, parameters=parameters ) # Verify creation assert job.id is not None assert job.transcript_id == transcript_id assert job.job_type == job_type assert job.parameters == parameters assert job.status == "pending" assert job.progress == 0 assert job.created_at is not None assert job.updated_at is not None assert job.completed_at is None assert job.error_message is None assert job.result_data is None def test_create_processing_job_minimal(self, job_repo): """Test creating a processing job with minimal data.""" job = job_repo.create(transcript_id=1, job_type="transcription") assert job.id is not None assert job.transcript_id == 1 assert job.job_type == "transcription" assert job.parameters is None assert job.status == "pending" assert job.progress == 0 def test_get_processing_job_by_id(self, job_repo): """Test retrieving a processing job by ID.""" # Create job created_job = job_repo.create( transcript_id=1, job_type="enhancement", parameters={"model": "gpt-4"} ) # Retrieve by ID retrieved_job = job_repo.get_by_id(created_job.id) assert retrieved_job is not None assert retrieved_job.id == created_job.id assert retrieved_job.transcript_id == 1 assert retrieved_job.job_type == "enhancement" assert retrieved_job.parameters["model"] == "gpt-4" def test_get_processing_job_by_id_not_found(self, job_repo): """Test retrieving a non-existent processing job.""" job = job_repo.get_by_id(99999) assert job is None def test_get_processing_jobs_by_transcript(self, job_repo): """Test retrieving all processing jobs for a transcript.""" transcript_id = 1 # Create multiple jobs for the same transcript job1 = job_repo.create(transcript_id=transcript_id, job_type="transcription") job2 = job_repo.create(transcript_id=transcript_id, job_type="enhancement") job3 = job_repo.create(transcript_id=2, job_type="transcription") # Different transcript # Get jobs for transcript_id = 1 transcript_jobs = job_repo.get_by_transcript(transcript_id) assert len(transcript_jobs) == 2 job_ids = [j.id for j in transcript_jobs] assert job1.id in job_ids assert job2.id in job_ids assert job3.id not in job_ids def test_update_processing_job_status(self, job_repo): """Test updating processing job status.""" # Create job job = job_repo.create(transcript_id=1, job_type="enhancement") # Update status updated_job = job_repo.update_status( job.id, status="processing", progress=0.5 ) assert updated_job is not None assert updated_job.status == "processing" assert updated_job.progress == 0.5 assert updated_job.updated_at > job.updated_at assert updated_job.completed_at is None def test_update_processing_job_completed(self, job_repo): """Test updating processing job to completed status.""" # Create job job = job_repo.create(transcript_id=1, job_type="enhancement") # Update to completed result_data = {"enhanced_text": "Enhanced content", "confidence": 0.95} updated_job = job_repo.update_status( job.id, status="completed", progress=1.0, result_data=result_data ) assert updated_job is not None assert updated_job.status == "completed" assert updated_job.progress == 1.0 assert updated_job.result_data == result_data assert updated_job.completed_at is not None def test_update_processing_job_failed(self, job_repo): """Test updating processing job to failed status.""" # Create job job = job_repo.create(transcript_id=1, job_type="enhancement") # Update to failed error_message = "Model API rate limit exceeded" updated_job = job_repo.update_status( job.id, status="failed", error_message=error_message ) assert updated_job is not None assert updated_job.status == "failed" assert updated_job.error_message == error_message assert updated_job.completed_at is None def test_update_processing_job_not_found(self, job_repo): """Test updating a non-existent processing job.""" result = job_repo.update_status(99999, status="processing") assert result is None def test_delete_processing_job(self, job_repo): """Test deleting a processing job.""" # Create job job = job_repo.create(transcript_id=1, job_type="enhancement") job_id = job.id # Delete job success = job_repo.delete(job_id) assert success is True # Verify deletion deleted_job = job_repo.get_by_id(job_id) assert deleted_job is None def test_delete_processing_job_not_found(self, job_repo): """Test deleting a non-existent processing job.""" success = job_repo.delete(99999) assert success is False def test_processing_job_status_transitions(self, job_repo): """Test valid status transitions for processing jobs.""" # Create job job = job_repo.create(transcript_id=1, job_type="enhancement") # Test status transitions statuses = ["pending", "processing", "completed"] for status in statuses: updated_job = job_repo.update_status(job.id, status=status) assert updated_job.status == status def test_processing_job_progress_tracking(self, job_repo): """Test progress tracking for processing jobs.""" # Create job job = job_repo.create(transcript_id=1, job_type="enhancement") # Update progress in steps progress_steps = [0.25, 0.5, 0.75, 1.0] for progress in progress_steps: updated_job = job_repo.update_status(job.id, progress=progress) assert updated_job.progress == progress def test_processing_job_relationships(self, job_repo, db_session): """Test processing job relationships with transcripts.""" # This test would verify relationships with transcripts and other entities # Implementation depends on the actual relationship structure pass class TestBackwardCompatibility: """Test suite for backward compatibility layer.""" @pytest.fixture(scope="class") def test_db_engine(self): """Create test database engine.""" test_db_url = get_database_url().replace("/trax", "/trax_test") engine = create_engine(test_db_url) Base.metadata.create_all(engine) yield engine Base.metadata.drop_all(engine) engine.dispose() @pytest.fixture def db_session(self, test_db_engine): """Create database session for tests.""" Session = sessionmaker(bind=test_db_engine) session = Session() yield session session.rollback() session.close() def test_v2_to_v1_format_conversion(self, db_session): """Test converting v2 transcript to v1 format.""" from src.database.models import TranscriptionResult from src.compatibility.backward_compatibility import TranscriptBackwardCompatibility # Create v2 transcript with complex data v2_transcript = TranscriptionResult( pipeline_version="v2", content={"text": "Original transcript text"}, enhanced_content={"enhanced_text": "Enhanced version"}, diarization_content={ "speakers": [ {"id": 1, "name": "Speaker 1", "segments": [{"start": 0, "end": 5, "text": "Hello"}]}, {"id": 2, "name": "Speaker 2", "segments": [{"start": 5, "end": 10, "text": "World"}]} ] }, merged_content={"text": "Hello World", "speakers": ["Speaker 1", "Speaker 2"]}, model_used="whisper-large-v3", accuracy_estimate=0.95 ) db_session.add(v2_transcript) db_session.commit() # Convert to v1 format v1_format = TranscriptBackwardCompatibility.to_v1_format(v2_transcript) # Verify v1 format structure assert "id" in v1_format assert "content" in v1_format assert "created_at" in v1_format assert "updated_at" in v1_format # Verify content is merged appropriately assert "Hello World" in v1_format["content"] def test_v1_to_v2_update(self, db_session): """Test updating v2 transcript from v1 format request.""" from src.database.models import TranscriptionResult from src.compatibility.backward_compatibility import TranscriptBackwardCompatibility # Create v2 transcript v2_transcript = TranscriptionResult( pipeline_version="v2", content={"text": "Original content"}, enhanced_content={"enhanced": True} ) db_session.add(v2_transcript) db_session.commit() # Update from v1 request v1_data = { "title": "Updated Title", "content": "Updated content from v1 client" } TranscriptBackwardCompatibility.update_from_v1_request(v2_transcript, v1_data) db_session.commit() # Verify updates assert v2_transcript.content["text"] == "Updated content from v1 client" assert v2_transcript.processing_metadata["v1_update"] is True def test_extract_merged_content(self, db_session): """Test extracting plain text from merged content.""" from src.compatibility.backward_compatibility import TranscriptBackwardCompatibility # Test with text field merged_content_with_text = {"text": "Simple text content"} extracted = TranscriptBackwardCompatibility._extract_merged_content(merged_content_with_text) assert extracted == "Simple text content" # Test with segments merged_content_with_segments = { "segments": [ {"text": "Hello", "start": 0, "end": 2}, {"text": "World", "start": 2, "end": 4} ] } extracted = TranscriptBackwardCompatibility._extract_merged_content(merged_content_with_segments) assert extracted == "Hello World" # Test with empty content extracted = TranscriptBackwardCompatibility._extract_merged_content(None) assert extracted == "" # Test with complex structure complex_content = {"metadata": {"text": "Nested text"}} extracted = TranscriptBackwardCompatibility._extract_merged_content(complex_content) assert extracted == str(complex_content)