trax/examples/diarization_pipeline_exampl...

322 lines
11 KiB
Python

"""Example usage of the diarization pipeline components.
This script demonstrates how to use the DiarizationManager, SpeakerProfileManager,
and ParallelProcessor for speaker diarization and profile management.
"""
import logging
import time
from pathlib import Path
from typing import List
from src.services.diarization_types import (
DiarizationConfig, ParallelProcessingConfig
)
from src.services.diarization_service import DiarizationManager
from src.services.speaker_profile_manager import SpeakerProfileManager
from src.services.parallel_processor import ParallelProcessor
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def setup_services():
"""Set up diarization services with configuration."""
# Diarization configuration
diarization_config = DiarizationConfig(
model_path="pyannote/speaker-diarization-3.0",
device="auto",
memory_optimization=True,
min_duration=0.5,
threshold=0.5
)
# Parallel processing configuration
parallel_config = ParallelProcessingConfig(
max_workers=2,
timeout_seconds=300,
memory_limit_mb=6000
)
# Initialize services
diarization_manager = DiarizationManager(diarization_config)
profile_manager = SpeakerProfileManager()
parallel_processor = ParallelProcessor(parallel_config)
return diarization_manager, profile_manager, parallel_processor
def process_single_file(
audio_path: Path,
diarization_manager: DiarizationManager,
profile_manager: SpeakerProfileManager
):
"""Process a single audio file with diarization and profile management."""
logger.info(f"Processing: {audio_path.name}")
try:
# Process diarization
start_time = time.time()
result = diarization_manager.process_audio(audio_path)
processing_time = time.time() - start_time
logger.info(f"Diarization completed in {processing_time:.2f}s")
logger.info(f"Found {result.speaker_count} speakers")
logger.info(f"Confidence: {result.confidence_score:.2f}")
# Create speaker profiles (mock embeddings for demonstration)
import numpy as np
for i, segment in enumerate(result.segments[:3]): # Limit to first 3 speakers
embedding = np.random.rand(512) # Mock embedding
profile = profile_manager.add_speaker(
segment.speaker_id,
embedding,
name=f"Speaker {i+1}"
)
logger.info(f"Created profile for {profile.speaker_id}")
return result
except Exception as e:
logger.error(f"Failed to process {audio_path}: {e}")
return None
def process_with_profiles(
audio_path: Path,
diarization_manager: DiarizationManager,
profile_manager: SpeakerProfileManager
):
"""Process audio with existing speaker profiles."""
logger.info(f"Processing with profiles: {audio_path.name}")
try:
# Process diarization
result = diarization_manager.process_audio(audio_path)
# Match speakers with existing profiles
import numpy as np
for segment in result.segments:
# Mock embedding for demonstration
embedding = np.random.rand(512)
# Find similar speakers
matches = profile_manager.find_similar_speakers(embedding, threshold=0.7)
if matches:
best_match = matches[0]
logger.info(f"Matched {segment.speaker_id} to {best_match.speaker_id} "
f"(similarity: {best_match.similarity_score:.2f})")
else:
logger.info(f"No match found for {segment.speaker_id}")
return result
except Exception as e:
logger.error(f"Failed to process with profiles: {e}")
return None
def process_parallel(
audio_path: Path,
parallel_processor: ParallelProcessor
):
"""Process audio using parallel diarization and transcription."""
logger.info(f"Parallel processing: {audio_path.name}")
try:
# Process with parallel processor
start_time = time.time()
result = parallel_processor.process_file(audio_path)
processing_time = time.time() - start_time
if result.success:
logger.info(f"Parallel processing completed in {processing_time:.2f}s")
logger.info(f"Total processing time: {result.processing_time:.2f}s")
if result.merged_result:
logger.info(f"Speaker count: {result.merged_result.get('speaker_count', 0)}")
logger.info(f"Segments: {len(result.merged_result.get('segments', []))}")
else:
logger.error(f"Parallel processing failed: {result.error_message}")
return result
except Exception as e:
logger.error(f"Failed to process in parallel: {e}")
return None
def batch_process(
audio_paths: List[Path],
parallel_processor: ParallelProcessor
):
"""Process multiple audio files in batch."""
logger.info(f"Batch processing {len(audio_paths)} files")
try:
# Process batch
start_time = time.time()
results = parallel_processor.process_batch(audio_paths)
total_time = time.time() - start_time
# Analyze results
successful = sum(1 for r in results if r.success)
failed = len(results) - successful
logger.info(f"Batch processing completed in {total_time:.2f}s")
logger.info(f"Successful: {successful}, Failed: {failed}")
# Get processing statistics
stats = parallel_processor.get_processing_stats()
logger.info(f"Success rate: {stats.get('success_rate', 0):.2f}")
logger.info(f"Average processing time: {stats.get('average_processing_time', 0):.2f}s")
return results
except Exception as e:
logger.error(f"Batch processing failed: {e}")
return []
def demonstrate_speaker_profiles(profile_manager: SpeakerProfileManager):
"""Demonstrate speaker profile functionality."""
logger.info("Demonstrating speaker profile features")
try:
import numpy as np
# Add some test profiles
speakers = [
("alice", "Alice Johnson"),
("bob", "Bob Smith"),
("charlie", "Charlie Brown")
]
for speaker_id, name in speakers:
embedding = np.random.rand(512)
profile = profile_manager.add_speaker(speaker_id, embedding, name=name)
logger.info(f"Added profile: {profile.name} ({profile.speaker_id})")
# Test similarity matching
test_embedding = np.random.rand(512)
matches = profile_manager.find_similar_speakers(test_embedding, threshold=0.5)
logger.info(f"Found {len(matches)} similar speakers")
for match in matches[:2]: # Show top 2 matches
logger.info(f"Match: {match.profile.name} (similarity: {match.similarity_score:.2f})")
# Get profile statistics
stats = profile_manager.get_profile_stats()
logger.info(f"Total profiles: {stats['total_profiles']}")
logger.info(f"Profiles with embeddings: {stats['profiles_with_embeddings']}")
except Exception as e:
logger.error(f"Speaker profile demonstration failed: {e}")
def performance_comparison(
audio_path: Path,
diarization_manager: DiarizationManager,
parallel_processor: ParallelProcessor
):
"""Compare sequential vs parallel processing performance."""
logger.info("Comparing sequential vs parallel processing")
try:
# Sequential processing
logger.info("Running sequential processing...")
start_time = time.time()
sequential_result = diarization_manager.process_audio(audio_path)
sequential_time = time.time() - start_time
# Parallel processing
logger.info("Running parallel processing...")
start_time = time.time()
parallel_result = parallel_processor.process_file(audio_path)
parallel_time = time.time() - start_time
# Calculate speedup
if parallel_result.success and parallel_time > 0:
speedup = sequential_time / parallel_time
logger.info(f"Sequential time: {sequential_time:.2f}s")
logger.info(f"Parallel time: {parallel_time:.2f}s")
logger.info(f"Speedup: {speedup:.2f}x")
# Update processor stats
parallel_processor.estimate_speedup(sequential_time, parallel_time)
else:
logger.warning("Could not calculate speedup due to parallel processing failure")
except Exception as e:
logger.error(f"Performance comparison failed: {e}")
def main():
"""Main function demonstrating the diarization pipeline."""
logger.info("Starting diarization pipeline demonstration")
# Set up services
diarization_manager, profile_manager, parallel_processor = setup_services()
try:
# Example audio files (adjust paths as needed)
audio_files = [
Path("tests/sample_5s.wav"),
Path("tests/sample_30s.mp3"),
Path("tests/sample_2m.mp4")
]
# Filter to existing files
existing_files = [f for f in audio_files if f.exists()]
if not existing_files:
logger.warning("No test audio files found. Please adjust paths.")
return
logger.info(f"Found {len(existing_files)} audio files for processing")
# Demonstrate different processing approaches
for audio_file in existing_files[:2]: # Process first 2 files
logger.info(f"\n--- Processing {audio_file.name} ---")
# Single file processing
process_single_file(audio_file, diarization_manager, profile_manager)
# Profile-based processing
process_with_profiles(audio_file, diarization_manager, profile_manager)
# Parallel processing
process_parallel(audio_file, parallel_processor)
# Batch processing
if len(existing_files) > 1:
logger.info("\n--- Batch Processing ---")
batch_process(existing_files, parallel_processor)
# Speaker profile demonstration
logger.info("\n--- Speaker Profile Demonstration ---")
demonstrate_speaker_profiles(profile_manager)
# Performance comparison
if existing_files:
logger.info("\n--- Performance Comparison ---")
performance_comparison(existing_files[0], diarization_manager, parallel_processor)
logger.info("\nDemonstration completed successfully!")
except Exception as e:
logger.error(f"Demonstration failed: {e}")
finally:
# Cleanup
logger.info("Cleaning up resources...")
diarization_manager.cleanup()
profile_manager.cleanup()
parallel_processor.cleanup()
if __name__ == "__main__":
main()