trax/examples/multi_pass_integration_demo.py

259 lines
8.6 KiB
Python

#!/usr/bin/env python3
"""Demonstration of MultiPassTranscriptionPipeline integration with DomainEnhancementPipeline.
This script shows how the domain-specific enhancement pipeline integrates with
the multi-pass transcription pipeline for Task 8.3.
"""
import asyncio
import logging
from pathlib import Path
from typing import Dict, Any
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
async def demo_multi_pass_with_domain_enhancement():
"""Demonstrate the integration between MultiPassTranscriptionPipeline and DomainEnhancementPipeline."""
try:
# Import the integrated pipeline
from src.services.multi_pass_transcription import MultiPassTranscriptionPipeline
from src.services.domain_enhancement import DomainEnhancementConfig
print("🚀 MultiPassTranscriptionPipeline + DomainEnhancementPipeline Integration Demo")
print("=" * 70)
# Create domain enhancement configuration
config = DomainEnhancementConfig(
domain="technical",
enable_terminology_enhancement=True,
enable_citation_handling=True,
enable_formatting_optimization=True,
quality_threshold=0.7,
max_enhancement_iterations=3
)
print(f"📋 Domain Enhancement Config: {config}")
print()
# Create the integrated pipeline
pipeline = MultiPassTranscriptionPipeline(
auto_detect_domain=True,
domain_enhancement_config=config
)
print(f"🔧 Pipeline created with domain enhancement: {pipeline.domain_enhancement_config is not None}")
print()
# Test segments for enhancement
test_segments = [
{
"text": "The algorithm uses machine learning to process data",
"start": 0.0,
"end": 1.0,
"confidence": 0.9
},
{
"text": "We need to optimize the neural network architecture",
"start": 1.0,
"end": 2.0,
"confidence": 0.85
},
{
"text": "The API endpoints should follow RESTful principles",
"start": 2.0,
"end": 3.0,
"confidence": 0.88
}
]
print("📝 Test Segments:")
for i, seg in enumerate(test_segments, 1):
print(f" {i}. [{seg['start']:.1f}s - {seg['end']:.1f}s] {seg['text']}")
print()
# Perform domain enhancement
print("🔄 Performing domain enhancement...")
enhanced_segments = await pipeline._perform_enhancement_pass(
test_segments,
domain="technical"
)
print("✅ Enhancement completed!")
print()
# Display results
print("📊 Enhanced Segments:")
for i, seg in enumerate(enhanced_segments, 1):
print(f" {i}. [{seg['start']:.1f}s - {seg['end']:.1f}s]")
print(f" Text: {seg['text']}")
print(f" Domain: {seg.get('domain', 'unknown')}")
if 'enhancement_confidence' in seg:
print(f" Enhancement Confidence: {seg['enhancement_confidence']:.3f}")
if 'enhancement_improvements' in seg:
print(f" Improvements: {', '.join(seg['enhancement_improvements'])}")
if 'enhancement_terminology_corrections' in seg:
print(f" Terminology Corrections: {', '.join(seg['enhancement_terminology_corrections'])}")
if 'enhancement_quality_metrics' in seg:
print(f" Quality Metrics: {seg['enhancement_quality_metrics']}")
print()
# Show pipeline state
print("🔍 Pipeline State:")
print(f" Domain Enhancement Pipeline: {pipeline.domain_enhancement_pipeline is not None}")
print(f" Auto-detect Domain: {pipeline.auto_detect_domain}")
print(f" Domain Enhancement Config: {pipeline.domain_enhancement_config is not None}")
return True
except Exception as e:
logger.error(f"Demo failed: {e}")
return False
async def demo_domain_switching():
"""Demonstrate how the pipeline handles different domains."""
try:
from src.services.multi_pass_transcription import MultiPassTranscriptionPipeline
from src.services.domain_enhancement import DomainEnhancementConfig
print("\n🔄 Domain Switching Demo")
print("=" * 40)
# Create pipeline with medical domain config
medical_config = DomainEnhancementConfig(
domain="medical",
enable_terminology_enhancement=True,
enable_citation_handling=False,
enable_formatting_optimization=True,
quality_threshold=0.8,
max_enhancement_iterations=2
)
pipeline = MultiPassTranscriptionPipeline(
auto_detect_domain=True,
domain_enhancement_config=medical_config
)
# Test medical content
medical_segments = [
{
"text": "The patient exhibits symptoms of hypertension",
"start": 0.0,
"end": 1.0
},
{
"text": "We need to monitor blood pressure regularly",
"start": 1.0,
"end": 2.0
}
]
print("🏥 Processing medical content...")
enhanced_medical = await pipeline._perform_enhancement_pass(
medical_segments,
domain="medical"
)
print(f"✅ Medical enhancement completed. Domain: {enhanced_medical[0].get('domain', 'unknown')}")
return True
except Exception as e:
logger.error(f"Domain switching demo failed: {e}")
return False
async def demo_fallback_behavior():
"""Demonstrate fallback behavior when enhancement fails."""
try:
from src.services.multi_pass_transcription import MultiPassTranscriptionPipeline
print("\n🛡️ Fallback Behavior Demo")
print("=" * 40)
# Create pipeline without domain enhancement config
pipeline = MultiPassTranscriptionPipeline(
auto_detect_domain=False,
domain_enhancement_config=None
)
# Test segments
test_segments = [
{
"text": "General content without specific domain",
"start": 0.0,
"end": 1.0
}
]
print("📝 Processing general content...")
enhanced_segments = await pipeline._perform_enhancement_pass(
test_segments,
domain="general"
)
print(f"✅ Fallback enhancement completed. Domain: {enhanced_segments[0].get('domain', 'unknown')}")
print(f" Text: {enhanced_segments[0]['text']}")
return True
except Exception as e:
logger.error(f"Fallback demo failed: {e}")
return False
async def main():
"""Run all demonstrations."""
print("🎯 MultiPassTranscriptionPipeline + DomainEnhancementPipeline Integration")
print("=" * 70)
print()
# Run demonstrations
demos = [
("Basic Integration", demo_multi_pass_with_domain_enhancement),
("Domain Switching", demo_domain_switching),
("Fallback Behavior", demo_fallback_behavior)
]
results = []
for name, demo_func in demos:
print(f"🎬 Running: {name}")
print("-" * 40)
try:
result = await demo_func()
results.append((name, result))
print(f"{name}: {'SUCCESS' if result else 'FAILED'}")
except Exception as e:
print(f"{name}: FAILED - {e}")
results.append((name, False))
print()
# Summary
print("📋 Demo Summary")
print("=" * 40)
for name, result in results:
status = "✅ PASS" if result else "❌ FAIL"
print(f" {name}: {status}")
success_count = sum(1 for _, result in results if result)
total_count = len(results)
print(f"\n🎉 Overall: {success_count}/{total_count} demos passed")
if success_count == total_count:
print("🚀 All integrations working correctly!")
else:
print("⚠️ Some integrations need attention.")
if __name__ == "__main__":
asyncio.run(main())