trax/tests/test_service_integration.py

226 lines
8.2 KiB
Python

"""Integration tests for service interactions using mock implementations."""
import asyncio
import pytest
from pathlib import Path
from typing import Dict, Any
from src.services.mocks import (
create_mock_service_container,
create_mock_youtube_service,
create_mock_media_service,
create_mock_transcription_service,
)
from src.services.protocols import (
YouTubeServiceProtocol,
MediaServiceProtocol,
TranscriptionServiceProtocol,
ExportFormat,
)
class TestServiceIntegration:
"""Test service interactions and workflows."""
@pytest.fixture
def mock_services(self):
"""Create mock service container for testing."""
return create_mock_service_container()
@pytest.mark.asyncio
async def test_youtube_to_transcription_workflow(self, mock_services):
"""Test complete workflow from YouTube URL to transcription."""
youtube_service = mock_services["youtube_service"]
media_service = mock_services["media_service"]
transcription_service = mock_services["transcription_service"]
# Extract YouTube metadata
url = "https://youtube.com/watch?v=mock123"
metadata = await youtube_service.extract_metadata(url)
assert metadata["title"] == "Mock YouTube Video"
assert metadata["duration"] == 120
# Process media pipeline
output_dir = Path("/tmp/test_output")
media_file = await media_service.process_media_pipeline(url, output_dir)
assert media_file is not None
assert hasattr(media_file, 'id')
# Transcribe the media file
transcription_result = await transcription_service.transcribe_file(media_file)
assert transcription_result.raw_content is not None
assert transcription_result.word_count > 0
assert transcription_result.accuracy_estimate > 0.8
@pytest.mark.asyncio
async def test_batch_processing_workflow(self, mock_services):
"""Test batch processing workflow."""
batch_processor = mock_services["batch_processor"]
# Add multiple tasks
task_ids = []
for i in range(3):
task_id = await batch_processor.add_task(
"transcription",
{"url": f"https://youtube.com/watch?v=video{i}", "priority": "high"}
)
task_ids.append(task_id)
# Process tasks
await batch_processor.process_tasks(max_workers=2)
# Check progress
progress = await batch_processor.get_progress()
assert progress.total_tasks == 3
assert progress.completed_tasks > 0
assert progress.overall_progress > 0
@pytest.mark.asyncio
async def test_service_protocol_compliance(self, mock_services):
"""Test that all mock services properly implement their protocols."""
from src.services.protocols import validate_protocol_implementation
# Test each service individually
assert validate_protocol_implementation(
mock_services["youtube_service"],
YouTubeServiceProtocol
)
assert validate_protocol_implementation(
mock_services["media_service"],
MediaServiceProtocol
)
assert validate_protocol_implementation(
mock_services["transcription_service"],
TranscriptionServiceProtocol
)
@pytest.mark.asyncio
async def test_media_service_operations(self, mock_services):
"""Test media service operations."""
media_service = mock_services["media_service"]
# Test file validation
is_valid = await media_service.validate_file_size(Path("/tmp/test.wav"), max_size_mb=1)
assert is_valid is True
# Test audio quality check
quality_ok = await media_service.check_audio_quality(Path("/tmp/test.wav"))
assert quality_ok is True
# Test media info extraction
info = await media_service.get_media_info(Path("/tmp/test.wav"))
assert info["format"] == "wav"
assert info["sample_rate"] == 16000
@pytest.mark.asyncio
async def test_transcription_service_operations(self, mock_services):
"""Test transcription service operations."""
transcription_service = mock_services["transcription_service"]
# Test audio transcription
audio_path = Path("/tmp/test_audio.wav")
result = await transcription_service.transcribe_audio(audio_path)
assert result.raw_content is not None
assert result.segments is not None
assert result.confidence_scores is not None
# Test job creation and status
mock_media_file = type('MockMediaFile', (), {'id': 'test-id'})()
job = await transcription_service.create_transcription_job(mock_media_file)
assert job is not None
assert hasattr(job, 'id')
status = await transcription_service.get_job_status(job.id)
assert status == "completed"
@pytest.mark.asyncio
async def test_enhancement_service_operations(self, mock_services):
"""Test enhancement service operations."""
enhancement_service = mock_services["enhancement_service"]
# Initialize service
await enhancement_service.initialize()
# Test transcript enhancement
original_text = "this is a test transcript with some issues"
enhanced = await enhancement_service.enhance_transcript(original_text)
assert enhanced.original_text == original_text
assert enhanced.enhanced_text != original_text
assert enhanced.improvements is not None
assert enhanced.confidence_score > 0.5
@pytest.mark.asyncio
async def test_export_service_operations(self, mock_services):
"""Test export service operations."""
export_service = mock_services["export_service"]
# Test supported formats
formats = export_service.get_supported_formats()
assert ExportFormat.JSON in formats
assert ExportFormat.TXT in formats
# Test transcript export
from src.services.protocols import TranscriptionResult
mock_transcript = TranscriptionResult(
raw_content="Test transcript content",
segments=[],
confidence_scores=[],
accuracy_estimate=0.9,
word_count=3,
processing_time_ms=1000,
model_used="whisper-1"
)
output_path = Path("/tmp/test_export.txt")
result = await export_service.export_transcript(
mock_transcript, output_path, ExportFormat.TXT
)
assert result.success is True
assert result.file_path == output_path
@pytest.mark.asyncio
async def test_error_handling_in_workflows(self, mock_services):
"""Test error handling in service workflows."""
youtube_service = mock_services["youtube_service"]
# Test batch extraction with mixed success/failure
urls = [
"https://youtube.com/watch?v=valid1",
"https://youtube.com/watch?v=invalid",
"https://youtube.com/watch?v=valid2"
]
results = await youtube_service.batch_extract(urls)
assert len(results) == 3
# Check that some succeeded and some failed
success_count = sum(1 for r in results if r["success"])
assert success_count > 0
assert success_count < len(results)
@pytest.mark.asyncio
async def test_progress_callback_functionality(self, mock_services):
"""Test progress callback functionality in services."""
media_service = mock_services["media_service"]
progress_updates = []
def progress_callback(progress: float, message: str):
progress_updates.append((progress, message))
# Test download with progress callback
await media_service.download_media(
"https://example.com/test.mp3",
Path("/tmp"),
progress_callback
)
# Verify progress updates were called
assert len(progress_updates) > 0
assert progress_updates[0][0] == 0.25
assert progress_updates[-1][0] == 1.0
if __name__ == "__main__":
pytest.main([__file__])