270 lines
10 KiB
Python
270 lines
10 KiB
Python
"""Unit tests for DeepSeek Enhancement Service (v2).
|
|
|
|
Tests the AI-powered transcript enhancement service that improves
|
|
transcription accuracy from 95% to 99% through intelligent corrections.
|
|
"""
|
|
|
|
import pytest
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
from uuid import uuid4
|
|
|
|
from src.services.enhancement import (
|
|
DeepSeekEnhancementService,
|
|
EnhancementConfig,
|
|
EnhancementResult,
|
|
EnhancementError,
|
|
create_enhancement_service,
|
|
)
|
|
|
|
|
|
class TestEnhancementConfig:
|
|
"""Test enhancement configuration."""
|
|
|
|
def test_default_config(self):
|
|
"""Test default configuration values."""
|
|
config = EnhancementConfig()
|
|
|
|
assert config.model == "deepseek-chat"
|
|
assert config.temperature == 0.0
|
|
assert config.max_tokens == 4096
|
|
assert config.quality_threshold == 0.7
|
|
assert config.enable_caching is True
|
|
assert config.cache_ttl == 86400 # 24 hours
|
|
|
|
def test_custom_config(self):
|
|
"""Test custom configuration values."""
|
|
config = EnhancementConfig(
|
|
model="deepseek-coder",
|
|
temperature=0.1,
|
|
max_tokens=8192,
|
|
quality_threshold=0.8,
|
|
enable_caching=False,
|
|
cache_ttl=3600
|
|
)
|
|
|
|
assert config.model == "deepseek-coder"
|
|
assert config.temperature == 0.1
|
|
assert config.max_tokens == 8192
|
|
assert config.quality_threshold == 0.8
|
|
assert config.enable_caching is False
|
|
assert config.cache_ttl == 3600
|
|
|
|
def test_config_validation(self):
|
|
"""Test configuration validation."""
|
|
# Valid config should not raise
|
|
config = EnhancementConfig()
|
|
config.validate()
|
|
|
|
# Invalid temperature should raise
|
|
with pytest.raises(ValueError, match="Temperature must be between 0 and 1"):
|
|
EnhancementConfig(temperature=1.5).validate()
|
|
|
|
# Invalid quality threshold should raise
|
|
with pytest.raises(ValueError, match="Quality threshold must be between 0 and 1"):
|
|
EnhancementConfig(quality_threshold=2.0).validate()
|
|
|
|
|
|
class TestEnhancementResult:
|
|
"""Test enhancement result data structure."""
|
|
|
|
def test_enhancement_result_creation(self):
|
|
"""Test creating enhancement result."""
|
|
result = EnhancementResult(
|
|
original_text="hello world",
|
|
enhanced_text="Hello, world!",
|
|
confidence_score=0.95,
|
|
improvements=["punctuation", "capitalization"],
|
|
processing_time=2.5,
|
|
model_used="deepseek-chat",
|
|
metadata={"tokens_used": 150}
|
|
)
|
|
|
|
assert result.original_text == "hello world"
|
|
assert result.enhanced_text == "Hello, world!"
|
|
assert result.confidence_score == 0.95
|
|
assert result.improvements == ["punctuation", "capitalization"]
|
|
assert result.processing_time == 2.5
|
|
assert result.model_used == "deepseek-chat"
|
|
assert result.metadata["tokens_used"] == 150
|
|
|
|
def test_enhancement_result_to_dict(self):
|
|
"""Test converting result to dictionary."""
|
|
result = EnhancementResult(
|
|
original_text="test",
|
|
enhanced_text="Test!",
|
|
confidence_score=0.9,
|
|
improvements=["capitalization"],
|
|
processing_time=1.0,
|
|
model_used="deepseek-chat",
|
|
metadata={"test": "value"}
|
|
)
|
|
|
|
result_dict = result.to_dict()
|
|
|
|
assert result_dict["original_text"] == "test"
|
|
assert result_dict["enhanced_text"] == "Test!"
|
|
assert result_dict["confidence_score"] == 0.9
|
|
assert result_dict["improvements"] == ["capitalization"]
|
|
assert result_dict["processing_time"] == 1.0
|
|
assert result_dict["model_used"] == "deepseek-chat"
|
|
assert result_dict["metadata"]["test"] == "value"
|
|
assert "created_at" in result_dict
|
|
|
|
|
|
class TestDeepSeekEnhancementService:
|
|
"""Test the DeepSeek enhancement service."""
|
|
|
|
@pytest.fixture
|
|
def enhancement_service(self):
|
|
"""Create enhancement service with mocked dependencies."""
|
|
config = EnhancementConfig(
|
|
model="deepseek-chat",
|
|
temperature=0.0,
|
|
max_tokens=4096,
|
|
quality_threshold=0.1 # Lower threshold for testing
|
|
)
|
|
return DeepSeekEnhancementService(config)
|
|
|
|
@pytest.fixture
|
|
def sample_transcript(self):
|
|
"""Sample transcript for testing."""
|
|
return """hello world this is a test transcript it needs punctuation and capitalization
|
|
there are some technical terms like python javascript and react that should be properly formatted
|
|
also there are some numbers like 42 and 3.14 that should be preserved"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_service_initialization(self, enhancement_service):
|
|
"""Test service initialization."""
|
|
with patch("src.services.enhancement.api.deepseek.DeepSeekAPI") as mock_deepseek:
|
|
with patch("src.config.config.DEEPSEEK_API_KEY", "test-key"):
|
|
await enhancement_service.initialize()
|
|
|
|
assert enhancement_service.is_initialized is True
|
|
assert enhancement_service.api_client is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_enhance_transcript_success(self, enhancement_service, sample_transcript):
|
|
"""Test successful transcript enhancement."""
|
|
with patch("src.services.enhancement.api.deepseek.DeepSeekAPI") as mock_deepseek:
|
|
with patch("src.config.config.DEEPSEEK_API_KEY", "test-key"):
|
|
# Mock API client
|
|
mock_client = AsyncMock()
|
|
mock_deepseek.return_value = mock_client
|
|
mock_client.chat.completions.create = AsyncMock(return_value=MagicMock(
|
|
choices=[MagicMock(message=MagicMock(content="Hello, world! This is a test transcript."))]
|
|
))
|
|
|
|
await enhancement_service.initialize()
|
|
|
|
result = await enhancement_service.enhance_transcript(sample_transcript)
|
|
|
|
assert result.enhanced_text != sample_transcript
|
|
assert result.confidence_score > 0.1 # Adjusted for test threshold
|
|
assert result.processing_time > 0
|
|
assert result.model_used == "deepseek-chat"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_enhance_transcript_api_error(self, enhancement_service, sample_transcript):
|
|
"""Test handling of API errors."""
|
|
with patch("src.services.enhancement.api.deepseek.DeepSeekAPI") as mock_deepseek:
|
|
with patch("src.config.config.DEEPSEEK_API_KEY", "test-key"):
|
|
# Mock API client with error
|
|
mock_client = AsyncMock()
|
|
mock_deepseek.return_value = mock_client
|
|
mock_client.chat.completions.create = AsyncMock(side_effect=Exception("API Error"))
|
|
|
|
await enhancement_service.initialize()
|
|
|
|
with pytest.raises(EnhancementError, match="Failed to enhance transcript"):
|
|
await enhancement_service.enhance_transcript(sample_transcript)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_enhance_transcript_caching(self, enhancement_service, sample_transcript):
|
|
"""Test enhancement result caching."""
|
|
with patch("src.services.enhancement.api.deepseek.DeepSeekAPI") as mock_deepseek:
|
|
with patch("src.config.config.DEEPSEEK_API_KEY", "test-key"):
|
|
# Mock API client
|
|
mock_client = AsyncMock()
|
|
mock_deepseek.return_value = mock_client
|
|
mock_client.chat.completions.create = AsyncMock(return_value=MagicMock(
|
|
choices=[MagicMock(message=MagicMock(content="Enhanced transcript"))]
|
|
))
|
|
|
|
await enhancement_service.initialize()
|
|
|
|
# First call should hit the API
|
|
result1 = await enhancement_service.enhance_transcript(sample_transcript)
|
|
|
|
# Second call should use cache
|
|
result2 = await enhancement_service.enhance_transcript(sample_transcript)
|
|
|
|
# Should only call API once
|
|
assert mock_client.chat.completions.create.call_count == 1
|
|
assert result1.enhanced_text == result2.enhanced_text
|
|
|
|
|
|
class TestEnhancementServiceFactory:
|
|
"""Test enhancement service factory function."""
|
|
|
|
def test_create_enhancement_service_default(self):
|
|
"""Test creating service with default configuration."""
|
|
service = create_enhancement_service()
|
|
|
|
assert isinstance(service, DeepSeekEnhancementService)
|
|
assert service.config.model == "deepseek-chat"
|
|
assert service.config.temperature == 0.0
|
|
|
|
def test_create_enhancement_service_custom(self):
|
|
"""Test creating service with custom configuration."""
|
|
config = EnhancementConfig(
|
|
model="deepseek-coder",
|
|
temperature=0.1,
|
|
quality_threshold=0.8
|
|
)
|
|
|
|
service = create_enhancement_service(config)
|
|
|
|
assert isinstance(service, DeepSeekEnhancementService)
|
|
assert service.config.model == "deepseek-coder"
|
|
assert service.config.temperature == 0.1
|
|
assert service.config.quality_threshold == 0.8
|
|
|
|
|
|
class TestEnhancementErrorHandling:
|
|
"""Test error handling in enhancement service."""
|
|
|
|
def test_enhancement_error_with_details(self):
|
|
"""Test enhancement error with detailed information."""
|
|
error = EnhancementError(
|
|
"API call failed",
|
|
original_text="test",
|
|
error_type="api_error",
|
|
retry_count=3
|
|
)
|
|
|
|
assert str(error) == "API call failed"
|
|
assert error.original_text == "test"
|
|
assert error.error_type == "api_error"
|
|
assert error.retry_count == 3
|
|
|
|
def test_enhancement_error_serialization(self):
|
|
"""Test enhancement error serialization."""
|
|
error = EnhancementError(
|
|
"Test error",
|
|
original_text="test",
|
|
error_type="test_error",
|
|
retry_count=1
|
|
)
|
|
|
|
error_dict = error.to_dict()
|
|
|
|
assert error_dict["message"] == "Test error"
|
|
assert error_dict["original_text"] == "test"
|
|
assert error_dict["error_type"] == "test_error"
|
|
assert error_dict["retry_count"] == 1
|
|
assert "timestamp" in error_dict
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__])
|