trax/.taskmaster/docs/trax-v2-architecture.md

28 KiB

Trax v2 Technical Architecture: High-Performance Single-Node Design

🎯 Architecture Overview

Trax v2 represents a significant evolution from v1, focusing on high performance and speaker diarization rather than distributed scalability. The architecture is designed as a highly optimized, single-node, multi-process application that leverages the full power of modern hardware while maintaining simplicity and determinism.

Key Architectural Principles

  1. Single-Node Optimization: Maximize utilization of one powerful machine rather than distributing across multiple nodes
  2. Multi-Pass Pipeline: Intelligent multi-stage processing for 99.5%+ accuracy
  3. Parallel Processing: Concurrent execution of independent tasks within the same job
  4. Model Caching: Persistent model management to avoid reloading overhead
  5. Memory Efficiency: 8-bit quantization and smart resource management
  6. Deterministic Processing: Predictable, reproducible results across runs

🏗️ Core Architecture Components

1. Enhanced Task System

The task system evolves from simple single-action tasks to complex pipeline workflows:

@dataclass
class PipelineTask:
    """Enhanced task definition for v2 pipeline workflows"""
    id: UUID
    media_file_id: UUID
    pipeline_stages: List[str]  # ["transcribe", "enhance", "diarize", "merge"]
    pipeline_config: Dict[str, Any]  # Model selection, domain, quality settings
    status: TaskStatus
    current_stage: Optional[str]
    progress_percentage: float
    error_message: Optional[str]
    created_at: datetime
    updated_at: datetime

Pipeline Stages

  • transcribe: Multi-pass transcription with confidence scoring
  • enhance: AI-powered text refinement using DeepSeek
  • diarize: Speaker identification using Pyannote.audio
  • merge: Combine transcript and diarization results

2. ModelManager Singleton

Central model management to prevent memory duplication and enable fast model switching:

class ModelManager:
    """Singleton for managing AI model lifecycle and caching"""
    
    def __init__(self):
        self._models: Dict[str, Any] = {}
        self._lora_adapters: Dict[str, Any] = {}
        self._model_configs: Dict[str, Dict] = {}
    
    async def get_model(self, model_type: str, config: Dict) -> Any:
        """Get or load model with caching"""
        cache_key = self._generate_cache_key(model_type, config)
        
        if cache_key not in self._models:
            model = await self._load_model(model_type, config)
            self._models[cache_key] = model
        
        return self._models[cache_key]
    
    async def load_lora_adapter(self, domain: str) -> Any:
        """Load LoRA adapter for domain-specific processing"""
        if domain not in self._lora_adapters:
            adapter = await self._load_lora_adapter(domain)
            self._lora_adapters[domain] = adapter
        
        return self._lora_adapters[domain]

3. Multi-Pass Transcription Pipeline

Intelligent multi-stage processing for maximum accuracy:

class MultiPassTranscriptionPipeline:
    """Multi-pass transcription pipeline for 99.5%+ accuracy"""
    
    def __init__(self, config: PipelineConfig):
        self.config = config
        self.model_manager = ModelManager()
    
    async def process(self, audio_file: Path) -> TranscriptionResult:
        """Execute multi-pass transcription pipeline"""
        
        # Stage 1: Fast initial transcription
        initial_result = await self._fast_pass(audio_file)
        
        # Stage 2: Confidence scoring and refinement
        confidence_scores = self._calculate_confidence(initial_result)
        refinement_segments = self._identify_low_confidence_segments(confidence_scores)
        
        if refinement_segments:
            refined_result = await self._refinement_pass(audio_file, refinement_segments)
            merged_result = self._merge_transcripts(initial_result, refined_result)
        else:
            merged_result = initial_result
        
        # Stage 3: AI enhancement (optional)
        if self.config.enable_enhancement:
            enhanced_result = await self._enhancement_pass(merged_result)
            return enhanced_result
        
        return merged_result
    
    async def _fast_pass(self, audio_file: Path) -> TranscriptionResult:
        """First pass using fast model (distil-small.en)"""
        model = await self.model_manager.get_model("whisper", {
            "model": "distil-small.en",
            "quantized": True
        })
        return await self._transcribe_with_model(audio_file, model)
    
    async def _refinement_pass(self, audio_file: Path, segments: List[Segment]) -> TranscriptionResult:
        """Refinement pass using accurate model (distil-large-v3)"""
        model = await self.model_manager.get_model("whisper", {
            "model": "distil-large-v3",
            "quantized": True,
            "segments": segments
        })
        return await self._transcribe_with_model(audio_file, model)
    
    async def _enhancement_pass(self, transcript: TranscriptionResult) -> TranscriptionResult:
        """AI enhancement using DeepSeek"""
        enhancer = await self.model_manager.get_model("deepseek", {})
        return await enhancer.enhance_transcript(transcript)

