"""Tests for enhancement pass (subtask 7.4).""" from __future__ import annotations import pytest from typing import List, Dict, Any from unittest.mock import patch, MagicMock from src.services.multi_pass_transcription import MultiPassTranscriptionPipeline def _segments() -> List[Dict[str, Any]]: return [ {"start": 0.0, "end": 1.0, "text": "BP normal, check ECG"}, {"start": 1.2, "end": 2.4, "text": "tachycardia suspected"}, ] def _technical_segments() -> List[Dict[str, Any]]: return [ {"start": 0.0, "end": 1.0, "text": "The algorithm implements a singleton pattern"}, {"start": 1.2, "end": 2.4, "text": "for thread safety in the software system"}, ] @pytest.mark.asyncio @patch("src.services.multi_pass_transcription.DomainAdaptationManager") async def test_enhancement_tags_segments_with_domain(mock_mgr_cls: MagicMock): # Create a proper mock domain adaptation manager mock_mgr = MagicMock() mock_mgr.domain_adapter.domain_adapters = {"medical": "mock_adapter"} mock_mgr.domain_adapter.switch_adapter.return_value = "mock_adapted_model" mock_mgr.domain_detector.detect_domain_from_text.return_value = "medical" mock_mgr_cls.return_value = mock_mgr # Create pipeline with domain adaptation enabled pipeline = MultiPassTranscriptionPipeline( domain_adapter=mock_mgr, auto_detect_domain=True ) segments = _segments() enhanced = await pipeline._perform_enhancement_pass(segments, domain="medical") assert len(enhanced) == len(segments) for e, s in zip(enhanced, segments): assert e["text"].startswith("[MEDICAL]") assert s["text"] in e["text"] @pytest.mark.asyncio @patch("src.services.multi_pass_transcription.DomainAdaptationManager") async def test_enhancement_tags_technical_segments(mock_mgr_cls: MagicMock): # Create a proper mock domain adaptation manager mock_mgr = MagicMock() mock_mgr.domain_adapter.domain_adapters = {"technical": "mock_adapter"} mock_mgr.domain_adapter.switch_adapter.return_value = "mock_adapted_model" mock_mgr.domain_detector.detect_domain_from_text.return_value = "technical" mock_mgr_cls.return_value = mock_mgr # Create pipeline with domain adaptation enabled pipeline = MultiPassTranscriptionPipeline( domain_adapter=mock_mgr, auto_detect_domain=True ) segments = _technical_segments() enhanced = await pipeline._perform_enhancement_pass(segments, domain="technical") assert len(enhanced) == len(segments) for e, s in zip(enhanced, segments): assert e["text"].startswith("[TECHNICAL]") assert s["text"] in e["text"] @pytest.mark.asyncio @patch("src.services.multi_pass_transcription.DomainAdaptationManager") async def test_enhancement_auto_detect_domain(mock_mgr_cls: MagicMock): # Create a proper mock domain adaptation manager mock_mgr = MagicMock() mock_mgr.domain_adapter.domain_adapters = {"medical": "mock_adapter"} mock_mgr.domain_adapter.switch_adapter.return_value = "mock_adapted_model" mock_mgr.domain_detector.detect_domain_from_text.return_value = "medical" mock_mgr_cls.return_value = mock_mgr # Create pipeline with domain adaptation enabled pipeline = MultiPassTranscriptionPipeline( domain_adapter=mock_mgr, auto_detect_domain=True ) segments = _segments() # Test without specifying domain - should auto-detect enhanced = await pipeline._perform_enhancement_pass(segments, domain=None) assert len(enhanced) == len(segments) for e, s in zip(enhanced, segments): assert e["text"].startswith("[MEDICAL]") assert s["text"] in e["text"] @pytest.mark.asyncio @patch("src.services.multi_pass_transcription.DomainAdaptationManager") async def test_enhancement_fallback_to_general(mock_mgr_cls: MagicMock): # Create a mock domain adaptation manager with no adapters mock_mgr = MagicMock() mock_mgr.domain_adapter.domain_adapters = {} mock_mgr.domain_detector.detect_domain_from_text.return_value = "general" mock_mgr_cls.return_value = mock_mgr # Create pipeline with domain adaptation enabled pipeline = MultiPassTranscriptionPipeline( domain_adapter=mock_mgr, auto_detect_domain=True ) segments = _segments() # Test with medical domain but no adapter available enhanced = await pipeline._perform_enhancement_pass(segments, domain="medical") assert len(enhanced) == len(segments) for e, s in zip(enhanced, segments): # Should fall back to general domain prefix assert e["text"].startswith("[GENERAL]") assert s["text"] in e["text"]