trax/tests/test_multi_pass_enhancement.py

126 lines
4.6 KiB
Python

"""Tests for enhancement pass (subtask 7.4)."""
from __future__ import annotations
import pytest
from typing import List, Dict, Any
from unittest.mock import patch, MagicMock
from src.services.multi_pass_transcription import MultiPassTranscriptionPipeline
def _segments() -> List[Dict[str, Any]]:
return [
{"start": 0.0, "end": 1.0, "text": "BP normal, check ECG"},
{"start": 1.2, "end": 2.4, "text": "tachycardia suspected"},
]
def _technical_segments() -> List[Dict[str, Any]]:
return [
{"start": 0.0, "end": 1.0, "text": "The algorithm implements a singleton pattern"},
{"start": 1.2, "end": 2.4, "text": "for thread safety in the software system"},
]
@pytest.mark.asyncio
@patch("src.services.multi_pass_transcription.DomainAdaptationManager")
async def test_enhancement_tags_segments_with_domain(mock_mgr_cls: MagicMock):
# Create a proper mock domain adaptation manager
mock_mgr = MagicMock()
mock_mgr.domain_adapter.domain_adapters = {"medical": "mock_adapter"}
mock_mgr.domain_adapter.switch_adapter.return_value = "mock_adapted_model"
mock_mgr.domain_detector.detect_domain_from_text.return_value = "medical"
mock_mgr_cls.return_value = mock_mgr
# Create pipeline with domain adaptation enabled
pipeline = MultiPassTranscriptionPipeline(
domain_adapter=mock_mgr,
auto_detect_domain=True
)
segments = _segments()
enhanced = await pipeline._perform_enhancement_pass(segments, domain="medical")
assert len(enhanced) == len(segments)
for e, s in zip(enhanced, segments):
assert e["text"].startswith("[MEDICAL]")
assert s["text"] in e["text"]
@pytest.mark.asyncio
@patch("src.services.multi_pass_transcription.DomainAdaptationManager")
async def test_enhancement_tags_technical_segments(mock_mgr_cls: MagicMock):
# Create a proper mock domain adaptation manager
mock_mgr = MagicMock()
mock_mgr.domain_adapter.domain_adapters = {"technical": "mock_adapter"}
mock_mgr.domain_adapter.switch_adapter.return_value = "mock_adapted_model"
mock_mgr.domain_detector.detect_domain_from_text.return_value = "technical"
mock_mgr_cls.return_value = mock_mgr
# Create pipeline with domain adaptation enabled
pipeline = MultiPassTranscriptionPipeline(
domain_adapter=mock_mgr,
auto_detect_domain=True
)
segments = _technical_segments()
enhanced = await pipeline._perform_enhancement_pass(segments, domain="technical")
assert len(enhanced) == len(segments)
for e, s in zip(enhanced, segments):
assert e["text"].startswith("[TECHNICAL]")
assert s["text"] in e["text"]
@pytest.mark.asyncio
@patch("src.services.multi_pass_transcription.DomainAdaptationManager")
async def test_enhancement_auto_detect_domain(mock_mgr_cls: MagicMock):
# Create a proper mock domain adaptation manager
mock_mgr = MagicMock()
mock_mgr.domain_adapter.domain_adapters = {"medical": "mock_adapter"}
mock_mgr.domain_adapter.switch_adapter.return_value = "mock_adapted_model"
mock_mgr.domain_detector.detect_domain_from_text.return_value = "medical"
mock_mgr_cls.return_value = mock_mgr
# Create pipeline with domain adaptation enabled
pipeline = MultiPassTranscriptionPipeline(
domain_adapter=mock_mgr,
auto_detect_domain=True
)
segments = _segments()
# Test without specifying domain - should auto-detect
enhanced = await pipeline._perform_enhancement_pass(segments, domain=None)
assert len(enhanced) == len(segments)
for e, s in zip(enhanced, segments):
assert e["text"].startswith("[MEDICAL]")
assert s["text"] in e["text"]
@pytest.mark.asyncio
@patch("src.services.multi_pass_transcription.DomainAdaptationManager")
async def test_enhancement_fallback_to_general(mock_mgr_cls: MagicMock):
# Create a mock domain adaptation manager with no adapters
mock_mgr = MagicMock()
mock_mgr.domain_adapter.domain_adapters = {}
mock_mgr.domain_detector.detect_domain_from_text.return_value = "general"
mock_mgr_cls.return_value = mock_mgr
# Create pipeline with domain adaptation enabled
pipeline = MultiPassTranscriptionPipeline(
domain_adapter=mock_mgr,
auto_detect_domain=True
)
segments = _segments()
# Test with medical domain but no adapter available
enhanced = await pipeline._perform_enhancement_pass(segments, domain="medical")
assert len(enhanced) == len(segments)
for e, s in zip(enhanced, segments):
# Should fall back to general domain prefix
assert e["text"].startswith("[GENERAL]")
assert s["text"] in e["text"]