322 lines
11 KiB
Python
322 lines
11 KiB
Python
"""Example usage of the diarization pipeline components.
|
|
|
|
This script demonstrates how to use the DiarizationManager, SpeakerProfileManager,
|
|
and ParallelProcessor for speaker diarization and profile management.
|
|
"""
|
|
|
|
import logging
|
|
import time
|
|
from pathlib import Path
|
|
from typing import List
|
|
|
|
from src.services.diarization_types import (
|
|
DiarizationConfig, ParallelProcessingConfig
|
|
)
|
|
from src.services.diarization_service import DiarizationManager
|
|
from src.services.speaker_profile_manager import SpeakerProfileManager
|
|
from src.services.parallel_processor import ParallelProcessor
|
|
|
|
# Configure logging
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def setup_services():
|
|
"""Set up diarization services with configuration."""
|
|
# Diarization configuration
|
|
diarization_config = DiarizationConfig(
|
|
model_path="pyannote/speaker-diarization-3.0",
|
|
device="auto",
|
|
memory_optimization=True,
|
|
min_duration=0.5,
|
|
threshold=0.5
|
|
)
|
|
|
|
# Parallel processing configuration
|
|
parallel_config = ParallelProcessingConfig(
|
|
max_workers=2,
|
|
timeout_seconds=300,
|
|
memory_limit_mb=6000
|
|
)
|
|
|
|
# Initialize services
|
|
diarization_manager = DiarizationManager(diarization_config)
|
|
profile_manager = SpeakerProfileManager()
|
|
parallel_processor = ParallelProcessor(parallel_config)
|
|
|
|
return diarization_manager, profile_manager, parallel_processor
|
|
|
|
|
|
def process_single_file(
|
|
audio_path: Path,
|
|
diarization_manager: DiarizationManager,
|
|
profile_manager: SpeakerProfileManager
|
|
):
|
|
"""Process a single audio file with diarization and profile management."""
|
|
logger.info(f"Processing: {audio_path.name}")
|
|
|
|
try:
|
|
# Process diarization
|
|
start_time = time.time()
|
|
result = diarization_manager.process_audio(audio_path)
|
|
processing_time = time.time() - start_time
|
|
|
|
logger.info(f"Diarization completed in {processing_time:.2f}s")
|
|
logger.info(f"Found {result.speaker_count} speakers")
|
|
logger.info(f"Confidence: {result.confidence_score:.2f}")
|
|
|
|
# Create speaker profiles (mock embeddings for demonstration)
|
|
import numpy as np
|
|
for i, segment in enumerate(result.segments[:3]): # Limit to first 3 speakers
|
|
embedding = np.random.rand(512) # Mock embedding
|
|
profile = profile_manager.add_speaker(
|
|
segment.speaker_id,
|
|
embedding,
|
|
name=f"Speaker {i+1}"
|
|
)
|
|
logger.info(f"Created profile for {profile.speaker_id}")
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to process {audio_path}: {e}")
|
|
return None
|
|
|
|
|
|
def process_with_profiles(
|
|
audio_path: Path,
|
|
diarization_manager: DiarizationManager,
|
|
profile_manager: SpeakerProfileManager
|
|
):
|
|
"""Process audio with existing speaker profiles."""
|
|
logger.info(f"Processing with profiles: {audio_path.name}")
|
|
|
|
try:
|
|
# Process diarization
|
|
result = diarization_manager.process_audio(audio_path)
|
|
|
|
# Match speakers with existing profiles
|
|
import numpy as np
|
|
for segment in result.segments:
|
|
# Mock embedding for demonstration
|
|
embedding = np.random.rand(512)
|
|
|
|
# Find similar speakers
|
|
matches = profile_manager.find_similar_speakers(embedding, threshold=0.7)
|
|
|
|
if matches:
|
|
best_match = matches[0]
|
|
logger.info(f"Matched {segment.speaker_id} to {best_match.speaker_id} "
|
|
f"(similarity: {best_match.similarity_score:.2f})")
|
|
else:
|
|
logger.info(f"No match found for {segment.speaker_id}")
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to process with profiles: {e}")
|
|
return None
|
|
|
|
|
|
def process_parallel(
|
|
audio_path: Path,
|
|
parallel_processor: ParallelProcessor
|
|
):
|
|
"""Process audio using parallel diarization and transcription."""
|
|
logger.info(f"Parallel processing: {audio_path.name}")
|
|
|
|
try:
|
|
# Process with parallel processor
|
|
start_time = time.time()
|
|
result = parallel_processor.process_file(audio_path)
|
|
processing_time = time.time() - start_time
|
|
|
|
if result.success:
|
|
logger.info(f"Parallel processing completed in {processing_time:.2f}s")
|
|
logger.info(f"Total processing time: {result.processing_time:.2f}s")
|
|
|
|
if result.merged_result:
|
|
logger.info(f"Speaker count: {result.merged_result.get('speaker_count', 0)}")
|
|
logger.info(f"Segments: {len(result.merged_result.get('segments', []))}")
|
|
else:
|
|
logger.error(f"Parallel processing failed: {result.error_message}")
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to process in parallel: {e}")
|
|
return None
|
|
|
|
|
|
def batch_process(
|
|
audio_paths: List[Path],
|
|
parallel_processor: ParallelProcessor
|
|
):
|
|
"""Process multiple audio files in batch."""
|
|
logger.info(f"Batch processing {len(audio_paths)} files")
|
|
|
|
try:
|
|
# Process batch
|
|
start_time = time.time()
|
|
results = parallel_processor.process_batch(audio_paths)
|
|
total_time = time.time() - start_time
|
|
|
|
# Analyze results
|
|
successful = sum(1 for r in results if r.success)
|
|
failed = len(results) - successful
|
|
|
|
logger.info(f"Batch processing completed in {total_time:.2f}s")
|
|
logger.info(f"Successful: {successful}, Failed: {failed}")
|
|
|
|
# Get processing statistics
|
|
stats = parallel_processor.get_processing_stats()
|
|
logger.info(f"Success rate: {stats.get('success_rate', 0):.2f}")
|
|
logger.info(f"Average processing time: {stats.get('average_processing_time', 0):.2f}s")
|
|
|
|
return results
|
|
|
|
except Exception as e:
|
|
logger.error(f"Batch processing failed: {e}")
|
|
return []
|
|
|
|
|
|
def demonstrate_speaker_profiles(profile_manager: SpeakerProfileManager):
|
|
"""Demonstrate speaker profile functionality."""
|
|
logger.info("Demonstrating speaker profile features")
|
|
|
|
try:
|
|
import numpy as np
|
|
|
|
# Add some test profiles
|
|
speakers = [
|
|
("alice", "Alice Johnson"),
|
|
("bob", "Bob Smith"),
|
|
("charlie", "Charlie Brown")
|
|
]
|
|
|
|
for speaker_id, name in speakers:
|
|
embedding = np.random.rand(512)
|
|
profile = profile_manager.add_speaker(speaker_id, embedding, name=name)
|
|
logger.info(f"Added profile: {profile.name} ({profile.speaker_id})")
|
|
|
|
# Test similarity matching
|
|
test_embedding = np.random.rand(512)
|
|
matches = profile_manager.find_similar_speakers(test_embedding, threshold=0.5)
|
|
|
|
logger.info(f"Found {len(matches)} similar speakers")
|
|
for match in matches[:2]: # Show top 2 matches
|
|
logger.info(f"Match: {match.profile.name} (similarity: {match.similarity_score:.2f})")
|
|
|
|
# Get profile statistics
|
|
stats = profile_manager.get_profile_stats()
|
|
logger.info(f"Total profiles: {stats['total_profiles']}")
|
|
logger.info(f"Profiles with embeddings: {stats['profiles_with_embeddings']}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Speaker profile demonstration failed: {e}")
|
|
|
|
|
|
def performance_comparison(
|
|
audio_path: Path,
|
|
diarization_manager: DiarizationManager,
|
|
parallel_processor: ParallelProcessor
|
|
):
|
|
"""Compare sequential vs parallel processing performance."""
|
|
logger.info("Comparing sequential vs parallel processing")
|
|
|
|
try:
|
|
# Sequential processing
|
|
logger.info("Running sequential processing...")
|
|
start_time = time.time()
|
|
sequential_result = diarization_manager.process_audio(audio_path)
|
|
sequential_time = time.time() - start_time
|
|
|
|
# Parallel processing
|
|
logger.info("Running parallel processing...")
|
|
start_time = time.time()
|
|
parallel_result = parallel_processor.process_file(audio_path)
|
|
parallel_time = time.time() - start_time
|
|
|
|
# Calculate speedup
|
|
if parallel_result.success and parallel_time > 0:
|
|
speedup = sequential_time / parallel_time
|
|
logger.info(f"Sequential time: {sequential_time:.2f}s")
|
|
logger.info(f"Parallel time: {parallel_time:.2f}s")
|
|
logger.info(f"Speedup: {speedup:.2f}x")
|
|
|
|
# Update processor stats
|
|
parallel_processor.estimate_speedup(sequential_time, parallel_time)
|
|
else:
|
|
logger.warning("Could not calculate speedup due to parallel processing failure")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Performance comparison failed: {e}")
|
|
|
|
|
|
def main():
|
|
"""Main function demonstrating the diarization pipeline."""
|
|
logger.info("Starting diarization pipeline demonstration")
|
|
|
|
# Set up services
|
|
diarization_manager, profile_manager, parallel_processor = setup_services()
|
|
|
|
try:
|
|
# Example audio files (adjust paths as needed)
|
|
audio_files = [
|
|
Path("tests/sample_5s.wav"),
|
|
Path("tests/sample_30s.mp3"),
|
|
Path("tests/sample_2m.mp4")
|
|
]
|
|
|
|
# Filter to existing files
|
|
existing_files = [f for f in audio_files if f.exists()]
|
|
|
|
if not existing_files:
|
|
logger.warning("No test audio files found. Please adjust paths.")
|
|
return
|
|
|
|
logger.info(f"Found {len(existing_files)} audio files for processing")
|
|
|
|
# Demonstrate different processing approaches
|
|
for audio_file in existing_files[:2]: # Process first 2 files
|
|
logger.info(f"\n--- Processing {audio_file.name} ---")
|
|
|
|
# Single file processing
|
|
process_single_file(audio_file, diarization_manager, profile_manager)
|
|
|
|
# Profile-based processing
|
|
process_with_profiles(audio_file, diarization_manager, profile_manager)
|
|
|
|
# Parallel processing
|
|
process_parallel(audio_file, parallel_processor)
|
|
|
|
# Batch processing
|
|
if len(existing_files) > 1:
|
|
logger.info("\n--- Batch Processing ---")
|
|
batch_process(existing_files, parallel_processor)
|
|
|
|
# Speaker profile demonstration
|
|
logger.info("\n--- Speaker Profile Demonstration ---")
|
|
demonstrate_speaker_profiles(profile_manager)
|
|
|
|
# Performance comparison
|
|
if existing_files:
|
|
logger.info("\n--- Performance Comparison ---")
|
|
performance_comparison(existing_files[0], diarization_manager, parallel_processor)
|
|
|
|
logger.info("\nDemonstration completed successfully!")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Demonstration failed: {e}")
|
|
|
|
finally:
|
|
# Cleanup
|
|
logger.info("Cleaning up resources...")
|
|
diarization_manager.cleanup()
|
|
profile_manager.cleanup()
|
|
parallel_processor.cleanup()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|