158 lines
5.4 KiB
Python
158 lines
5.4 KiB
Python
#!/usr/bin/env python3
|
|
"""Domain Detection Integration Demo
|
|
|
|
This script demonstrates how domain detection is integrated into the transcription pipeline.
|
|
It shows both text-based and path-based domain detection, as well as the rule-based fallback.
|
|
"""
|
|
|
|
import logging
|
|
from pathlib import Path
|
|
|
|
from src.services.domain_adaptation import DomainDetector
|
|
from src.services.multi_pass_transcription import MultiPassTranscriptionPipeline
|
|
from src.services.domain_adaptation_manager import DomainAdaptationManager
|
|
|
|
# Set up logging
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def demo_domain_detection():
|
|
"""Demonstrate domain detection capabilities."""
|
|
print("🔍 Domain Detection Integration Demo")
|
|
print("=" * 50)
|
|
|
|
# Initialize domain detector
|
|
detector = DomainDetector()
|
|
print(f"✅ Domain detector initialized with domains: {detector.domains}")
|
|
print()
|
|
|
|
# Test text-based domain detection
|
|
print("📝 Text-based Domain Detection:")
|
|
print("-" * 30)
|
|
|
|
test_texts = [
|
|
("The patient shows symptoms of acute myocardial infarction", "medical"),
|
|
("Implement the algorithm for thread safety in the software system", "technical"),
|
|
("The research methodology follows a quantitative approach", "academic"),
|
|
("Hello world, how are you today?", "general"),
|
|
("The contract agreement requires legal compliance", "legal")
|
|
]
|
|
|
|
for text, expected_domain in test_texts:
|
|
detected_domain = detector.detect_domain_from_text(text)
|
|
status = "✅" if detected_domain == expected_domain else "❌"
|
|
print(f"{status} '{text[:50]}...' -> {detected_domain} (expected: {expected_domain})")
|
|
|
|
print()
|
|
|
|
# Test path-based domain detection
|
|
print("📁 Path-based Domain Detection:")
|
|
print("-" * 30)
|
|
|
|
test_paths = [
|
|
("data/media/medical_interview_patient_123.wav", "medical"),
|
|
("data/media/tech_tutorial_python_programming.mp3", "technical"),
|
|
("data/media/research_presentation_university_lecture.wav", "academic"),
|
|
("data/media/legal_deposition_case_456.mp4", "legal"),
|
|
("data/media/recording_001.wav", None) # No domain indicators
|
|
]
|
|
|
|
for path_str, expected_domain in test_paths:
|
|
path = Path(path_str)
|
|
detected_domain = detector.detect_domain_from_path(path)
|
|
if expected_domain is None:
|
|
status = "✅" if detected_domain is None else "❌"
|
|
print(f"{status} '{path.name}' -> {detected_domain} (expected: None)")
|
|
else:
|
|
status = "✅" if detected_domain == expected_domain else "❌"
|
|
print(f"{status} '{path.name}' -> {detected_domain} (expected: {expected_domain})")
|
|
|
|
print()
|
|
|
|
# Test domain probabilities
|
|
print("📊 Domain Probability Scoring:")
|
|
print("-" * 30)
|
|
|
|
sample_text = "The patient requires immediate medical attention for diagnosis"
|
|
probabilities = detector.get_domain_probabilities(sample_text)
|
|
|
|
print(f"Text: '{sample_text}'")
|
|
print("Domain probabilities:")
|
|
for domain, prob in sorted(probabilities.items(), key=lambda x: x[1], reverse=True):
|
|
print(f" {domain}: {prob:.3f}")
|
|
|
|
print()
|
|
|
|
# Test pipeline integration
|
|
print("🔗 Pipeline Integration Demo:")
|
|
print("-" * 30)
|
|
|
|
# Create domain adaptation manager
|
|
domain_manager = DomainAdaptationManager()
|
|
|
|
# Create pipeline with domain adaptation
|
|
pipeline = MultiPassTranscriptionPipeline(
|
|
domain_adapter=domain_manager,
|
|
auto_detect_domain=True
|
|
)
|
|
|
|
print(f"Pipeline auto-detect enabled: {pipeline.auto_detect_domain}")
|
|
print(f"Domain detector initialized: {pipeline.domain_detector is not None}")
|
|
|
|
if pipeline.domain_detector:
|
|
print(f"Available domains: {pipeline.domain_detector.domains}")
|
|
print("✅ Domain detection is properly integrated into the pipeline")
|
|
else:
|
|
print("❌ Domain detection is not properly integrated")
|
|
|
|
print()
|
|
|
|
# Test confidence thresholds
|
|
print("🎯 Confidence Threshold Testing:")
|
|
print("-" * 30)
|
|
|
|
confidence_levels = [0.3, 0.5, 0.7, 0.9]
|
|
test_text = "The patient shows symptoms of diabetes mellitus"
|
|
|
|
for threshold in confidence_levels:
|
|
detected_domain = detector.detect_domain(test_text, threshold=threshold)
|
|
print(f"Threshold {threshold}: {detected_domain}")
|
|
|
|
print()
|
|
print("🎉 Domain Detection Integration Demo Complete!")
|
|
|
|
|
|
def demo_rule_based_fallback():
|
|
"""Demonstrate rule-based detection fallback."""
|
|
print("\n🔄 Rule-based Detection Fallback Demo")
|
|
print("=" * 50)
|
|
|
|
detector = DomainDetector()
|
|
|
|
# Test with untrained detector (should use rule-based detection)
|
|
print("Testing with untrained detector (ML model not trained):")
|
|
|
|
test_cases = [
|
|
"The patient needs immediate medical attention",
|
|
"Implement the singleton pattern for thread safety",
|
|
"The research methodology follows quantitative analysis",
|
|
"This is a general conversation about the weather"
|
|
]
|
|
|
|
for text in test_cases:
|
|
detected_domain = detector.detect_domain(text)
|
|
print(f" '{text[:40]}...' -> {detected_domain}")
|
|
|
|
print("\n✅ Rule-based fallback working correctly!")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
try:
|
|
demo_domain_detection()
|
|
demo_rule_based_fallback()
|
|
except Exception as e:
|
|
logger.error(f"Demo failed: {e}")
|
|
raise
|
|
|