259 lines
8.6 KiB
Python
259 lines
8.6 KiB
Python
#!/usr/bin/env python3
|
|
"""Demonstration of MultiPassTranscriptionPipeline integration with DomainEnhancementPipeline.
|
|
|
|
This script shows how the domain-specific enhancement pipeline integrates with
|
|
the multi-pass transcription pipeline for Task 8.3.
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
from pathlib import Path
|
|
from typing import Dict, Any
|
|
|
|
# Configure logging
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
async def demo_multi_pass_with_domain_enhancement():
|
|
"""Demonstrate the integration between MultiPassTranscriptionPipeline and DomainEnhancementPipeline."""
|
|
|
|
try:
|
|
# Import the integrated pipeline
|
|
from src.services.multi_pass_transcription import MultiPassTranscriptionPipeline
|
|
from src.services.domain_enhancement import DomainEnhancementConfig
|
|
|
|
print("🚀 MultiPassTranscriptionPipeline + DomainEnhancementPipeline Integration Demo")
|
|
print("=" * 70)
|
|
|
|
# Create domain enhancement configuration
|
|
config = DomainEnhancementConfig(
|
|
domain="technical",
|
|
enable_terminology_enhancement=True,
|
|
enable_citation_handling=True,
|
|
enable_formatting_optimization=True,
|
|
quality_threshold=0.7,
|
|
max_enhancement_iterations=3
|
|
)
|
|
|
|
print(f"📋 Domain Enhancement Config: {config}")
|
|
print()
|
|
|
|
# Create the integrated pipeline
|
|
pipeline = MultiPassTranscriptionPipeline(
|
|
auto_detect_domain=True,
|
|
domain_enhancement_config=config
|
|
)
|
|
|
|
print(f"🔧 Pipeline created with domain enhancement: {pipeline.domain_enhancement_config is not None}")
|
|
print()
|
|
|
|
# Test segments for enhancement
|
|
test_segments = [
|
|
{
|
|
"text": "The algorithm uses machine learning to process data",
|
|
"start": 0.0,
|
|
"end": 1.0,
|
|
"confidence": 0.9
|
|
},
|
|
{
|
|
"text": "We need to optimize the neural network architecture",
|
|
"start": 1.0,
|
|
"end": 2.0,
|
|
"confidence": 0.85
|
|
},
|
|
{
|
|
"text": "The API endpoints should follow RESTful principles",
|
|
"start": 2.0,
|
|
"end": 3.0,
|
|
"confidence": 0.88
|
|
}
|
|
]
|
|
|
|
print("📝 Test Segments:")
|
|
for i, seg in enumerate(test_segments, 1):
|
|
print(f" {i}. [{seg['start']:.1f}s - {seg['end']:.1f}s] {seg['text']}")
|
|
print()
|
|
|
|
# Perform domain enhancement
|
|
print("🔄 Performing domain enhancement...")
|
|
enhanced_segments = await pipeline._perform_enhancement_pass(
|
|
test_segments,
|
|
domain="technical"
|
|
)
|
|
|
|
print("✅ Enhancement completed!")
|
|
print()
|
|
|
|
# Display results
|
|
print("📊 Enhanced Segments:")
|
|
for i, seg in enumerate(enhanced_segments, 1):
|
|
print(f" {i}. [{seg['start']:.1f}s - {seg['end']:.1f}s]")
|
|
print(f" Text: {seg['text']}")
|
|
print(f" Domain: {seg.get('domain', 'unknown')}")
|
|
|
|
if 'enhancement_confidence' in seg:
|
|
print(f" Enhancement Confidence: {seg['enhancement_confidence']:.3f}")
|
|
|
|
if 'enhancement_improvements' in seg:
|
|
print(f" Improvements: {', '.join(seg['enhancement_improvements'])}")
|
|
|
|
if 'enhancement_terminology_corrections' in seg:
|
|
print(f" Terminology Corrections: {', '.join(seg['enhancement_terminology_corrections'])}")
|
|
|
|
if 'enhancement_quality_metrics' in seg:
|
|
print(f" Quality Metrics: {seg['enhancement_quality_metrics']}")
|
|
|
|
print()
|
|
|
|
# Show pipeline state
|
|
print("🔍 Pipeline State:")
|
|
print(f" Domain Enhancement Pipeline: {pipeline.domain_enhancement_pipeline is not None}")
|
|
print(f" Auto-detect Domain: {pipeline.auto_detect_domain}")
|
|
print(f" Domain Enhancement Config: {pipeline.domain_enhancement_config is not None}")
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Demo failed: {e}")
|
|
return False
|
|
|
|
async def demo_domain_switching():
|
|
"""Demonstrate how the pipeline handles different domains."""
|
|
|
|
try:
|
|
from src.services.multi_pass_transcription import MultiPassTranscriptionPipeline
|
|
from src.services.domain_enhancement import DomainEnhancementConfig
|
|
|
|
print("\n🔄 Domain Switching Demo")
|
|
print("=" * 40)
|
|
|
|
# Create pipeline with medical domain config
|
|
medical_config = DomainEnhancementConfig(
|
|
domain="medical",
|
|
enable_terminology_enhancement=True,
|
|
enable_citation_handling=False,
|
|
enable_formatting_optimization=True,
|
|
quality_threshold=0.8,
|
|
max_enhancement_iterations=2
|
|
)
|
|
|
|
pipeline = MultiPassTranscriptionPipeline(
|
|
auto_detect_domain=True,
|
|
domain_enhancement_config=medical_config
|
|
)
|
|
|
|
# Test medical content
|
|
medical_segments = [
|
|
{
|
|
"text": "The patient exhibits symptoms of hypertension",
|
|
"start": 0.0,
|
|
"end": 1.0
|
|
},
|
|
{
|
|
"text": "We need to monitor blood pressure regularly",
|
|
"start": 1.0,
|
|
"end": 2.0
|
|
}
|
|
]
|
|
|
|
print("🏥 Processing medical content...")
|
|
enhanced_medical = await pipeline._perform_enhancement_pass(
|
|
medical_segments,
|
|
domain="medical"
|
|
)
|
|
|
|
print(f"✅ Medical enhancement completed. Domain: {enhanced_medical[0].get('domain', 'unknown')}")
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Domain switching demo failed: {e}")
|
|
return False
|
|
|
|
async def demo_fallback_behavior():
|
|
"""Demonstrate fallback behavior when enhancement fails."""
|
|
|
|
try:
|
|
from src.services.multi_pass_transcription import MultiPassTranscriptionPipeline
|
|
|
|
print("\n🛡️ Fallback Behavior Demo")
|
|
print("=" * 40)
|
|
|
|
# Create pipeline without domain enhancement config
|
|
pipeline = MultiPassTranscriptionPipeline(
|
|
auto_detect_domain=False,
|
|
domain_enhancement_config=None
|
|
)
|
|
|
|
# Test segments
|
|
test_segments = [
|
|
{
|
|
"text": "General content without specific domain",
|
|
"start": 0.0,
|
|
"end": 1.0
|
|
}
|
|
]
|
|
|
|
print("📝 Processing general content...")
|
|
enhanced_segments = await pipeline._perform_enhancement_pass(
|
|
test_segments,
|
|
domain="general"
|
|
)
|
|
|
|
print(f"✅ Fallback enhancement completed. Domain: {enhanced_segments[0].get('domain', 'unknown')}")
|
|
print(f" Text: {enhanced_segments[0]['text']}")
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Fallback demo failed: {e}")
|
|
return False
|
|
|
|
async def main():
|
|
"""Run all demonstrations."""
|
|
print("🎯 MultiPassTranscriptionPipeline + DomainEnhancementPipeline Integration")
|
|
print("=" * 70)
|
|
print()
|
|
|
|
# Run demonstrations
|
|
demos = [
|
|
("Basic Integration", demo_multi_pass_with_domain_enhancement),
|
|
("Domain Switching", demo_domain_switching),
|
|
("Fallback Behavior", demo_fallback_behavior)
|
|
]
|
|
|
|
results = []
|
|
for name, demo_func in demos:
|
|
print(f"🎬 Running: {name}")
|
|
print("-" * 40)
|
|
|
|
try:
|
|
result = await demo_func()
|
|
results.append((name, result))
|
|
print(f"✅ {name}: {'SUCCESS' if result else 'FAILED'}")
|
|
except Exception as e:
|
|
print(f"❌ {name}: FAILED - {e}")
|
|
results.append((name, False))
|
|
|
|
print()
|
|
|
|
# Summary
|
|
print("📋 Demo Summary")
|
|
print("=" * 40)
|
|
for name, result in results:
|
|
status = "✅ PASS" if result else "❌ FAIL"
|
|
print(f" {name}: {status}")
|
|
|
|
success_count = sum(1 for _, result in results if result)
|
|
total_count = len(results)
|
|
|
|
print(f"\n🎉 Overall: {success_count}/{total_count} demos passed")
|
|
|
|
if success_count == total_count:
|
|
print("🚀 All integrations working correctly!")
|
|
else:
|
|
print("⚠️ Some integrations need attention.")
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|