4. Speaker Diarization Service

Pyannote.audio integration for speaker identification:

class SpeakerDiarizationService:
    """Speaker diarization using Pyannote.audio"""
    
    def __init__(self, config: DiarizationConfig):
        self.config = config
        self.model_manager = ModelManager()
        self._embedding_model = None
        self._clustering_model = None
    
    async def diarize(self, audio_file: Path) -> DiarizationResult:
        """Perform speaker diarization on audio file"""
        
        # Load models (cached)
        if not self._embedding_model:
            self._embedding_model = await self.model_manager.get_model("pyannote_embedding", {})
        if not self._clustering_model:
            self._clustering_model = await self.model_manager.get_model("pyannote_clustering", {})
        
        # Extract speaker embeddings
        embeddings = await self._extract_embeddings(audio_file)
        
        # Perform clustering
        speaker_segments = await self._cluster_speakers(embeddings)
        
        # Post-process and validate
        validated_segments = self._validate_speaker_segments(speaker_segments)
        
        return DiarizationResult(
            speaker_segments=validated_segments,
            speaker_count=len(set(seg.speaker_id for seg in validated_segments)),
            confidence_score=self._calculate_diarization_confidence(validated_segments)
        )
    
    async def _extract_embeddings(self, audio_file: Path) -> List[Embedding]:
        """Extract speaker embeddings from audio"""
        # Implementation using Pyannote.audio embedding model
        pass
    
    async def _cluster_speakers(self, embeddings: List[Embedding]) -> List[SpeakerSegment]:
        """Cluster embeddings to identify speakers"""
        # Implementation using Pyannote.audio clustering
        pass

5. Domain Adaptation with LoRA

Lightweight domain-specific model adaptation:

class LoRAAdapterManager:
    """Manage LoRA adapters for domain-specific processing"""
    
    def __init__(self):
        self.model_manager = ModelManager()
        self._base_model = None
        self._current_adapter = None
    
    async def load_domain_adapter(self, domain: str) -> None:
        """Load LoRA adapter for specific domain"""
        
        # Load base model if not loaded
        if not self._base_model:
            self._base_model = await self.model_manager.get_model("whisper_base", {})
        
        # Load domain-specific adapter
        adapter = await self.model_manager.load_lora_adapter(domain)
        
        # Apply adapter to base model
        self._base_model.load_adapter(adapter)
        self._current_adapter = domain
    
    async def auto_detect_domain(self, audio_file: Path) -> str:
        """Automatically detect content domain"""
        # Use keyword analysis or content classification
        # Return detected domain (technical, medical, academic, general)
        pass
    
    async def transcribe_with_domain(self, audio_file: Path, domain: str) -> TranscriptionResult:
        """Transcribe with domain-specific model"""
        await self.load_domain_adapter(domain)
        return await self._base_model.transcribe(audio_file)

🔄 Data Flow Architecture

Parallel Processing Flow

┌─────────────────┐
│   Audio File    │
└─────────────────┘
         │
         ▼
