408 lines
14 KiB
Python
408 lines
14 KiB
Python
"""Unit tests for the template-based analysis system."""
|
|
|
|
import pytest
|
|
from unittest.mock import Mock, AsyncMock
|
|
from datetime import datetime
|
|
|
|
from backend.models.analysis_templates import (
|
|
AnalysisTemplate,
|
|
TemplateSet,
|
|
TemplateRegistry,
|
|
TemplateType,
|
|
ComplexityLevel
|
|
)
|
|
from backend.services.template_driven_agent import (
|
|
TemplateDrivenAgent,
|
|
TemplateAnalysisRequest,
|
|
TemplateAnalysisResult
|
|
)
|
|
from backend.services.template_defaults import (
|
|
create_educational_templates,
|
|
create_domain_templates,
|
|
create_default_registry
|
|
)
|
|
|
|
|
|
class TestAnalysisTemplate:
|
|
"""Test AnalysisTemplate model."""
|
|
|
|
def test_template_creation(self):
|
|
"""Test creating a basic template."""
|
|
template = AnalysisTemplate(
|
|
id="test_template",
|
|
name="Test Template",
|
|
description="A test template for unit testing",
|
|
template_type=TemplateType.EDUCATIONAL,
|
|
complexity_level=ComplexityLevel.BEGINNER,
|
|
system_prompt="You are a test assistant analyzing {content}",
|
|
analysis_focus=["testing", "validation", "quality"],
|
|
output_format="## Test Results\n{results}",
|
|
variables={"example_var": "test_value"}
|
|
)
|
|
|
|
assert template.id == "test_template"
|
|
assert template.name == "Test Template"
|
|
assert template.template_type == TemplateType.EDUCATIONAL
|
|
assert template.complexity_level == ComplexityLevel.BEGINNER
|
|
assert len(template.analysis_focus) == 3
|
|
assert template.variables["example_var"] == "test_value"
|
|
|
|
def test_template_prompt_rendering(self):
|
|
"""Test template prompt rendering with variables."""
|
|
template = AnalysisTemplate(
|
|
id="render_test",
|
|
name="Render Test",
|
|
description="Test template rendering",
|
|
template_type=TemplateType.CUSTOM,
|
|
system_prompt="Analyze this {content_type} about {topic} for {audience}",
|
|
analysis_focus=["rendering"],
|
|
output_format="Results: {output}",
|
|
variables={"content_type": "article", "topic": "testing"}
|
|
)
|
|
|
|
context = {"audience": "developers"}
|
|
rendered = template.render_prompt(context)
|
|
|
|
assert "article" in rendered
|
|
assert "testing" in rendered
|
|
assert "developers" in rendered
|
|
|
|
def test_template_validation(self):
|
|
"""Test template variable validation."""
|
|
# Valid template
|
|
template = AnalysisTemplate(
|
|
id="valid",
|
|
name="Valid Template",
|
|
description="Valid template for testing",
|
|
template_type=TemplateType.CUSTOM,
|
|
system_prompt="Test prompt",
|
|
analysis_focus=["test"],
|
|
output_format="Test output",
|
|
variables={"key": "value", "number": 42}
|
|
)
|
|
|
|
# Should not raise an exception
|
|
assert template.variables["key"] == "value"
|
|
assert template.variables["number"] == 42
|
|
|
|
|
|
class TestTemplateSet:
|
|
"""Test TemplateSet model."""
|
|
|
|
def test_template_set_creation(self):
|
|
"""Test creating a template set."""
|
|
template1 = AnalysisTemplate(
|
|
id="template1",
|
|
name="Template 1",
|
|
description="First template",
|
|
template_type=TemplateType.EDUCATIONAL,
|
|
system_prompt="Test prompt 1",
|
|
analysis_focus=["test1"],
|
|
output_format="Output 1"
|
|
)
|
|
|
|
template2 = AnalysisTemplate(
|
|
id="template2",
|
|
name="Template 2",
|
|
description="Second template",
|
|
template_type=TemplateType.EDUCATIONAL,
|
|
system_prompt="Test prompt 2",
|
|
analysis_focus=["test2"],
|
|
output_format="Output 2"
|
|
)
|
|
|
|
template_set = TemplateSet(
|
|
id="test_set",
|
|
name="Test Set",
|
|
description="Test template set",
|
|
template_type=TemplateType.EDUCATIONAL,
|
|
templates={
|
|
"template1": template1,
|
|
"template2": template2
|
|
},
|
|
execution_order=["template1", "template2"]
|
|
)
|
|
|
|
assert template_set.id == "test_set"
|
|
assert len(template_set.templates) == 2
|
|
assert template_set.get_template("template1") == template1
|
|
assert template_set.get_template("template2") == template2
|
|
assert template_set.execution_order == ["template1", "template2"]
|
|
|
|
def test_template_set_validation(self):
|
|
"""Test template set validation."""
|
|
template = AnalysisTemplate(
|
|
id="valid_template",
|
|
name="Valid Template",
|
|
description="Valid template",
|
|
template_type=TemplateType.EDUCATIONAL,
|
|
system_prompt="Test prompt",
|
|
analysis_focus=["test"],
|
|
output_format="Test output"
|
|
)
|
|
|
|
# Valid template set
|
|
template_set = TemplateSet(
|
|
id="valid_set",
|
|
name="Valid Set",
|
|
description="Valid template set",
|
|
template_type=TemplateType.EDUCATIONAL,
|
|
templates={"valid_template": template}
|
|
)
|
|
|
|
assert len(template_set.templates) == 1
|
|
|
|
# Template set with mismatched IDs should raise validation error
|
|
with pytest.raises(ValueError, match="Template ID mismatch"):
|
|
TemplateSet(
|
|
id="invalid_set",
|
|
name="Invalid Set",
|
|
description="Invalid template set",
|
|
template_type=TemplateType.EDUCATIONAL,
|
|
templates={"wrong_id": template} # ID mismatch
|
|
)
|
|
|
|
|
|
class TestTemplateRegistry:
|
|
"""Test TemplateRegistry functionality."""
|
|
|
|
def test_registry_operations(self):
|
|
"""Test registry register/get operations."""
|
|
registry = TemplateRegistry()
|
|
|
|
template = AnalysisTemplate(
|
|
id="registry_test",
|
|
name="Registry Test",
|
|
description="Test template for registry",
|
|
template_type=TemplateType.CUSTOM,
|
|
system_prompt="Test prompt",
|
|
analysis_focus=["registry"],
|
|
output_format="Test output"
|
|
)
|
|
|
|
# Register template
|
|
registry.register_template(template)
|
|
|
|
# Retrieve template
|
|
retrieved = registry.get_template("registry_test")
|
|
assert retrieved is not None
|
|
assert retrieved.id == "registry_test"
|
|
assert retrieved.name == "Registry Test"
|
|
|
|
# Test non-existent template
|
|
assert registry.get_template("non_existent") is None
|
|
|
|
def test_registry_filtering(self):
|
|
"""Test registry template filtering."""
|
|
registry = TemplateRegistry()
|
|
|
|
educational_template = AnalysisTemplate(
|
|
id="educational",
|
|
name="Educational",
|
|
description="Educational template",
|
|
template_type=TemplateType.EDUCATIONAL,
|
|
system_prompt="Test",
|
|
analysis_focus=["education"],
|
|
output_format="Output"
|
|
)
|
|
|
|
domain_template = AnalysisTemplate(
|
|
id="domain",
|
|
name="Domain",
|
|
description="Domain template",
|
|
template_type=TemplateType.DOMAIN,
|
|
system_prompt="Test",
|
|
analysis_focus=["domain"],
|
|
output_format="Output"
|
|
)
|
|
|
|
registry.register_template(educational_template)
|
|
registry.register_template(domain_template)
|
|
|
|
# Test filtering
|
|
educational_templates = registry.list_templates(TemplateType.EDUCATIONAL)
|
|
assert len(educational_templates) == 1
|
|
assert educational_templates[0].id == "educational"
|
|
|
|
domain_templates = registry.list_templates(TemplateType.DOMAIN)
|
|
assert len(domain_templates) == 1
|
|
assert domain_templates[0].id == "domain"
|
|
|
|
all_templates = registry.list_templates()
|
|
assert len(all_templates) == 2
|
|
|
|
|
|
class TestTemplateDrivenAgent:
|
|
"""Test TemplateDrivenAgent functionality."""
|
|
|
|
@pytest.fixture
|
|
def mock_ai_service(self):
|
|
"""Mock AI service for testing."""
|
|
service = Mock()
|
|
service.generate_summary = AsyncMock()
|
|
return service
|
|
|
|
@pytest.fixture
|
|
def test_registry(self):
|
|
"""Test registry with sample templates."""
|
|
registry = TemplateRegistry()
|
|
|
|
template = AnalysisTemplate(
|
|
id="test_agent_template",
|
|
name="Test Agent Template",
|
|
description="Template for agent testing",
|
|
template_type=TemplateType.CUSTOM,
|
|
system_prompt="Analyze this content: {content}",
|
|
analysis_focus=["testing", "analysis"],
|
|
output_format="## Analysis\n{analysis}\n\n## Key Points\n- {points}",
|
|
min_insights=2,
|
|
max_insights=4
|
|
)
|
|
|
|
registry.register_template(template)
|
|
return registry
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_single_template_analysis(self, mock_ai_service, test_registry):
|
|
"""Test analyzing content with a single template."""
|
|
# Mock AI response
|
|
mock_ai_service.generate_summary.return_value = """
|
|
## Analysis
|
|
This is a test analysis of the provided content.
|
|
|
|
## Key Points
|
|
- First important insight about the content
|
|
- Second valuable observation
|
|
- Third key finding
|
|
"""
|
|
|
|
agent = TemplateDrivenAgent(
|
|
ai_service=mock_ai_service,
|
|
template_registry=test_registry
|
|
)
|
|
|
|
request = TemplateAnalysisRequest(
|
|
content="Test content for analysis",
|
|
template_id="test_agent_template",
|
|
context={"additional": "context"}
|
|
)
|
|
|
|
result = await agent.analyze_with_template(request)
|
|
|
|
# Verify result
|
|
assert result.template_id == "test_agent_template"
|
|
assert result.template_name == "Test Agent Template"
|
|
assert "test analysis" in result.analysis.lower()
|
|
assert len(result.key_insights) >= 2
|
|
assert result.confidence_score > 0
|
|
assert result.processing_time_seconds > 0
|
|
|
|
# Verify AI service was called
|
|
mock_ai_service.generate_summary.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_nonexistent_template(self, mock_ai_service, test_registry):
|
|
"""Test error handling for non-existent template."""
|
|
agent = TemplateDrivenAgent(
|
|
ai_service=mock_ai_service,
|
|
template_registry=test_registry
|
|
)
|
|
|
|
request = TemplateAnalysisRequest(
|
|
content="Test content",
|
|
template_id="nonexistent_template"
|
|
)
|
|
|
|
with pytest.raises(Exception, match="Template not found"):
|
|
await agent.analyze_with_template(request)
|
|
|
|
def test_usage_statistics(self, mock_ai_service, test_registry):
|
|
"""Test usage statistics tracking."""
|
|
agent = TemplateDrivenAgent(
|
|
ai_service=mock_ai_service,
|
|
template_registry=test_registry
|
|
)
|
|
|
|
# Initially no usage
|
|
stats = agent.get_usage_stats()
|
|
assert len(stats) == 0
|
|
|
|
# Update stats manually (simulating usage)
|
|
agent._update_usage_stats("test_template")
|
|
agent._update_usage_stats("test_template")
|
|
agent._update_usage_stats("other_template")
|
|
|
|
stats = agent.get_usage_stats()
|
|
assert stats["test_template"] == 2
|
|
assert stats["other_template"] == 1
|
|
|
|
|
|
class TestDefaultTemplates:
|
|
"""Test default template creation."""
|
|
|
|
def test_educational_templates_creation(self):
|
|
"""Test creating educational template set."""
|
|
educational_set = create_educational_templates()
|
|
|
|
assert educational_set.id == "educational_perspectives"
|
|
assert educational_set.template_type == TemplateType.EDUCATIONAL
|
|
assert len(educational_set.templates) == 3
|
|
|
|
# Check individual templates
|
|
beginner = educational_set.get_template("educational_beginner")
|
|
expert = educational_set.get_template("educational_expert")
|
|
scholarly = educational_set.get_template("educational_scholarly")
|
|
|
|
assert beginner is not None
|
|
assert expert is not None
|
|
assert scholarly is not None
|
|
|
|
assert beginner.complexity_level == ComplexityLevel.BEGINNER
|
|
assert expert.complexity_level == ComplexityLevel.EXPERT
|
|
assert scholarly.complexity_level == ComplexityLevel.SCHOLARLY
|
|
|
|
# Check synthesis template
|
|
assert educational_set.synthesis_template is not None
|
|
assert educational_set.synthesis_template.id == "educational_synthesis"
|
|
|
|
def test_domain_templates_creation(self):
|
|
"""Test creating domain template set."""
|
|
domain_set = create_domain_templates()
|
|
|
|
assert domain_set.id == "domain_perspectives"
|
|
assert domain_set.template_type == TemplateType.DOMAIN
|
|
assert len(domain_set.templates) == 3
|
|
|
|
# Check individual templates
|
|
technical = domain_set.get_template("domain_technical")
|
|
business = domain_set.get_template("domain_business")
|
|
ux = domain_set.get_template("domain_ux")
|
|
|
|
assert technical is not None
|
|
assert business is not None
|
|
assert ux is not None
|
|
|
|
assert "technical" in technical.analysis_focus[0].lower()
|
|
assert "business" in business.analysis_focus[0].lower()
|
|
assert "user" in ux.analysis_focus[0].lower()
|
|
|
|
def test_default_registry_creation(self):
|
|
"""Test creating default registry with all templates."""
|
|
registry = create_default_registry()
|
|
|
|
# Check that all templates are registered
|
|
all_templates = registry.list_templates()
|
|
assert len(all_templates) >= 6 # 3 educational + 3 domain + synthesis
|
|
|
|
# Check template sets
|
|
all_sets = registry.list_template_sets()
|
|
assert len(all_sets) >= 2 # educational + domain
|
|
|
|
# Check specific templates
|
|
beginner = registry.get_template("educational_beginner")
|
|
expert = registry.get_template("educational_expert")
|
|
scholarly = registry.get_template("educational_scholarly")
|
|
|
|
assert beginner is not None
|
|
assert expert is not None
|
|
assert scholarly is not None |