trax/examples/domain_detection_demo.py

158 lines
5.4 KiB
Python

#!/usr/bin/env python3
"""Domain Detection Integration Demo
This script demonstrates how domain detection is integrated into the transcription pipeline.
It shows both text-based and path-based domain detection, as well as the rule-based fallback.
"""
import logging
from pathlib import Path
from src.services.domain_adaptation import DomainDetector
from src.services.multi_pass_transcription import MultiPassTranscriptionPipeline
from src.services.domain_adaptation_manager import DomainAdaptationManager
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def demo_domain_detection():
"""Demonstrate domain detection capabilities."""
print("🔍 Domain Detection Integration Demo")
print("=" * 50)
# Initialize domain detector
detector = DomainDetector()
print(f"✅ Domain detector initialized with domains: {detector.domains}")
print()
# Test text-based domain detection
print("📝 Text-based Domain Detection:")
print("-" * 30)
test_texts = [
("The patient shows symptoms of acute myocardial infarction", "medical"),
("Implement the algorithm for thread safety in the software system", "technical"),
("The research methodology follows a quantitative approach", "academic"),
("Hello world, how are you today?", "general"),
("The contract agreement requires legal compliance", "legal")
]
for text, expected_domain in test_texts:
detected_domain = detector.detect_domain_from_text(text)
status = "" if detected_domain == expected_domain else ""
print(f"{status} '{text[:50]}...' -> {detected_domain} (expected: {expected_domain})")
print()
# Test path-based domain detection
print("📁 Path-based Domain Detection:")
print("-" * 30)
test_paths = [
("data/media/medical_interview_patient_123.wav", "medical"),
("data/media/tech_tutorial_python_programming.mp3", "technical"),
("data/media/research_presentation_university_lecture.wav", "academic"),
("data/media/legal_deposition_case_456.mp4", "legal"),
("data/media/recording_001.wav", None) # No domain indicators
]
for path_str, expected_domain in test_paths:
path = Path(path_str)
detected_domain = detector.detect_domain_from_path(path)
if expected_domain is None:
status = "" if detected_domain is None else ""
print(f"{status} '{path.name}' -> {detected_domain} (expected: None)")
else:
status = "" if detected_domain == expected_domain else ""
print(f"{status} '{path.name}' -> {detected_domain} (expected: {expected_domain})")
print()
# Test domain probabilities
print("📊 Domain Probability Scoring:")
print("-" * 30)
sample_text = "The patient requires immediate medical attention for diagnosis"
probabilities = detector.get_domain_probabilities(sample_text)
print(f"Text: '{sample_text}'")
print("Domain probabilities:")
for domain, prob in sorted(probabilities.items(), key=lambda x: x[1], reverse=True):
print(f" {domain}: {prob:.3f}")
print()
# Test pipeline integration
print("🔗 Pipeline Integration Demo:")
print("-" * 30)
# Create domain adaptation manager
domain_manager = DomainAdaptationManager()
# Create pipeline with domain adaptation
pipeline = MultiPassTranscriptionPipeline(
domain_adapter=domain_manager,
auto_detect_domain=True
)
print(f"Pipeline auto-detect enabled: {pipeline.auto_detect_domain}")
print(f"Domain detector initialized: {pipeline.domain_detector is not None}")
if pipeline.domain_detector:
print(f"Available domains: {pipeline.domain_detector.domains}")
print("✅ Domain detection is properly integrated into the pipeline")
else:
print("❌ Domain detection is not properly integrated")
print()
# Test confidence thresholds
print("🎯 Confidence Threshold Testing:")
print("-" * 30)
confidence_levels = [0.3, 0.5, 0.7, 0.9]
test_text = "The patient shows symptoms of diabetes mellitus"
for threshold in confidence_levels:
detected_domain = detector.detect_domain(test_text, threshold=threshold)
print(f"Threshold {threshold}: {detected_domain}")
print()
print("🎉 Domain Detection Integration Demo Complete!")
def demo_rule_based_fallback():
"""Demonstrate rule-based detection fallback."""
print("\n🔄 Rule-based Detection Fallback Demo")
print("=" * 50)
detector = DomainDetector()
# Test with untrained detector (should use rule-based detection)
print("Testing with untrained detector (ML model not trained):")
test_cases = [
"The patient needs immediate medical attention",
"Implement the singleton pattern for thread safety",
"The research methodology follows quantitative analysis",
"This is a general conversation about the weather"
]
for text in test_cases:
detected_domain = detector.detect_domain(text)
print(f" '{text[:40]}...' -> {detected_domain}")
print("\n✅ Rule-based fallback working correctly!")
if __name__ == "__main__":
try:
demo_domain_detection()
demo_rule_based_fallback()
except Exception as e:
logger.error(f"Demo failed: {e}")
raise