┌─────────────────────────────────────────────────────────┐
│              Parallel Processing Pipeline               │
│  ┌─────────────────┐    ┌─────────────────────────────┐ │
│  │ Transcription   │    │    Diarization              │ │
│  │ Pipeline        │    │    Pipeline                 │ │
│  │                 │    │                             │ │
│  │ • Fast Pass     │    │ • Embedding Extraction      │ │
│  │ • Refinement    │    │ • Speaker Clustering        │ │
│  │ • Enhancement   │    │ • Segment Validation        │ │
│  └─────────────────┘    └─────────────────────────────┘ │
└─────────────────────────────────────────────────────────┘
         │                           │
         ▼                           ▼
┌─────────────────┐    ┌─────────────────────────────┐
│ Transcription   │    │    Diarization              │
│ Result          │    │    Result                   │
└─────────────────┘    └─────────────────────────────┘
         │                           │
         └───────────┬───────────────┘
                     ▼
         ┌─────────────────────────────┐
         │        Merge Service        │
         │                             │
         │ • Align timestamps          │
         │ • Combine speaker labels    │
         │ • Validate consistency      │
         └─────────────────────────────┘
                     ▼
         ┌─────────────────────────────┐
         │      Final Transcript       │
         │                             │
         │ • High accuracy text        │
         │ • Speaker identification    │
         │ • Confidence scores         │
         └─────────────────────────────┘

State Management Flow

class ProcessingJobManager:
    """Manage processing job lifecycle and state transitions"""
    
    async def create_job(self, media_file_id: UUID, config: PipelineConfig) -> ProcessingJob:
        """Create new processing job"""
        job = ProcessingJob(
            id=uuid4(),
            media_file_id=media_file_id,
            pipeline_config=config,
            status=TaskStatus.QUEUED,
            created_at=datetime.utcnow()
        )
        await self._save_job(job)
        return job
    
    async def execute_job(self, job: ProcessingJob) -> None:
        """Execute processing job with state management"""
        
        try:
            # Update status to processing
            await self._update_job_status(job.id, TaskStatus.PROCESSING)
            
            # Execute pipeline stages
            for stage in job.pipeline_config.stages:
                await self._update_job_stage(job.id, stage)
                result = await self._execute_stage(job, stage)
                
                if not result.success:
                    raise ProcessingError(f"Stage {stage} failed: {result.error}")
            
            # Mark as completed
            await self._update_job_status(job.id, TaskStatus.COMPLETED)
            
        except Exception as e:
            await self._update_job_status(job.id, TaskStatus.FAILED, str(e))
            raise

🚀 Performance Optimization Strategies

1. Memory Optimization

class MemoryOptimizer:
    """Memory optimization strategies for v2"""
    
    def __init__(self):
        self.max_memory_gb = 8
        self.current_usage_gb = 0
    
    async def optimize_model_loading(self, model_config: Dict) -> Dict:
        """Apply memory optimizations to model loading"""
        
        # 8-bit quantization
        if model_config.get("quantized", True):
            model_config["torch_dtype"] = torch.int8
        
        # Gradient checkpointing for large models
        if model_config.get("model_size") == "large":
            model_config["gradient_checkpointing"] = True
        
        # Model offloading for very large models
        if self.current_usage_gb > self.max_memory_gb * 0.8:
            model_config["device_map"] = "auto"
        
        return model_config
    
    async def cleanup_unused_models(self) -> None:
        """Clean up unused models to free memory"""
        unused_models = self._identify_unused_models()
        for model in unused_models:
            await self._unload_model(model)

2. CPU Optimization

class CPUOptimizer:
    """CPU optimization for parallel processing"""
    
    def __init__(self):
        self.cpu_count = os.cpu_count()
        self.optimal_worker_count = min(self.cpu_count, 8)
    
    async def configure_worker_pool(self) -> AsyncWorkerPool:
        """Configure optimal worker pool size"""
        return AsyncWorkerPool(
            max_workers=self.optimal_worker_count,
            thread_name_prefix="trax_worker"
        )
    
    async def optimize_audio_processing(self, audio_file: Path) -> Path:
        """Optimize audio for processing"""
        # Convert to optimal format (16kHz mono WAV)
        # Apply noise reduction if needed
        # Chunk large files appropriately
        pass

