126 lines
4.6 KiB
Python
126 lines
4.6 KiB
Python
"""Tests for enhancement pass (subtask 7.4)."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import pytest
|
|
from typing import List, Dict, Any
|
|
from unittest.mock import patch, MagicMock
|
|
|
|
from src.services.multi_pass_transcription import MultiPassTranscriptionPipeline
|
|
|
|
|
|
def _segments() -> List[Dict[str, Any]]:
|
|
return [
|
|
{"start": 0.0, "end": 1.0, "text": "BP normal, check ECG"},
|
|
{"start": 1.2, "end": 2.4, "text": "tachycardia suspected"},
|
|
]
|
|
|
|
|
|
def _technical_segments() -> List[Dict[str, Any]]:
|
|
return [
|
|
{"start": 0.0, "end": 1.0, "text": "The algorithm implements a singleton pattern"},
|
|
{"start": 1.2, "end": 2.4, "text": "for thread safety in the software system"},
|
|
]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("src.services.multi_pass_transcription.DomainAdaptationManager")
|
|
async def test_enhancement_tags_segments_with_domain(mock_mgr_cls: MagicMock):
|
|
# Create a proper mock domain adaptation manager
|
|
mock_mgr = MagicMock()
|
|
mock_mgr.domain_adapter.domain_adapters = {"medical": "mock_adapter"}
|
|
mock_mgr.domain_adapter.switch_adapter.return_value = "mock_adapted_model"
|
|
mock_mgr.domain_detector.detect_domain_from_text.return_value = "medical"
|
|
mock_mgr_cls.return_value = mock_mgr
|
|
|
|
# Create pipeline with domain adaptation enabled
|
|
pipeline = MultiPassTranscriptionPipeline(
|
|
domain_adapter=mock_mgr,
|
|
auto_detect_domain=True
|
|
)
|
|
segments = _segments()
|
|
|
|
enhanced = await pipeline._perform_enhancement_pass(segments, domain="medical")
|
|
|
|
assert len(enhanced) == len(segments)
|
|
for e, s in zip(enhanced, segments):
|
|
assert e["text"].startswith("[MEDICAL]")
|
|
assert s["text"] in e["text"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("src.services.multi_pass_transcription.DomainAdaptationManager")
|
|
async def test_enhancement_tags_technical_segments(mock_mgr_cls: MagicMock):
|
|
# Create a proper mock domain adaptation manager
|
|
mock_mgr = MagicMock()
|
|
mock_mgr.domain_adapter.domain_adapters = {"technical": "mock_adapter"}
|
|
mock_mgr.domain_adapter.switch_adapter.return_value = "mock_adapted_model"
|
|
mock_mgr.domain_detector.detect_domain_from_text.return_value = "technical"
|
|
mock_mgr_cls.return_value = mock_mgr
|
|
|
|
# Create pipeline with domain adaptation enabled
|
|
pipeline = MultiPassTranscriptionPipeline(
|
|
domain_adapter=mock_mgr,
|
|
auto_detect_domain=True
|
|
)
|
|
segments = _technical_segments()
|
|
|
|
enhanced = await pipeline._perform_enhancement_pass(segments, domain="technical")
|
|
|
|
assert len(enhanced) == len(segments)
|
|
for e, s in zip(enhanced, segments):
|
|
assert e["text"].startswith("[TECHNICAL]")
|
|
assert s["text"] in e["text"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("src.services.multi_pass_transcription.DomainAdaptationManager")
|
|
async def test_enhancement_auto_detect_domain(mock_mgr_cls: MagicMock):
|
|
# Create a proper mock domain adaptation manager
|
|
mock_mgr = MagicMock()
|
|
mock_mgr.domain_adapter.domain_adapters = {"medical": "mock_adapter"}
|
|
mock_mgr.domain_adapter.switch_adapter.return_value = "mock_adapted_model"
|
|
mock_mgr.domain_detector.detect_domain_from_text.return_value = "medical"
|
|
mock_mgr_cls.return_value = mock_mgr
|
|
|
|
# Create pipeline with domain adaptation enabled
|
|
pipeline = MultiPassTranscriptionPipeline(
|
|
domain_adapter=mock_mgr,
|
|
auto_detect_domain=True
|
|
)
|
|
segments = _segments()
|
|
|
|
# Test without specifying domain - should auto-detect
|
|
enhanced = await pipeline._perform_enhancement_pass(segments, domain=None)
|
|
|
|
assert len(enhanced) == len(segments)
|
|
for e, s in zip(enhanced, segments):
|
|
assert e["text"].startswith("[MEDICAL]")
|
|
assert s["text"] in e["text"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("src.services.multi_pass_transcription.DomainAdaptationManager")
|
|
async def test_enhancement_fallback_to_general(mock_mgr_cls: MagicMock):
|
|
# Create a mock domain adaptation manager with no adapters
|
|
mock_mgr = MagicMock()
|
|
mock_mgr.domain_adapter.domain_adapters = {}
|
|
mock_mgr.domain_detector.detect_domain_from_text.return_value = "general"
|
|
mock_mgr_cls.return_value = mock_mgr
|
|
|
|
# Create pipeline with domain adaptation enabled
|
|
pipeline = MultiPassTranscriptionPipeline(
|
|
domain_adapter=mock_mgr,
|
|
auto_detect_domain=True
|
|
)
|
|
segments = _segments()
|
|
|
|
# Test with medical domain but no adapter available
|
|
enhanced = await pipeline._perform_enhancement_pass(segments, domain="medical")
|
|
|
|
assert len(enhanced) == len(segments)
|
|
for e, s in zip(enhanced, segments):
|
|
# Should fall back to general domain prefix
|
|
assert e["text"].startswith("[GENERAL]")
|
|
assert s["text"] in e["text"]
|