trax/tests/test_v2_repositories.py

517 lines
19 KiB
Python

"""Unit tests for v2 repository layer (Task 6).
Tests the repository classes for speaker profiles and processing jobs
that provide clean data access layer for the new v2 schema.
"""
import pytest
from datetime import datetime, timezone
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from typing import List, Optional, Dict, Any
import json
from src.database.models import Base, register_model
from src.database.connection import get_database_url
class TestSpeakerProfileRepository:
"""Test suite for SpeakerProfileRepository."""
@pytest.fixture(scope="class")
def test_db_engine(self):
"""Create test database engine."""
test_db_url = get_database_url().replace("/trax", "/trax_test")
engine = create_engine(test_db_url)
Base.metadata.create_all(engine)
yield engine
Base.metadata.drop_all(engine)
engine.dispose()
@pytest.fixture
def db_session(self, test_db_engine):
"""Create database session for tests."""
Session = sessionmaker(bind=test_db_engine)
session = Session()
yield session
session.rollback()
session.close()
@pytest.fixture
def speaker_repo(self, db_session):
"""Create SpeakerProfileRepository instance."""
from src.repositories.speaker_profile_repository import SpeakerProfileRepository
return SpeakerProfileRepository(db_session)
def test_create_speaker_profile(self, speaker_repo):
"""Test creating a new speaker profile."""
# Test data
name = "John Doe"
user_id = 1
characteristics = {
"voice_type": "male",
"accent": "american",
"speaking_rate": "normal",
"pitch": "medium"
}
embedding = b"test_embedding_data"
# Create profile
profile = speaker_repo.create(
name=name,
user_id=user_id,
characteristics=characteristics,
embedding=embedding
)
# Verify creation
assert profile.id is not None
assert profile.name == name
assert profile.user_id == user_id
assert profile.characteristics == characteristics
assert profile.embedding == embedding
assert profile.sample_count == 0
assert profile.created_at is not None
assert profile.updated_at is not None
def test_create_speaker_profile_minimal(self, speaker_repo):
"""Test creating a speaker profile with minimal data."""
profile = speaker_repo.create(name="Jane Smith", user_id=2)
assert profile.id is not None
assert profile.name == "Jane Smith"
assert profile.user_id == 2
assert profile.characteristics is None
assert profile.embedding is None
assert profile.sample_count == 0
def test_get_speaker_profile_by_id(self, speaker_repo):
"""Test retrieving a speaker profile by ID."""
# Create profile
created_profile = speaker_repo.create(
name="Test Speaker",
user_id=1,
characteristics={"voice_type": "female"}
)
# Retrieve by ID
retrieved_profile = speaker_repo.get_by_id(created_profile.id)
assert retrieved_profile is not None
assert retrieved_profile.id == created_profile.id
assert retrieved_profile.name == "Test Speaker"
assert retrieved_profile.characteristics["voice_type"] == "female"
def test_get_speaker_profile_by_id_not_found(self, speaker_repo):
"""Test retrieving a non-existent speaker profile."""
profile = speaker_repo.get_by_id(99999)
assert profile is None
def test_get_speaker_profiles_by_user(self, speaker_repo):
"""Test retrieving all speaker profiles for a user."""
user_id = 1
# Create multiple profiles for the same user
profile1 = speaker_repo.create(name="Speaker 1", user_id=user_id)
profile2 = speaker_repo.create(name="Speaker 2", user_id=user_id)
profile3 = speaker_repo.create(name="Speaker 3", user_id=2) # Different user
# Get profiles for user_id = 1
user_profiles = speaker_repo.get_by_user(user_id)
assert len(user_profiles) == 2
profile_ids = [p.id for p in user_profiles]
assert profile1.id in profile_ids
assert profile2.id in profile_ids
assert profile3.id not in profile_ids
def test_update_speaker_profile(self, speaker_repo):
"""Test updating a speaker profile."""
# Create profile
profile = speaker_repo.create(
name="Original Name",
user_id=1,
characteristics={"voice_type": "male"}
)
# Update profile
updated_profile = speaker_repo.update(
profile.id,
name="Updated Name",
characteristics={"voice_type": "female", "accent": "british"},
sample_count=5
)
assert updated_profile is not None
assert updated_profile.name == "Updated Name"
assert updated_profile.characteristics["voice_type"] == "female"
assert updated_profile.characteristics["accent"] == "british"
assert updated_profile.sample_count == 5
assert updated_profile.updated_at > profile.updated_at
def test_update_speaker_profile_not_found(self, speaker_repo):
"""Test updating a non-existent speaker profile."""
result = speaker_repo.update(99999, name="New Name")
assert result is None
def test_delete_speaker_profile(self, speaker_repo):
"""Test deleting a speaker profile."""
# Create profile
profile = speaker_repo.create(name="To Delete", user_id=1)
profile_id = profile.id
# Delete profile
success = speaker_repo.delete(profile_id)
assert success is True
# Verify deletion
deleted_profile = speaker_repo.get_by_id(profile_id)
assert deleted_profile is None
def test_delete_speaker_profile_not_found(self, speaker_repo):
"""Test deleting a non-existent speaker profile."""
success = speaker_repo.delete(99999)
assert success is False
def test_speaker_profile_relationships(self, speaker_repo, db_session):
"""Test speaker profile relationships with other entities."""
# This test would verify relationships with users and other entities
# Implementation depends on the actual relationship structure
pass
class TestProcessingJobRepository:
"""Test suite for ProcessingJobRepository."""
@pytest.fixture(scope="class")
def test_db_engine(self):
"""Create test database engine."""
test_db_url = get_database_url().replace("/trax", "/trax_test")
engine = create_engine(test_db_url)
Base.metadata.create_all(engine)
yield engine
Base.metadata.drop_all(engine)
engine.dispose()
@pytest.fixture
def db_session(self, test_db_engine):
"""Create database session for tests."""
Session = sessionmaker(bind=test_db_engine)
session = Session()
yield session
session.rollback()
session.close()
@pytest.fixture
def job_repo(self, db_session):
"""Create ProcessingJobRepository instance."""
from src.repositories.processing_job_repository import ProcessingJobRepository
return ProcessingJobRepository(db_session)
def test_create_processing_job(self, job_repo):
"""Test creating a new processing job."""
transcript_id = 1
job_type = "enhancement"
parameters = {
"model": "gpt-4",
"temperature": 0.7,
"max_tokens": 1000
}
# Create job
job = job_repo.create(
transcript_id=transcript_id,
job_type=job_type,
parameters=parameters
)
# Verify creation
assert job.id is not None
assert job.transcript_id == transcript_id
assert job.job_type == job_type
assert job.parameters == parameters
assert job.status == "pending"
assert job.progress == 0
assert job.created_at is not None
assert job.updated_at is not None
assert job.completed_at is None
assert job.error_message is None
assert job.result_data is None
def test_create_processing_job_minimal(self, job_repo):
"""Test creating a processing job with minimal data."""
job = job_repo.create(transcript_id=1, job_type="transcription")
assert job.id is not None
assert job.transcript_id == 1
assert job.job_type == "transcription"
assert job.parameters is None
assert job.status == "pending"
assert job.progress == 0
def test_get_processing_job_by_id(self, job_repo):
"""Test retrieving a processing job by ID."""
# Create job
created_job = job_repo.create(
transcript_id=1,
job_type="enhancement",
parameters={"model": "gpt-4"}
)
# Retrieve by ID
retrieved_job = job_repo.get_by_id(created_job.id)
assert retrieved_job is not None
assert retrieved_job.id == created_job.id
assert retrieved_job.transcript_id == 1
assert retrieved_job.job_type == "enhancement"
assert retrieved_job.parameters["model"] == "gpt-4"
def test_get_processing_job_by_id_not_found(self, job_repo):
"""Test retrieving a non-existent processing job."""
job = job_repo.get_by_id(99999)
assert job is None
def test_get_processing_jobs_by_transcript(self, job_repo):
"""Test retrieving all processing jobs for a transcript."""
transcript_id = 1
# Create multiple jobs for the same transcript
job1 = job_repo.create(transcript_id=transcript_id, job_type="transcription")
job2 = job_repo.create(transcript_id=transcript_id, job_type="enhancement")
job3 = job_repo.create(transcript_id=2, job_type="transcription") # Different transcript
# Get jobs for transcript_id = 1
transcript_jobs = job_repo.get_by_transcript(transcript_id)
assert len(transcript_jobs) == 2
job_ids = [j.id for j in transcript_jobs]
assert job1.id in job_ids
assert job2.id in job_ids
assert job3.id not in job_ids
def test_update_processing_job_status(self, job_repo):
"""Test updating processing job status."""
# Create job
job = job_repo.create(transcript_id=1, job_type="enhancement")
# Update status
updated_job = job_repo.update_status(
job.id,
status="processing",
progress=0.5
)
assert updated_job is not None
assert updated_job.status == "processing"
assert updated_job.progress == 0.5
assert updated_job.updated_at > job.updated_at
assert updated_job.completed_at is None
def test_update_processing_job_completed(self, job_repo):
"""Test updating processing job to completed status."""
# Create job
job = job_repo.create(transcript_id=1, job_type="enhancement")
# Update to completed
result_data = {"enhanced_text": "Enhanced content", "confidence": 0.95}
updated_job = job_repo.update_status(
job.id,
status="completed",
progress=1.0,
result_data=result_data
)
assert updated_job is not None
assert updated_job.status == "completed"
assert updated_job.progress == 1.0
assert updated_job.result_data == result_data
assert updated_job.completed_at is not None
def test_update_processing_job_failed(self, job_repo):
"""Test updating processing job to failed status."""
# Create job
job = job_repo.create(transcript_id=1, job_type="enhancement")
# Update to failed
error_message = "Model API rate limit exceeded"
updated_job = job_repo.update_status(
job.id,
status="failed",
error_message=error_message
)
assert updated_job is not None
assert updated_job.status == "failed"
assert updated_job.error_message == error_message
assert updated_job.completed_at is None
def test_update_processing_job_not_found(self, job_repo):
"""Test updating a non-existent processing job."""
result = job_repo.update_status(99999, status="processing")
assert result is None
def test_delete_processing_job(self, job_repo):
"""Test deleting a processing job."""
# Create job
job = job_repo.create(transcript_id=1, job_type="enhancement")
job_id = job.id
# Delete job
success = job_repo.delete(job_id)
assert success is True
# Verify deletion
deleted_job = job_repo.get_by_id(job_id)
assert deleted_job is None
def test_delete_processing_job_not_found(self, job_repo):
"""Test deleting a non-existent processing job."""
success = job_repo.delete(99999)
assert success is False
def test_processing_job_status_transitions(self, job_repo):
"""Test valid status transitions for processing jobs."""
# Create job
job = job_repo.create(transcript_id=1, job_type="enhancement")
# Test status transitions
statuses = ["pending", "processing", "completed"]
for status in statuses:
updated_job = job_repo.update_status(job.id, status=status)
assert updated_job.status == status
def test_processing_job_progress_tracking(self, job_repo):
"""Test progress tracking for processing jobs."""
# Create job
job = job_repo.create(transcript_id=1, job_type="enhancement")
# Update progress in steps
progress_steps = [0.25, 0.5, 0.75, 1.0]
for progress in progress_steps:
updated_job = job_repo.update_status(job.id, progress=progress)
assert updated_job.progress == progress
def test_processing_job_relationships(self, job_repo, db_session):
"""Test processing job relationships with transcripts."""
# This test would verify relationships with transcripts and other entities
# Implementation depends on the actual relationship structure
pass
class TestBackwardCompatibility:
"""Test suite for backward compatibility layer."""
@pytest.fixture(scope="class")
def test_db_engine(self):
"""Create test database engine."""
test_db_url = get_database_url().replace("/trax", "/trax_test")
engine = create_engine(test_db_url)
Base.metadata.create_all(engine)
yield engine
Base.metadata.drop_all(engine)
engine.dispose()
@pytest.fixture
def db_session(self, test_db_engine):
"""Create database session for tests."""
Session = sessionmaker(bind=test_db_engine)
session = Session()
yield session
session.rollback()
session.close()
def test_v2_to_v1_format_conversion(self, db_session):
"""Test converting v2 transcript to v1 format."""
from src.database.models import TranscriptionResult
from src.compatibility.backward_compatibility import TranscriptBackwardCompatibility
# Create v2 transcript with complex data
v2_transcript = TranscriptionResult(
pipeline_version="v2",
content={"text": "Original transcript text"},
enhanced_content={"enhanced_text": "Enhanced version"},
diarization_content={
"speakers": [
{"id": 1, "name": "Speaker 1", "segments": [{"start": 0, "end": 5, "text": "Hello"}]},
{"id": 2, "name": "Speaker 2", "segments": [{"start": 5, "end": 10, "text": "World"}]}
]
},
merged_content={"text": "Hello World", "speakers": ["Speaker 1", "Speaker 2"]},
model_used="whisper-large-v3",
accuracy_estimate=0.95
)
db_session.add(v2_transcript)
db_session.commit()
# Convert to v1 format
v1_format = TranscriptBackwardCompatibility.to_v1_format(v2_transcript)
# Verify v1 format structure
assert "id" in v1_format
assert "content" in v1_format
assert "created_at" in v1_format
assert "updated_at" in v1_format
# Verify content is merged appropriately
assert "Hello World" in v1_format["content"]
def test_v1_to_v2_update(self, db_session):
"""Test updating v2 transcript from v1 format request."""
from src.database.models import TranscriptionResult
from src.compatibility.backward_compatibility import TranscriptBackwardCompatibility
# Create v2 transcript
v2_transcript = TranscriptionResult(
pipeline_version="v2",
content={"text": "Original content"},
enhanced_content={"enhanced": True}
)
db_session.add(v2_transcript)
db_session.commit()
# Update from v1 request
v1_data = {
"title": "Updated Title",
"content": "Updated content from v1 client"
}
TranscriptBackwardCompatibility.update_from_v1_request(v2_transcript, v1_data)
db_session.commit()
# Verify updates
assert v2_transcript.content["text"] == "Updated content from v1 client"
assert v2_transcript.processing_metadata["v1_update"] is True
def test_extract_merged_content(self, db_session):
"""Test extracting plain text from merged content."""
from src.compatibility.backward_compatibility import TranscriptBackwardCompatibility
# Test with text field
merged_content_with_text = {"text": "Simple text content"}
extracted = TranscriptBackwardCompatibility._extract_merged_content(merged_content_with_text)
assert extracted == "Simple text content"
# Test with segments
merged_content_with_segments = {
"segments": [
{"text": "Hello", "start": 0, "end": 2},
{"text": "World", "start": 2, "end": 4}
]
}
extracted = TranscriptBackwardCompatibility._extract_merged_content(merged_content_with_segments)
assert extracted == "Hello World"
# Test with empty content
extracted = TranscriptBackwardCompatibility._extract_merged_content(None)
assert extracted == ""
# Test with complex structure
complex_content = {"metadata": {"text": "Nested text"}}
extracted = TranscriptBackwardCompatibility._extract_merged_content(complex_content)
assert extracted == str(complex_content)