3. Pipeline Optimization

class PipelineOptimizer:
    """Optimize pipeline execution for performance"""
    
    async def execute_parallel_stages(self, job: ProcessingJob) -> Dict[str, Any]:
        """Execute independent stages in parallel"""
        
        # Identify parallel stages
        parallel_stages = self._identify_parallel_stages(job.pipeline_config)
        
        # Execute in parallel
        tasks = []
        for stage in parallel_stages:
            task = asyncio.create_task(self._execute_stage(job, stage))
            tasks.append(task)
        
        # Wait for completion
        results = await asyncio.gather(*tasks, return_exceptions=True)
        
        return dict(zip(parallel_stages, results))
    
    def _identify_parallel_stages(self, config: PipelineConfig) -> List[str]:
        """Identify stages that can run in parallel"""
        # Transcription and diarization can run in parallel
        # Enhancement must wait for transcription
        # Merging must wait for both transcription and diarization
        pass

💻 CLI Interface Architecture

Enhanced CLI Interface

class TraxCLI:
    """Enhanced CLI interface for Trax v2"""
    
    def __init__(self):
        self.progress_reporter = ProgressReporter()
        self.batch_processor = BatchProcessor()
        self.logger = setup_logging()
    
    async def transcribe_single(self, file_path: Path, config: PipelineConfig) -> None:
        """Transcribe a single file with enhanced progress reporting"""
        
        # Validate file
        self._validate_file(file_path)
        
        # Create processing job
        job = await self._create_job(file_path, config)
        
        # Process with real-time progress
        await self._process_with_progress(job)
        
        # Display results
        self._display_results(job)
    
    async def transcribe_batch(self, directory: Path, config: PipelineConfig) -> None:
        """Process batch of files with enhanced progress reporting"""
        
        # Validate directory
        files = self._validate_directory(directory)
        
        # Create batch job
        batch_job = await self._create_batch_job(files, config)
        
        # Process batch with progress
        await self._process_batch_with_progress(batch_job)
        
        # Display batch results
        self._display_batch_results(batch_job)
    
    def _validate_file(self, file_path: Path) -> None:
        """Validate single file for processing"""
        if not file_path.exists():
            raise FileNotFoundError(f"File not found: {file_path}")
        
        if file_path.stat().st_size > 500 * 1024 * 1024:  # 500MB
            raise ValueError(f"File too large: {file_path}")
        
        if file_path.suffix.lower() not in ['.mp3', '.mp4', '.wav', '.m4a', '.webm']:
            raise ValueError(f"Unsupported format: {file_path.suffix}")
    
    def _validate_directory(self, directory: Path) -> List[Path]:
        """Validate directory and return list of supported files"""
        if not directory.exists():
            raise FileNotFoundError(f"Directory not found: {directory}")
        
        supported_extensions = {'.mp3', '.mp4', '.wav', '.m4a', '.webm'}
        files = [
            f for f in directory.iterdir()
            if f.is_file() and f.suffix.lower() in supported_extensions
        ]
        
        if not files:
            raise ValueError(f"No supported files found in: {directory}")
        
        return files

Progress Reporting

class ProgressReporter:
    """Real-time progress reporting for CLI"""
    
    def __init__(self):
        self.start_time = None
        self.current_stage = None
    
    async def report_progress(self, job: ProcessingJob) -> None:
        """Report real-time progress for a job"""
        
        if self.start_time is None:
            self.start_time = time.time()
        
        # Calculate progress
        elapsed = time.time() - self.start_time
        progress = job.progress_percentage
        
        # Display progress bar
        self._display_progress_bar(progress, elapsed)
        
        # Display current stage
        if job.current_stage != self.current_stage:
            self.current_stage = job.current_stage
            self._display_stage_info(job.current_stage)
        
        # Display performance metrics
        self._display_performance_metrics(job)
    
    def _display_progress_bar(self, progress: float, elapsed: float) -> None:
        """Display ASCII progress bar"""
        bar_length = 50
        filled_length = int(bar_length * progress / 100)
        bar = '█' * filled_length + '-' * (bar_length - filled_length)
        
        print(f"\rProgress: [{bar}] {progress:.1f}% ({elapsed:.1f}s)", end='', flush=True)
    
    def _display_stage_info(self, stage: str) -> None:
        """Display current processing stage"""
        print(f"\n🔄 {stage.title()}...")
    
    def _display_performance_metrics(self, job: ProcessingJob) -> None:
        """Display performance metrics"""
        if hasattr(job, 'performance_metrics'):
            metrics = job.performance_metrics
            print(f"   CPU: {metrics.get('cpu_usage', 0):.1f}% | "
                  f"Memory: {metrics.get('memory_usage_gb', 0):.1f}GB | "
                  f"Speed: {metrics.get('processing_speed', 0):.1f}x")

