trax/tests/test_enhancement_service.py

270 lines
10 KiB
Python

"""Unit tests for DeepSeek Enhancement Service (v2).
Tests the AI-powered transcript enhancement service that improves
transcription accuracy from 95% to 99% through intelligent corrections.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from uuid import uuid4
from src.services.enhancement import (
DeepSeekEnhancementService,
EnhancementConfig,
EnhancementResult,
EnhancementError,
create_enhancement_service,
)
class TestEnhancementConfig:
"""Test enhancement configuration."""
def test_default_config(self):
"""Test default configuration values."""
config = EnhancementConfig()
assert config.model == "deepseek-chat"
assert config.temperature == 0.0
assert config.max_tokens == 4096
assert config.quality_threshold == 0.7
assert config.enable_caching is True
assert config.cache_ttl == 86400 # 24 hours
def test_custom_config(self):
"""Test custom configuration values."""
config = EnhancementConfig(
model="deepseek-coder",
temperature=0.1,
max_tokens=8192,
quality_threshold=0.8,
enable_caching=False,
cache_ttl=3600
)
assert config.model == "deepseek-coder"
assert config.temperature == 0.1
assert config.max_tokens == 8192
assert config.quality_threshold == 0.8
assert config.enable_caching is False
assert config.cache_ttl == 3600
def test_config_validation(self):
"""Test configuration validation."""
# Valid config should not raise
config = EnhancementConfig()
config.validate()
# Invalid temperature should raise
with pytest.raises(ValueError, match="Temperature must be between 0 and 1"):
EnhancementConfig(temperature=1.5).validate()
# Invalid quality threshold should raise
with pytest.raises(ValueError, match="Quality threshold must be between 0 and 1"):
EnhancementConfig(quality_threshold=2.0).validate()
class TestEnhancementResult:
"""Test enhancement result data structure."""
def test_enhancement_result_creation(self):
"""Test creating enhancement result."""
result = EnhancementResult(
original_text="hello world",
enhanced_text="Hello, world!",
confidence_score=0.95,
improvements=["punctuation", "capitalization"],
processing_time=2.5,
model_used="deepseek-chat",
metadata={"tokens_used": 150}
)
assert result.original_text == "hello world"
assert result.enhanced_text == "Hello, world!"
assert result.confidence_score == 0.95
assert result.improvements == ["punctuation", "capitalization"]
assert result.processing_time == 2.5
assert result.model_used == "deepseek-chat"
assert result.metadata["tokens_used"] == 150
def test_enhancement_result_to_dict(self):
"""Test converting result to dictionary."""
result = EnhancementResult(
original_text="test",
enhanced_text="Test!",
confidence_score=0.9,
improvements=["capitalization"],
processing_time=1.0,
model_used="deepseek-chat",
metadata={"test": "value"}
)
result_dict = result.to_dict()
assert result_dict["original_text"] == "test"
assert result_dict["enhanced_text"] == "Test!"
assert result_dict["confidence_score"] == 0.9
assert result_dict["improvements"] == ["capitalization"]
assert result_dict["processing_time"] == 1.0
assert result_dict["model_used"] == "deepseek-chat"
assert result_dict["metadata"]["test"] == "value"
assert "created_at" in result_dict
class TestDeepSeekEnhancementService:
"""Test the DeepSeek enhancement service."""
@pytest.fixture
def enhancement_service(self):
"""Create enhancement service with mocked dependencies."""
config = EnhancementConfig(
model="deepseek-chat",
temperature=0.0,
max_tokens=4096,
quality_threshold=0.1 # Lower threshold for testing
)
return DeepSeekEnhancementService(config)
@pytest.fixture
def sample_transcript(self):
"""Sample transcript for testing."""
return """hello world this is a test transcript it needs punctuation and capitalization
there are some technical terms like python javascript and react that should be properly formatted
also there are some numbers like 42 and 3.14 that should be preserved"""
@pytest.mark.asyncio
async def test_service_initialization(self, enhancement_service):
"""Test service initialization."""
with patch("src.services.enhancement.api.deepseek.DeepSeekAPI") as mock_deepseek:
with patch("src.config.config.DEEPSEEK_API_KEY", "test-key"):
await enhancement_service.initialize()
assert enhancement_service.is_initialized is True
assert enhancement_service.api_client is not None
@pytest.mark.asyncio
async def test_enhance_transcript_success(self, enhancement_service, sample_transcript):
"""Test successful transcript enhancement."""
with patch("src.services.enhancement.api.deepseek.DeepSeekAPI") as mock_deepseek:
with patch("src.config.config.DEEPSEEK_API_KEY", "test-key"):
# Mock API client
mock_client = AsyncMock()
mock_deepseek.return_value = mock_client
mock_client.chat.completions.create = AsyncMock(return_value=MagicMock(
choices=[MagicMock(message=MagicMock(content="Hello, world! This is a test transcript."))]
))
await enhancement_service.initialize()
result = await enhancement_service.enhance_transcript(sample_transcript)
assert result.enhanced_text != sample_transcript
assert result.confidence_score > 0.1 # Adjusted for test threshold
assert result.processing_time > 0
assert result.model_used == "deepseek-chat"
@pytest.mark.asyncio
async def test_enhance_transcript_api_error(self, enhancement_service, sample_transcript):
"""Test handling of API errors."""
with patch("src.services.enhancement.api.deepseek.DeepSeekAPI") as mock_deepseek:
with patch("src.config.config.DEEPSEEK_API_KEY", "test-key"):
# Mock API client with error
mock_client = AsyncMock()
mock_deepseek.return_value = mock_client
mock_client.chat.completions.create = AsyncMock(side_effect=Exception("API Error"))
await enhancement_service.initialize()
with pytest.raises(EnhancementError, match="Failed to enhance transcript"):
await enhancement_service.enhance_transcript(sample_transcript)
@pytest.mark.asyncio
async def test_enhance_transcript_caching(self, enhancement_service, sample_transcript):
"""Test enhancement result caching."""
with patch("src.services.enhancement.api.deepseek.DeepSeekAPI") as mock_deepseek:
with patch("src.config.config.DEEPSEEK_API_KEY", "test-key"):
# Mock API client
mock_client = AsyncMock()
mock_deepseek.return_value = mock_client
mock_client.chat.completions.create = AsyncMock(return_value=MagicMock(
choices=[MagicMock(message=MagicMock(content="Enhanced transcript"))]
))
await enhancement_service.initialize()
# First call should hit the API
result1 = await enhancement_service.enhance_transcript(sample_transcript)
# Second call should use cache
result2 = await enhancement_service.enhance_transcript(sample_transcript)
# Should only call API once
assert mock_client.chat.completions.create.call_count == 1
assert result1.enhanced_text == result2.enhanced_text
class TestEnhancementServiceFactory:
"""Test enhancement service factory function."""
def test_create_enhancement_service_default(self):
"""Test creating service with default configuration."""
service = create_enhancement_service()
assert isinstance(service, DeepSeekEnhancementService)
assert service.config.model == "deepseek-chat"
assert service.config.temperature == 0.0
def test_create_enhancement_service_custom(self):
"""Test creating service with custom configuration."""
config = EnhancementConfig(
model="deepseek-coder",
temperature=0.1,
quality_threshold=0.8
)
service = create_enhancement_service(config)
assert isinstance(service, DeepSeekEnhancementService)
assert service.config.model == "deepseek-coder"
assert service.config.temperature == 0.1
assert service.config.quality_threshold == 0.8
class TestEnhancementErrorHandling:
"""Test error handling in enhancement service."""
def test_enhancement_error_with_details(self):
"""Test enhancement error with detailed information."""
error = EnhancementError(
"API call failed",
original_text="test",
error_type="api_error",
retry_count=3
)
assert str(error) == "API call failed"
assert error.original_text == "test"
assert error.error_type == "api_error"
assert error.retry_count == 3
def test_enhancement_error_serialization(self):
"""Test enhancement error serialization."""
error = EnhancementError(
"Test error",
original_text="test",
error_type="test_error",
retry_count=1
)
error_dict = error.to_dict()
assert error_dict["message"] == "Test error"
assert error_dict["original_text"] == "test"
assert error_dict["error_type"] == "test_error"
assert error_dict["retry_count"] == 1
assert "timestamp" in error_dict
if __name__ == "__main__":
pytest.main([__file__])