28 KiB
Trax v2 Technical Architecture: High-Performance Single-Node Design
🎯 Architecture Overview
Trax v2 represents a significant evolution from v1, focusing on high performance and speaker diarization rather than distributed scalability. The architecture is designed as a highly optimized, single-node, multi-process application that leverages the full power of modern hardware while maintaining simplicity and determinism.
Key Architectural Principles
- Single-Node Optimization: Maximize utilization of one powerful machine rather than distributing across multiple nodes
- Multi-Pass Pipeline: Intelligent multi-stage processing for 99.5%+ accuracy
- Parallel Processing: Concurrent execution of independent tasks within the same job
- Model Caching: Persistent model management to avoid reloading overhead
- Memory Efficiency: 8-bit quantization and smart resource management
- Deterministic Processing: Predictable, reproducible results across runs
🏗️ Core Architecture Components
1. Enhanced Task System
The task system evolves from simple single-action tasks to complex pipeline workflows:
@dataclass
class PipelineTask:
"""Enhanced task definition for v2 pipeline workflows"""
id: UUID
media_file_id: UUID
pipeline_stages: List[str] # ["transcribe", "enhance", "diarize", "merge"]
pipeline_config: Dict[str, Any] # Model selection, domain, quality settings
status: TaskStatus
current_stage: Optional[str]
progress_percentage: float
error_message: Optional[str]
created_at: datetime
updated_at: datetime
Pipeline Stages
- transcribe: Multi-pass transcription with confidence scoring
- enhance: AI-powered text refinement using DeepSeek
- diarize: Speaker identification using Pyannote.audio
- merge: Combine transcript and diarization results
2. ModelManager Singleton
Central model management to prevent memory duplication and enable fast model switching:
class ModelManager:
"""Singleton for managing AI model lifecycle and caching"""
def __init__(self):
self._models: Dict[str, Any] = {}
self._lora_adapters: Dict[str, Any] = {}
self._model_configs: Dict[str, Dict] = {}
async def get_model(self, model_type: str, config: Dict) -> Any:
"""Get or load model with caching"""
cache_key = self._generate_cache_key(model_type, config)
if cache_key not in self._models:
model = await self._load_model(model_type, config)
self._models[cache_key] = model
return self._models[cache_key]
async def load_lora_adapter(self, domain: str) -> Any:
"""Load LoRA adapter for domain-specific processing"""
if domain not in self._lora_adapters:
adapter = await self._load_lora_adapter(domain)
self._lora_adapters[domain] = adapter
return self._lora_adapters[domain]
3. Multi-Pass Transcription Pipeline
Intelligent multi-stage processing for maximum accuracy:
class MultiPassTranscriptionPipeline:
"""Multi-pass transcription pipeline for 99.5%+ accuracy"""
def __init__(self, config: PipelineConfig):
self.config = config
self.model_manager = ModelManager()
async def process(self, audio_file: Path) -> TranscriptionResult:
"""Execute multi-pass transcription pipeline"""
# Stage 1: Fast initial transcription
initial_result = await self._fast_pass(audio_file)
# Stage 2: Confidence scoring and refinement
confidence_scores = self._calculate_confidence(initial_result)
refinement_segments = self._identify_low_confidence_segments(confidence_scores)
if refinement_segments:
refined_result = await self._refinement_pass(audio_file, refinement_segments)
merged_result = self._merge_transcripts(initial_result, refined_result)
else:
merged_result = initial_result
# Stage 3: AI enhancement (optional)
if self.config.enable_enhancement:
enhanced_result = await self._enhancement_pass(merged_result)
return enhanced_result
return merged_result
async def _fast_pass(self, audio_file: Path) -> TranscriptionResult:
"""First pass using fast model (distil-small.en)"""
model = await self.model_manager.get_model("whisper", {
"model": "distil-small.en",
"quantized": True
})
return await self._transcribe_with_model(audio_file, model)
async def _refinement_pass(self, audio_file: Path, segments: List[Segment]) -> TranscriptionResult:
"""Refinement pass using accurate model (distil-large-v3)"""
model = await self.model_manager.get_model("whisper", {
"model": "distil-large-v3",
"quantized": True,
"segments": segments
})
return await self._transcribe_with_model(audio_file, model)
async def _enhancement_pass(self, transcript: TranscriptionResult) -> TranscriptionResult:
"""AI enhancement using DeepSeek"""
enhancer = await self.model_manager.get_model("deepseek", {})
return await enhancer.enhance_transcript(transcript)
4. Speaker Diarization Service
Pyannote.audio integration for speaker identification:
class SpeakerDiarizationService:
"""Speaker diarization using Pyannote.audio"""
def __init__(self, config: DiarizationConfig):
self.config = config
self.model_manager = ModelManager()
self._embedding_model = None
self._clustering_model = None
async def diarize(self, audio_file: Path) -> DiarizationResult:
"""Perform speaker diarization on audio file"""
# Load models (cached)
if not self._embedding_model:
self._embedding_model = await self.model_manager.get_model("pyannote_embedding", {})
if not self._clustering_model:
self._clustering_model = await self.model_manager.get_model("pyannote_clustering", {})
# Extract speaker embeddings
embeddings = await self._extract_embeddings(audio_file)
# Perform clustering
speaker_segments = await self._cluster_speakers(embeddings)
# Post-process and validate
validated_segments = self._validate_speaker_segments(speaker_segments)
return DiarizationResult(
speaker_segments=validated_segments,
speaker_count=len(set(seg.speaker_id for seg in validated_segments)),
confidence_score=self._calculate_diarization_confidence(validated_segments)
)
async def _extract_embeddings(self, audio_file: Path) -> List[Embedding]:
"""Extract speaker embeddings from audio"""
# Implementation using Pyannote.audio embedding model
pass
async def _cluster_speakers(self, embeddings: List[Embedding]) -> List[SpeakerSegment]:
"""Cluster embeddings to identify speakers"""
# Implementation using Pyannote.audio clustering
pass
5. Domain Adaptation with LoRA
Lightweight domain-specific model adaptation:
class LoRAAdapterManager:
"""Manage LoRA adapters for domain-specific processing"""
def __init__(self):
self.model_manager = ModelManager()
self._base_model = None
self._current_adapter = None
async def load_domain_adapter(self, domain: str) -> None:
"""Load LoRA adapter for specific domain"""
# Load base model if not loaded
if not self._base_model:
self._base_model = await self.model_manager.get_model("whisper_base", {})
# Load domain-specific adapter
adapter = await self.model_manager.load_lora_adapter(domain)
# Apply adapter to base model
self._base_model.load_adapter(adapter)
self._current_adapter = domain
async def auto_detect_domain(self, audio_file: Path) -> str:
"""Automatically detect content domain"""
# Use keyword analysis or content classification
# Return detected domain (technical, medical, academic, general)
pass
async def transcribe_with_domain(self, audio_file: Path, domain: str) -> TranscriptionResult:
"""Transcribe with domain-specific model"""
await self.load_domain_adapter(domain)
return await self._base_model.transcribe(audio_file)
🔄 Data Flow Architecture
Parallel Processing Flow
┌─────────────────┐
│ Audio File │
└─────────────────┘
│
▼
┌─────────────────────────────────────────────────────────┐
│ Parallel Processing Pipeline │
│ ┌─────────────────┐ ┌─────────────────────────────┐ │
│ │ Transcription │ │ Diarization │ │
│ │ Pipeline │ │ Pipeline │ │
│ │ │ │ │ │
│ │ • Fast Pass │ │ • Embedding Extraction │ │
│ │ • Refinement │ │ • Speaker Clustering │ │
│ │ • Enhancement │ │ • Segment Validation │ │
│ └─────────────────┘ └─────────────────────────────┘ │
└─────────────────────────────────────────────────────────┘
│ │
▼ ▼
┌─────────────────┐ ┌─────────────────────────────┐
│ Transcription │ │ Diarization │
│ Result │ │ Result │
└─────────────────┘ └─────────────────────────────┘
│ │
└───────────┬───────────────┘
▼
┌─────────────────────────────┐
│ Merge Service │
│ │
│ • Align timestamps │
│ • Combine speaker labels │
│ • Validate consistency │
└─────────────────────────────┘
▼
┌─────────────────────────────┐
│ Final Transcript │
│ │
│ • High accuracy text │
│ • Speaker identification │
│ • Confidence scores │
└─────────────────────────────┘
State Management Flow
class ProcessingJobManager:
"""Manage processing job lifecycle and state transitions"""
async def create_job(self, media_file_id: UUID, config: PipelineConfig) -> ProcessingJob:
"""Create new processing job"""
job = ProcessingJob(
id=uuid4(),
media_file_id=media_file_id,
pipeline_config=config,
status=TaskStatus.QUEUED,
created_at=datetime.utcnow()
)
await self._save_job(job)
return job
async def execute_job(self, job: ProcessingJob) -> None:
"""Execute processing job with state management"""
try:
# Update status to processing
await self._update_job_status(job.id, TaskStatus.PROCESSING)
# Execute pipeline stages
for stage in job.pipeline_config.stages:
await self._update_job_stage(job.id, stage)
result = await self._execute_stage(job, stage)
if not result.success:
raise ProcessingError(f"Stage {stage} failed: {result.error}")
# Mark as completed
await self._update_job_status(job.id, TaskStatus.COMPLETED)
except Exception as e:
await self._update_job_status(job.id, TaskStatus.FAILED, str(e))
raise
🚀 Performance Optimization Strategies
1. Memory Optimization
class MemoryOptimizer:
"""Memory optimization strategies for v2"""
def __init__(self):
self.max_memory_gb = 8
self.current_usage_gb = 0
async def optimize_model_loading(self, model_config: Dict) -> Dict:
"""Apply memory optimizations to model loading"""
# 8-bit quantization
if model_config.get("quantized", True):
model_config["torch_dtype"] = torch.int8
# Gradient checkpointing for large models
if model_config.get("model_size") == "large":
model_config["gradient_checkpointing"] = True
# Model offloading for very large models
if self.current_usage_gb > self.max_memory_gb * 0.8:
model_config["device_map"] = "auto"
return model_config
async def cleanup_unused_models(self) -> None:
"""Clean up unused models to free memory"""
unused_models = self._identify_unused_models()
for model in unused_models:
await self._unload_model(model)
2. CPU Optimization
class CPUOptimizer:
"""CPU optimization for parallel processing"""
def __init__(self):
self.cpu_count = os.cpu_count()
self.optimal_worker_count = min(self.cpu_count, 8)
async def configure_worker_pool(self) -> AsyncWorkerPool:
"""Configure optimal worker pool size"""
return AsyncWorkerPool(
max_workers=self.optimal_worker_count,
thread_name_prefix="trax_worker"
)
async def optimize_audio_processing(self, audio_file: Path) -> Path:
"""Optimize audio for processing"""
# Convert to optimal format (16kHz mono WAV)
# Apply noise reduction if needed
# Chunk large files appropriately
pass
3. Pipeline Optimization
class PipelineOptimizer:
"""Optimize pipeline execution for performance"""
async def execute_parallel_stages(self, job: ProcessingJob) -> Dict[str, Any]:
"""Execute independent stages in parallel"""
# Identify parallel stages
parallel_stages = self._identify_parallel_stages(job.pipeline_config)
# Execute in parallel
tasks = []
for stage in parallel_stages:
task = asyncio.create_task(self._execute_stage(job, stage))
tasks.append(task)
# Wait for completion
results = await asyncio.gather(*tasks, return_exceptions=True)
return dict(zip(parallel_stages, results))
def _identify_parallel_stages(self, config: PipelineConfig) -> List[str]:
"""Identify stages that can run in parallel"""
# Transcription and diarization can run in parallel
# Enhancement must wait for transcription
# Merging must wait for both transcription and diarization
pass
💻 CLI Interface Architecture
Enhanced CLI Interface
class TraxCLI:
"""Enhanced CLI interface for Trax v2"""
def __init__(self):
self.progress_reporter = ProgressReporter()
self.batch_processor = BatchProcessor()
self.logger = setup_logging()
async def transcribe_single(self, file_path: Path, config: PipelineConfig) -> None:
"""Transcribe a single file with enhanced progress reporting"""
# Validate file
self._validate_file(file_path)
# Create processing job
job = await self._create_job(file_path, config)
# Process with real-time progress
await self._process_with_progress(job)
# Display results
self._display_results(job)
async def transcribe_batch(self, directory: Path, config: PipelineConfig) -> None:
"""Process batch of files with enhanced progress reporting"""
# Validate directory
files = self._validate_directory(directory)
# Create batch job
batch_job = await self._create_batch_job(files, config)
# Process batch with progress
await self._process_batch_with_progress(batch_job)
# Display batch results
self._display_batch_results(batch_job)
def _validate_file(self, file_path: Path) -> None:
"""Validate single file for processing"""
if not file_path.exists():
raise FileNotFoundError(f"File not found: {file_path}")
if file_path.stat().st_size > 500 * 1024 * 1024: # 500MB
raise ValueError(f"File too large: {file_path}")
if file_path.suffix.lower() not in ['.mp3', '.mp4', '.wav', '.m4a', '.webm']:
raise ValueError(f"Unsupported format: {file_path.suffix}")
def _validate_directory(self, directory: Path) -> List[Path]:
"""Validate directory and return list of supported files"""
if not directory.exists():
raise FileNotFoundError(f"Directory not found: {directory}")
supported_extensions = {'.mp3', '.mp4', '.wav', '.m4a', '.webm'}
files = [
f for f in directory.iterdir()
if f.is_file() and f.suffix.lower() in supported_extensions
]
if not files:
raise ValueError(f"No supported files found in: {directory}")
return files
Progress Reporting
class ProgressReporter:
"""Real-time progress reporting for CLI"""
def __init__(self):
self.start_time = None
self.current_stage = None
async def report_progress(self, job: ProcessingJob) -> None:
"""Report real-time progress for a job"""
if self.start_time is None:
self.start_time = time.time()
# Calculate progress
elapsed = time.time() - self.start_time
progress = job.progress_percentage
# Display progress bar
self._display_progress_bar(progress, elapsed)
# Display current stage
if job.current_stage != self.current_stage:
self.current_stage = job.current_stage
self._display_stage_info(job.current_stage)
# Display performance metrics
self._display_performance_metrics(job)
def _display_progress_bar(self, progress: float, elapsed: float) -> None:
"""Display ASCII progress bar"""
bar_length = 50
filled_length = int(bar_length * progress / 100)
bar = '█' * filled_length + '-' * (bar_length - filled_length)
print(f"\rProgress: [{bar}] {progress:.1f}% ({elapsed:.1f}s)", end='', flush=True)
def _display_stage_info(self, stage: str) -> None:
"""Display current processing stage"""
print(f"\n🔄 {stage.title()}...")
def _display_performance_metrics(self, job: ProcessingJob) -> None:
"""Display performance metrics"""
if hasattr(job, 'performance_metrics'):
metrics = job.performance_metrics
print(f" CPU: {metrics.get('cpu_usage', 0):.1f}% | "
f"Memory: {metrics.get('memory_usage_gb', 0):.1f}GB | "
f"Speed: {metrics.get('processing_speed', 0):.1f}x")
🔧 Configuration Management
Pipeline Configuration
@dataclass
class PipelineConfig:
"""Configuration for v2 processing pipeline"""
# Pipeline version
version: str = "v2" # v1, v2, v2+
# Model selection
transcription_model: str = "distil-large-v3"
enhancement_model: str = "deepseek"
diarization_model: str = "pyannote"
# Quality settings
accuracy_threshold: float = 0.995 # 99.5%
confidence_threshold: float = 0.8
# Domain settings
domain: Optional[str] = None # technical, medical, academic, general
auto_detect_domain: bool = True
# Performance settings
enable_quantization: bool = True
parallel_processing: bool = True
max_workers: int = 8
# Diarization settings
enable_diarization: bool = False
min_speaker_count: int = 2
max_speaker_count: int = 10
# Enhancement settings
enable_enhancement: bool = True
enhancement_prompts: Dict[str, str] = field(default_factory=dict)
Environment Configuration
class TraxV2Config:
"""Configuration management for Trax v2"""
def __init__(self):
self.load_from_env()
self.load_from_file()
self.validate()
def load_from_env(self):
"""Load configuration from environment variables"""
self.openai_api_key = os.getenv("OPENAI_API_KEY")
self.deepseek_api_key = os.getenv("DEEPSEEK_API_KEY")
self.huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
# Performance settings
self.max_memory_gb = int(os.getenv("TRAX_MAX_MEMORY_GB", "8"))
self.max_workers = int(os.getenv("TRAX_MAX_WORKERS", "8"))
self.enable_quantization = os.getenv("TRAX_ENABLE_QUANTIZATION", "true").lower() == "true"
def validate(self):
"""Validate configuration"""
if not self.openai_api_key:
raise ConfigurationError("OPENAI_API_KEY is required")
if not self.deepseek_api_key:
raise ConfigurationError("DEEPSEEK_API_KEY is required")
if not self.huggingface_token:
raise ConfigurationError("HUGGINGFACE_TOKEN is required for diarization")
🧪 Testing Architecture
Unit Testing Strategy
class TraxV2TestSuite:
"""Comprehensive test suite for Trax v2"""
def test_multi_pass_pipeline(self):
"""Test multi-pass transcription pipeline"""
pipeline = MultiPassTranscriptionPipeline(self.test_config)
result = await pipeline.process(self.test_audio_file)
assert result.accuracy_estimate >= 0.995
assert result.processing_time_ms < 25000 # <25 seconds
assert len(result.segments) > 0
def test_diarization_service(self):
"""Test speaker diarization"""
diarization = SpeakerDiarizationService(self.test_config)
result = await diarization.diarize(self.test_multi_speaker_file)
assert result.speaker_count >= 2
assert result.confidence_score >= 0.9
assert len(result.speaker_segments) > 0
def test_lora_adapter_manager(self):
"""Test domain adaptation"""
adapter_manager = LoRAAdapterManager()
await adapter_manager.load_domain_adapter("technical")
result = await adapter_manager.transcribe_with_domain(
self.test_technical_file, "technical"
)
assert result.domain_used == "technical"
assert result.accuracy_estimate > 0.99
Integration Testing Strategy
class TraxV2IntegrationTests:
"""Integration tests for complete v2 pipeline"""
async def test_complete_v2_pipeline(self):
"""Test complete v2 pipeline with diarization"""
job_manager = ProcessingJobManager()
# Create job
job = await job_manager.create_job(
self.test_file_id,
PipelineConfig(version="v2+", enable_diarization=True)
)
# Execute job
await job_manager.execute_job(job)
# Verify results
assert job.status == TaskStatus.COMPLETED
assert job.transcript.accuracy_estimate >= 0.995
assert job.transcript.speaker_count >= 2
assert job.processing_time_ms < 25000
async def test_cli_batch_processing(self):
"""Test CLI batch processing with multiple files"""
cli = TraxCLI()
# Process batch of files
test_directory = Path("test_files")
config = PipelineConfig(version="v2", enable_diarization=True)
await cli.transcribe_batch(test_directory, config)
# Verify all files processed
results = await self._get_batch_results()
assert len(results) == len(list(test_directory.glob("*.mp3")))
assert all(result.status == "completed" for result in results)
📊 Performance Monitoring
Metrics Collection
class PerformanceMonitor:
"""Monitor and collect performance metrics"""
def __init__(self):
self.metrics: Dict[str, List[float]] = defaultdict(list)
async def record_metric(self, metric_name: str, value: float):
"""Record a performance metric"""
self.metrics[metric_name].append(value)
async def get_performance_report(self) -> Dict[str, Dict]:
"""Generate performance report"""
report = {}
for metric_name, values in self.metrics.items():
report[metric_name] = {
"count": len(values),
"mean": statistics.mean(values),
"median": statistics.median(values),
"min": min(values),
"max": max(values),
"std": statistics.stdev(values) if len(values) > 1 else 0
}
return report
async def check_performance_targets(self) -> Dict[str, bool]:
"""Check if performance targets are met"""
targets = {
"processing_time_5min": 25000, # <25 seconds
"accuracy_threshold": 0.995, # 99.5%+
"memory_usage_gb": 8, # <8GB
"diarization_accuracy": 0.9 # 90%+
}
results = {}
for target_name, target_value in targets.items():
if target_name in self.metrics:
current_value = statistics.mean(self.metrics[target_name])
results[target_name] = current_value <= target_value
return results
🔄 Migration Strategy
From v1 to v2
class TraxV2Migration:
"""Migration utilities for upgrading from v1 to v2"""
async def migrate_database_schema(self):
"""Migrate database schema for v2 features"""
# Add new tables
await self._create_speaker_profiles_table()
await self._create_processing_jobs_table()
# Modify existing tables
await self._add_v2_columns_to_transcripts()
await self._add_v2_columns_to_media_files()
async def migrate_existing_transcripts(self):
"""Migrate existing v1 transcripts to v2 format"""
v1_transcripts = await self._get_v1_transcripts()
for transcript in v1_transcripts:
# Update schema
transcript.pipeline_version = "v1"
transcript.merged_content = transcript.raw_content
# Save updated transcript
await self._save_transcript(transcript)
async def validate_migration(self) -> bool:
"""Validate successful migration"""
# Check schema
schema_valid = await self._validate_schema()
# Check data integrity
data_valid = await self._validate_data_integrity()
# Check functionality
functionality_valid = await self._test_v2_functionality()
return schema_valid and data_valid and functionality_valid
This architecture document provides the technical foundation for implementing Trax v2, focusing on high performance and speaker diarization while maintaining the simplicity and determinism of the single-node design.