🔧 Configuration Management

Pipeline Configuration

@dataclass
class PipelineConfig:
    """Configuration for v2 processing pipeline"""
    
    # Pipeline version
    version: str = "v2"  # v1, v2, v2+
    
    # Model selection
    transcription_model: str = "distil-large-v3"
    enhancement_model: str = "deepseek"
    diarization_model: str = "pyannote"
    
    # Quality settings
    accuracy_threshold: float = 0.995  # 99.5%
    confidence_threshold: float = 0.8
    
    # Domain settings
    domain: Optional[str] = None  # technical, medical, academic, general
    auto_detect_domain: bool = True
    
    # Performance settings
    enable_quantization: bool = True
    parallel_processing: bool = True
    max_workers: int = 8
    
    # Diarization settings
    enable_diarization: bool = False
    min_speaker_count: int = 2
    max_speaker_count: int = 10
    
    # Enhancement settings
    enable_enhancement: bool = True
    enhancement_prompts: Dict[str, str] = field(default_factory=dict)

Environment Configuration

class TraxV2Config:
    """Configuration management for Trax v2"""
    
    def __init__(self):
        self.load_from_env()
        self.load_from_file()
        self.validate()
    
    def load_from_env(self):
        """Load configuration from environment variables"""
        self.openai_api_key = os.getenv("OPENAI_API_KEY")
        self.deepseek_api_key = os.getenv("DEEPSEEK_API_KEY")
        self.huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
        
        # Performance settings
        self.max_memory_gb = int(os.getenv("TRAX_MAX_MEMORY_GB", "8"))
        self.max_workers = int(os.getenv("TRAX_MAX_WORKERS", "8"))
        self.enable_quantization = os.getenv("TRAX_ENABLE_QUANTIZATION", "true").lower() == "true"
    
    def validate(self):
        """Validate configuration"""
        if not self.openai_api_key:
            raise ConfigurationError("OPENAI_API_KEY is required")
        if not self.deepseek_api_key:
            raise ConfigurationError("DEEPSEEK_API_KEY is required")
        if not self.huggingface_token:
            raise ConfigurationError("HUGGINGFACE_TOKEN is required for diarization")

🧪 Testing Architecture

Unit Testing Strategy

class TraxV2TestSuite:
    """Comprehensive test suite for Trax v2"""
    
    def test_multi_pass_pipeline(self):
        """Test multi-pass transcription pipeline"""
        pipeline = MultiPassTranscriptionPipeline(self.test_config)
        result = await pipeline.process(self.test_audio_file)
        
        assert result.accuracy_estimate >= 0.995
        assert result.processing_time_ms < 25000  # <25 seconds
        assert len(result.segments) > 0
    
    def test_diarization_service(self):
        """Test speaker diarization"""
        diarization = SpeakerDiarizationService(self.test_config)
        result = await diarization.diarize(self.test_multi_speaker_file)
        
        assert result.speaker_count >= 2
        assert result.confidence_score >= 0.9
        assert len(result.speaker_segments) > 0
    
    def test_lora_adapter_manager(self):
        """Test domain adaptation"""
        adapter_manager = LoRAAdapterManager()
        await adapter_manager.load_domain_adapter("technical")
        
        result = await adapter_manager.transcribe_with_domain(
            self.test_technical_file, "technical"
        )
        
        assert result.domain_used == "technical"
        assert result.accuracy_estimate > 0.99

