517 lines
19 KiB
Python
517 lines
19 KiB
Python
"""Unit tests for v2 repository layer (Task 6).
|
|
|
|
Tests the repository classes for speaker profiles and processing jobs
|
|
that provide clean data access layer for the new v2 schema.
|
|
"""
|
|
|
|
import pytest
|
|
from datetime import datetime, timezone
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.orm import sessionmaker
|
|
from typing import List, Optional, Dict, Any
|
|
import json
|
|
|
|
from src.database.models import Base, register_model
|
|
from src.database.connection import get_database_url
|
|
|
|
|
|
class TestSpeakerProfileRepository:
|
|
"""Test suite for SpeakerProfileRepository."""
|
|
|
|
@pytest.fixture(scope="class")
|
|
def test_db_engine(self):
|
|
"""Create test database engine."""
|
|
test_db_url = get_database_url().replace("/trax", "/trax_test")
|
|
engine = create_engine(test_db_url)
|
|
Base.metadata.create_all(engine)
|
|
yield engine
|
|
Base.metadata.drop_all(engine)
|
|
engine.dispose()
|
|
|
|
@pytest.fixture
|
|
def db_session(self, test_db_engine):
|
|
"""Create database session for tests."""
|
|
Session = sessionmaker(bind=test_db_engine)
|
|
session = Session()
|
|
yield session
|
|
session.rollback()
|
|
session.close()
|
|
|
|
@pytest.fixture
|
|
def speaker_repo(self, db_session):
|
|
"""Create SpeakerProfileRepository instance."""
|
|
from src.repositories.speaker_profile_repository import SpeakerProfileRepository
|
|
return SpeakerProfileRepository(db_session)
|
|
|
|
def test_create_speaker_profile(self, speaker_repo):
|
|
"""Test creating a new speaker profile."""
|
|
# Test data
|
|
name = "John Doe"
|
|
user_id = 1
|
|
characteristics = {
|
|
"voice_type": "male",
|
|
"accent": "american",
|
|
"speaking_rate": "normal",
|
|
"pitch": "medium"
|
|
}
|
|
embedding = b"test_embedding_data"
|
|
|
|
# Create profile
|
|
profile = speaker_repo.create(
|
|
name=name,
|
|
user_id=user_id,
|
|
characteristics=characteristics,
|
|
embedding=embedding
|
|
)
|
|
|
|
# Verify creation
|
|
assert profile.id is not None
|
|
assert profile.name == name
|
|
assert profile.user_id == user_id
|
|
assert profile.characteristics == characteristics
|
|
assert profile.embedding == embedding
|
|
assert profile.sample_count == 0
|
|
assert profile.created_at is not None
|
|
assert profile.updated_at is not None
|
|
|
|
def test_create_speaker_profile_minimal(self, speaker_repo):
|
|
"""Test creating a speaker profile with minimal data."""
|
|
profile = speaker_repo.create(name="Jane Smith", user_id=2)
|
|
|
|
assert profile.id is not None
|
|
assert profile.name == "Jane Smith"
|
|
assert profile.user_id == 2
|
|
assert profile.characteristics is None
|
|
assert profile.embedding is None
|
|
assert profile.sample_count == 0
|
|
|
|
def test_get_speaker_profile_by_id(self, speaker_repo):
|
|
"""Test retrieving a speaker profile by ID."""
|
|
# Create profile
|
|
created_profile = speaker_repo.create(
|
|
name="Test Speaker",
|
|
user_id=1,
|
|
characteristics={"voice_type": "female"}
|
|
)
|
|
|
|
# Retrieve by ID
|
|
retrieved_profile = speaker_repo.get_by_id(created_profile.id)
|
|
|
|
assert retrieved_profile is not None
|
|
assert retrieved_profile.id == created_profile.id
|
|
assert retrieved_profile.name == "Test Speaker"
|
|
assert retrieved_profile.characteristics["voice_type"] == "female"
|
|
|
|
def test_get_speaker_profile_by_id_not_found(self, speaker_repo):
|
|
"""Test retrieving a non-existent speaker profile."""
|
|
profile = speaker_repo.get_by_id(99999)
|
|
assert profile is None
|
|
|
|
def test_get_speaker_profiles_by_user(self, speaker_repo):
|
|
"""Test retrieving all speaker profiles for a user."""
|
|
user_id = 1
|
|
|
|
# Create multiple profiles for the same user
|
|
profile1 = speaker_repo.create(name="Speaker 1", user_id=user_id)
|
|
profile2 = speaker_repo.create(name="Speaker 2", user_id=user_id)
|
|
profile3 = speaker_repo.create(name="Speaker 3", user_id=2) # Different user
|
|
|
|
# Get profiles for user_id = 1
|
|
user_profiles = speaker_repo.get_by_user(user_id)
|
|
|
|
assert len(user_profiles) == 2
|
|
profile_ids = [p.id for p in user_profiles]
|
|
assert profile1.id in profile_ids
|
|
assert profile2.id in profile_ids
|
|
assert profile3.id not in profile_ids
|
|
|
|
def test_update_speaker_profile(self, speaker_repo):
|
|
"""Test updating a speaker profile."""
|
|
# Create profile
|
|
profile = speaker_repo.create(
|
|
name="Original Name",
|
|
user_id=1,
|
|
characteristics={"voice_type": "male"}
|
|
)
|
|
|
|
# Update profile
|
|
updated_profile = speaker_repo.update(
|
|
profile.id,
|
|
name="Updated Name",
|
|
characteristics={"voice_type": "female", "accent": "british"},
|
|
sample_count=5
|
|
)
|
|
|
|
assert updated_profile is not None
|
|
assert updated_profile.name == "Updated Name"
|
|
assert updated_profile.characteristics["voice_type"] == "female"
|
|
assert updated_profile.characteristics["accent"] == "british"
|
|
assert updated_profile.sample_count == 5
|
|
assert updated_profile.updated_at > profile.updated_at
|
|
|
|
def test_update_speaker_profile_not_found(self, speaker_repo):
|
|
"""Test updating a non-existent speaker profile."""
|
|
result = speaker_repo.update(99999, name="New Name")
|
|
assert result is None
|
|
|
|
def test_delete_speaker_profile(self, speaker_repo):
|
|
"""Test deleting a speaker profile."""
|
|
# Create profile
|
|
profile = speaker_repo.create(name="To Delete", user_id=1)
|
|
profile_id = profile.id
|
|
|
|
# Delete profile
|
|
success = speaker_repo.delete(profile_id)
|
|
|
|
assert success is True
|
|
|
|
# Verify deletion
|
|
deleted_profile = speaker_repo.get_by_id(profile_id)
|
|
assert deleted_profile is None
|
|
|
|
def test_delete_speaker_profile_not_found(self, speaker_repo):
|
|
"""Test deleting a non-existent speaker profile."""
|
|
success = speaker_repo.delete(99999)
|
|
assert success is False
|
|
|
|
def test_speaker_profile_relationships(self, speaker_repo, db_session):
|
|
"""Test speaker profile relationships with other entities."""
|
|
# This test would verify relationships with users and other entities
|
|
# Implementation depends on the actual relationship structure
|
|
pass
|
|
|
|
|
|
class TestProcessingJobRepository:
|
|
"""Test suite for ProcessingJobRepository."""
|
|
|
|
@pytest.fixture(scope="class")
|
|
def test_db_engine(self):
|
|
"""Create test database engine."""
|
|
test_db_url = get_database_url().replace("/trax", "/trax_test")
|
|
engine = create_engine(test_db_url)
|
|
Base.metadata.create_all(engine)
|
|
yield engine
|
|
Base.metadata.drop_all(engine)
|
|
engine.dispose()
|
|
|
|
@pytest.fixture
|
|
def db_session(self, test_db_engine):
|
|
"""Create database session for tests."""
|
|
Session = sessionmaker(bind=test_db_engine)
|
|
session = Session()
|
|
yield session
|
|
session.rollback()
|
|
session.close()
|
|
|
|
@pytest.fixture
|
|
def job_repo(self, db_session):
|
|
"""Create ProcessingJobRepository instance."""
|
|
from src.repositories.processing_job_repository import ProcessingJobRepository
|
|
return ProcessingJobRepository(db_session)
|
|
|
|
def test_create_processing_job(self, job_repo):
|
|
"""Test creating a new processing job."""
|
|
transcript_id = 1
|
|
job_type = "enhancement"
|
|
parameters = {
|
|
"model": "gpt-4",
|
|
"temperature": 0.7,
|
|
"max_tokens": 1000
|
|
}
|
|
|
|
# Create job
|
|
job = job_repo.create(
|
|
transcript_id=transcript_id,
|
|
job_type=job_type,
|
|
parameters=parameters
|
|
)
|
|
|
|
# Verify creation
|
|
assert job.id is not None
|
|
assert job.transcript_id == transcript_id
|
|
assert job.job_type == job_type
|
|
assert job.parameters == parameters
|
|
assert job.status == "pending"
|
|
assert job.progress == 0
|
|
assert job.created_at is not None
|
|
assert job.updated_at is not None
|
|
assert job.completed_at is None
|
|
assert job.error_message is None
|
|
assert job.result_data is None
|
|
|
|
def test_create_processing_job_minimal(self, job_repo):
|
|
"""Test creating a processing job with minimal data."""
|
|
job = job_repo.create(transcript_id=1, job_type="transcription")
|
|
|
|
assert job.id is not None
|
|
assert job.transcript_id == 1
|
|
assert job.job_type == "transcription"
|
|
assert job.parameters is None
|
|
assert job.status == "pending"
|
|
assert job.progress == 0
|
|
|
|
def test_get_processing_job_by_id(self, job_repo):
|
|
"""Test retrieving a processing job by ID."""
|
|
# Create job
|
|
created_job = job_repo.create(
|
|
transcript_id=1,
|
|
job_type="enhancement",
|
|
parameters={"model": "gpt-4"}
|
|
)
|
|
|
|
# Retrieve by ID
|
|
retrieved_job = job_repo.get_by_id(created_job.id)
|
|
|
|
assert retrieved_job is not None
|
|
assert retrieved_job.id == created_job.id
|
|
assert retrieved_job.transcript_id == 1
|
|
assert retrieved_job.job_type == "enhancement"
|
|
assert retrieved_job.parameters["model"] == "gpt-4"
|
|
|
|
def test_get_processing_job_by_id_not_found(self, job_repo):
|
|
"""Test retrieving a non-existent processing job."""
|
|
job = job_repo.get_by_id(99999)
|
|
assert job is None
|
|
|
|
def test_get_processing_jobs_by_transcript(self, job_repo):
|
|
"""Test retrieving all processing jobs for a transcript."""
|
|
transcript_id = 1
|
|
|
|
# Create multiple jobs for the same transcript
|
|
job1 = job_repo.create(transcript_id=transcript_id, job_type="transcription")
|
|
job2 = job_repo.create(transcript_id=transcript_id, job_type="enhancement")
|
|
job3 = job_repo.create(transcript_id=2, job_type="transcription") # Different transcript
|
|
|
|
# Get jobs for transcript_id = 1
|
|
transcript_jobs = job_repo.get_by_transcript(transcript_id)
|
|
|
|
assert len(transcript_jobs) == 2
|
|
job_ids = [j.id for j in transcript_jobs]
|
|
assert job1.id in job_ids
|
|
assert job2.id in job_ids
|
|
assert job3.id not in job_ids
|
|
|
|
def test_update_processing_job_status(self, job_repo):
|
|
"""Test updating processing job status."""
|
|
# Create job
|
|
job = job_repo.create(transcript_id=1, job_type="enhancement")
|
|
|
|
# Update status
|
|
updated_job = job_repo.update_status(
|
|
job.id,
|
|
status="processing",
|
|
progress=0.5
|
|
)
|
|
|
|
assert updated_job is not None
|
|
assert updated_job.status == "processing"
|
|
assert updated_job.progress == 0.5
|
|
assert updated_job.updated_at > job.updated_at
|
|
assert updated_job.completed_at is None
|
|
|
|
def test_update_processing_job_completed(self, job_repo):
|
|
"""Test updating processing job to completed status."""
|
|
# Create job
|
|
job = job_repo.create(transcript_id=1, job_type="enhancement")
|
|
|
|
# Update to completed
|
|
result_data = {"enhanced_text": "Enhanced content", "confidence": 0.95}
|
|
updated_job = job_repo.update_status(
|
|
job.id,
|
|
status="completed",
|
|
progress=1.0,
|
|
result_data=result_data
|
|
)
|
|
|
|
assert updated_job is not None
|
|
assert updated_job.status == "completed"
|
|
assert updated_job.progress == 1.0
|
|
assert updated_job.result_data == result_data
|
|
assert updated_job.completed_at is not None
|
|
|
|
def test_update_processing_job_failed(self, job_repo):
|
|
"""Test updating processing job to failed status."""
|
|
# Create job
|
|
job = job_repo.create(transcript_id=1, job_type="enhancement")
|
|
|
|
# Update to failed
|
|
error_message = "Model API rate limit exceeded"
|
|
updated_job = job_repo.update_status(
|
|
job.id,
|
|
status="failed",
|
|
error_message=error_message
|
|
)
|
|
|
|
assert updated_job is not None
|
|
assert updated_job.status == "failed"
|
|
assert updated_job.error_message == error_message
|
|
assert updated_job.completed_at is None
|
|
|
|
def test_update_processing_job_not_found(self, job_repo):
|
|
"""Test updating a non-existent processing job."""
|
|
result = job_repo.update_status(99999, status="processing")
|
|
assert result is None
|
|
|
|
def test_delete_processing_job(self, job_repo):
|
|
"""Test deleting a processing job."""
|
|
# Create job
|
|
job = job_repo.create(transcript_id=1, job_type="enhancement")
|
|
job_id = job.id
|
|
|
|
# Delete job
|
|
success = job_repo.delete(job_id)
|
|
|
|
assert success is True
|
|
|
|
# Verify deletion
|
|
deleted_job = job_repo.get_by_id(job_id)
|
|
assert deleted_job is None
|
|
|
|
def test_delete_processing_job_not_found(self, job_repo):
|
|
"""Test deleting a non-existent processing job."""
|
|
success = job_repo.delete(99999)
|
|
assert success is False
|
|
|
|
def test_processing_job_status_transitions(self, job_repo):
|
|
"""Test valid status transitions for processing jobs."""
|
|
# Create job
|
|
job = job_repo.create(transcript_id=1, job_type="enhancement")
|
|
|
|
# Test status transitions
|
|
statuses = ["pending", "processing", "completed"]
|
|
|
|
for status in statuses:
|
|
updated_job = job_repo.update_status(job.id, status=status)
|
|
assert updated_job.status == status
|
|
|
|
def test_processing_job_progress_tracking(self, job_repo):
|
|
"""Test progress tracking for processing jobs."""
|
|
# Create job
|
|
job = job_repo.create(transcript_id=1, job_type="enhancement")
|
|
|
|
# Update progress in steps
|
|
progress_steps = [0.25, 0.5, 0.75, 1.0]
|
|
|
|
for progress in progress_steps:
|
|
updated_job = job_repo.update_status(job.id, progress=progress)
|
|
assert updated_job.progress == progress
|
|
|
|
def test_processing_job_relationships(self, job_repo, db_session):
|
|
"""Test processing job relationships with transcripts."""
|
|
# This test would verify relationships with transcripts and other entities
|
|
# Implementation depends on the actual relationship structure
|
|
pass
|
|
|
|
|
|
class TestBackwardCompatibility:
|
|
"""Test suite for backward compatibility layer."""
|
|
|
|
@pytest.fixture(scope="class")
|
|
def test_db_engine(self):
|
|
"""Create test database engine."""
|
|
test_db_url = get_database_url().replace("/trax", "/trax_test")
|
|
engine = create_engine(test_db_url)
|
|
Base.metadata.create_all(engine)
|
|
yield engine
|
|
Base.metadata.drop_all(engine)
|
|
engine.dispose()
|
|
|
|
@pytest.fixture
|
|
def db_session(self, test_db_engine):
|
|
"""Create database session for tests."""
|
|
Session = sessionmaker(bind=test_db_engine)
|
|
session = Session()
|
|
yield session
|
|
session.rollback()
|
|
session.close()
|
|
|
|
def test_v2_to_v1_format_conversion(self, db_session):
|
|
"""Test converting v2 transcript to v1 format."""
|
|
from src.database.models import TranscriptionResult
|
|
from src.compatibility.backward_compatibility import TranscriptBackwardCompatibility
|
|
|
|
# Create v2 transcript with complex data
|
|
v2_transcript = TranscriptionResult(
|
|
pipeline_version="v2",
|
|
content={"text": "Original transcript text"},
|
|
enhanced_content={"enhanced_text": "Enhanced version"},
|
|
diarization_content={
|
|
"speakers": [
|
|
{"id": 1, "name": "Speaker 1", "segments": [{"start": 0, "end": 5, "text": "Hello"}]},
|
|
{"id": 2, "name": "Speaker 2", "segments": [{"start": 5, "end": 10, "text": "World"}]}
|
|
]
|
|
},
|
|
merged_content={"text": "Hello World", "speakers": ["Speaker 1", "Speaker 2"]},
|
|
model_used="whisper-large-v3",
|
|
accuracy_estimate=0.95
|
|
)
|
|
db_session.add(v2_transcript)
|
|
db_session.commit()
|
|
|
|
# Convert to v1 format
|
|
v1_format = TranscriptBackwardCompatibility.to_v1_format(v2_transcript)
|
|
|
|
# Verify v1 format structure
|
|
assert "id" in v1_format
|
|
assert "content" in v1_format
|
|
assert "created_at" in v1_format
|
|
assert "updated_at" in v1_format
|
|
|
|
# Verify content is merged appropriately
|
|
assert "Hello World" in v1_format["content"]
|
|
|
|
def test_v1_to_v2_update(self, db_session):
|
|
"""Test updating v2 transcript from v1 format request."""
|
|
from src.database.models import TranscriptionResult
|
|
from src.compatibility.backward_compatibility import TranscriptBackwardCompatibility
|
|
|
|
# Create v2 transcript
|
|
v2_transcript = TranscriptionResult(
|
|
pipeline_version="v2",
|
|
content={"text": "Original content"},
|
|
enhanced_content={"enhanced": True}
|
|
)
|
|
db_session.add(v2_transcript)
|
|
db_session.commit()
|
|
|
|
# Update from v1 request
|
|
v1_data = {
|
|
"title": "Updated Title",
|
|
"content": "Updated content from v1 client"
|
|
}
|
|
|
|
TranscriptBackwardCompatibility.update_from_v1_request(v2_transcript, v1_data)
|
|
db_session.commit()
|
|
|
|
# Verify updates
|
|
assert v2_transcript.content["text"] == "Updated content from v1 client"
|
|
assert v2_transcript.processing_metadata["v1_update"] is True
|
|
|
|
def test_extract_merged_content(self, db_session):
|
|
"""Test extracting plain text from merged content."""
|
|
from src.compatibility.backward_compatibility import TranscriptBackwardCompatibility
|
|
|
|
# Test with text field
|
|
merged_content_with_text = {"text": "Simple text content"}
|
|
extracted = TranscriptBackwardCompatibility._extract_merged_content(merged_content_with_text)
|
|
assert extracted == "Simple text content"
|
|
|
|
# Test with segments
|
|
merged_content_with_segments = {
|
|
"segments": [
|
|
{"text": "Hello", "start": 0, "end": 2},
|
|
{"text": "World", "start": 2, "end": 4}
|
|
]
|
|
}
|
|
extracted = TranscriptBackwardCompatibility._extract_merged_content(merged_content_with_segments)
|
|
assert extracted == "Hello World"
|
|
|
|
# Test with empty content
|
|
extracted = TranscriptBackwardCompatibility._extract_merged_content(None)
|
|
assert extracted == ""
|
|
|
|
# Test with complex structure
|
|
complex_content = {"metadata": {"text": "Nested text"}}
|
|
extracted = TranscriptBackwardCompatibility._extract_merged_content(complex_content)
|
|
assert extracted == str(complex_content)
|