Integration Testing Strategy

class TraxV2IntegrationTests:
    """Integration tests for complete v2 pipeline"""
    
    async def test_complete_v2_pipeline(self):
        """Test complete v2 pipeline with diarization"""
        job_manager = ProcessingJobManager()
        
        # Create job
        job = await job_manager.create_job(
            self.test_file_id,
            PipelineConfig(version="v2+", enable_diarization=True)
        )
        
        # Execute job
        await job_manager.execute_job(job)
        
        # Verify results
        assert job.status == TaskStatus.COMPLETED
        assert job.transcript.accuracy_estimate >= 0.995
        assert job.transcript.speaker_count >= 2
        assert job.processing_time_ms < 25000
    
    async def test_cli_batch_processing(self):
        """Test CLI batch processing with multiple files"""
        cli = TraxCLI()
        
        # Process batch of files
        test_directory = Path("test_files")
        config = PipelineConfig(version="v2", enable_diarization=True)
        
        await cli.transcribe_batch(test_directory, config)
        
        # Verify all files processed
        results = await self._get_batch_results()
        assert len(results) == len(list(test_directory.glob("*.mp3")))
        assert all(result.status == "completed" for result in results)

📊 Performance Monitoring

Metrics Collection

class PerformanceMonitor:
    """Monitor and collect performance metrics"""
    
    def __init__(self):
        self.metrics: Dict[str, List[float]] = defaultdict(list)
    
    async def record_metric(self, metric_name: str, value: float):
        """Record a performance metric"""
        self.metrics[metric_name].append(value)
    
    async def get_performance_report(self) -> Dict[str, Dict]:
        """Generate performance report"""
        report = {}
        
        for metric_name, values in self.metrics.items():
            report[metric_name] = {
                "count": len(values),
                "mean": statistics.mean(values),
                "median": statistics.median(values),
                "min": min(values),
                "max": max(values),
                "std": statistics.stdev(values) if len(values) > 1 else 0
            }
        
        return report
    
    async def check_performance_targets(self) -> Dict[str, bool]:
        """Check if performance targets are met"""
        targets = {
            "processing_time_5min": 25000,  # <25 seconds
            "accuracy_threshold": 0.995,    # 99.5%+
            "memory_usage_gb": 8,           # <8GB
            "diarization_accuracy": 0.9     # 90%+
        }
        
        results = {}
        for target_name, target_value in targets.items():
            if target_name in self.metrics:
                current_value = statistics.mean(self.metrics[target_name])
                results[target_name] = current_value <= target_value
        
        return results

🔄 Migration Strategy

From v1 to v2

class TraxV2Migration:
    """Migration utilities for upgrading from v1 to v2"""
    
    async def migrate_database_schema(self):
        """Migrate database schema for v2 features"""
        
        # Add new tables
        await self._create_speaker_profiles_table()
        await self._create_processing_jobs_table()
        
        # Modify existing tables
        await self._add_v2_columns_to_transcripts()
        await self._add_v2_columns_to_media_files()
    
    async def migrate_existing_transcripts(self):
        """Migrate existing v1 transcripts to v2 format"""
        
        v1_transcripts = await self._get_v1_transcripts()
        
        for transcript in v1_transcripts:
            # Update schema
            transcript.pipeline_version = "v1"
            transcript.merged_content = transcript.raw_content
            
            # Save updated transcript
            await self._save_transcript(transcript)
    
    async def validate_migration(self) -> bool:
        """Validate successful migration"""
        
        # Check schema
        schema_valid = await self._validate_schema()
        
        # Check data integrity
        data_valid = await self._validate_data_integrity()
        
        # Check functionality
        functionality_valid = await self._test_v2_functionality()
        
        return schema_valid and data_valid and functionality_valid

This architecture document provides the technical foundation for implementing Trax v2, focusing on high performance and speaker diarization while maintaining the simplicity and determinism of the single-node design.