trax/examples/caching_pipeline.py

480 lines
16 KiB
Python

#!/usr/bin/env python3
"""Example: Multi-layer caching for transcription pipeline.
This example demonstrates:
- Using the AI Assistant Library's cache components
- Multi-layer caching strategy (memory, database, filesystem)
- Cache invalidation and warming
- Performance metrics and cost savings
"""
import asyncio
import logging
import hashlib
import json
from pathlib import Path
from typing import Dict, Any, Optional, List
from datetime import datetime, timedelta
import time
# Add parent directory to path for imports
import sys
sys.path.insert(0, str(Path(__file__).parent.parent))
# Simplified cache classes for the example
class MemoryCache:
"""Simple memory cache."""
def __init__(self, default_ttl=3600, max_size=100):
self.cache = {}
self.default_ttl = default_ttl
self.max_size = max_size
async def get(self, key):
return self.cache.get(key)
async def set(self, key, value, ttl=None):
self.cache[key] = value
async def delete(self, key):
return self.cache.pop(key, None) is not None
async def size(self):
return len(self.cache)
class CacheManager:
"""Base cache manager."""
pass
def cached(ttl=3600):
"""Simple cache decorator."""
def decorator(func):
cache = {}
async def wrapper(*args, **kwargs):
key = str(args) + str(kwargs)
if key in cache:
return cache[key]
result = await func(*args, **kwargs)
cache[key] = result
return result
return wrapper
return decorator
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class TranscriptCache(CacheManager):
"""Multi-layer cache for transcription pipeline.
Implements a three-tier caching strategy:
1. Memory cache - Hot data, fast access
2. Database cache - Persistent, searchable
3. Filesystem cache - Large files, audio data
"""
def __init__(self):
"""Initialize multi-layer cache."""
super().__init__()
# Layer 1: Memory cache for hot data
self.memory_cache = MemoryCache(
default_ttl=3600, # 1 hour
max_size=100, # Maximum 100 entries
)
# Layer 2: Database cache (simulated with file)
self.db_cache_file = Path("cache_db.json")
self.db_cache_data = self._load_db_cache()
# Layer 3: Filesystem cache for audio
self.fs_cache_dir = Path("audio_cache")
self.fs_cache_dir.mkdir(exist_ok=True)
# Metrics tracking
self.metrics = {
"memory_hits": 0,
"db_hits": 0,
"fs_hits": 0,
"misses": 0,
"cost_saved": 0.0,
"time_saved": 0.0,
}
def _load_db_cache(self) -> Dict[str, Any]:
"""Load database cache from file."""
if self.db_cache_file.exists():
with open(self.db_cache_file, 'r') as f:
return json.load(f)
return {}
def _save_db_cache(self):
"""Save database cache to file."""
with open(self.db_cache_file, 'w') as f:
json.dump(self.db_cache_data, f, indent=2)
async def get_transcript(self, file_hash: str) -> Optional[Dict[str, Any]]:
"""Get transcript from cache with multi-layer lookup.
Args:
file_hash: Hash of the audio file
Returns:
Cached transcript if found, None otherwise
"""
# Layer 1: Check memory cache
cached_data = await self.memory_cache.get(f"transcript:{file_hash}")
if cached_data:
self.metrics["memory_hits"] += 1
logger.info(f"✓ Memory cache hit for {file_hash[:8]}...")
return cached_data
# Layer 2: Check database cache
if file_hash in self.db_cache_data:
self.metrics["db_hits"] += 1
logger.info(f"✓ Database cache hit for {file_hash[:8]}...")
# Promote to memory cache
data = self.db_cache_data[file_hash]
await self.memory_cache.set(f"transcript:{file_hash}", data)
return data
# Layer 3: Check filesystem for processed audio
audio_cache_path = self.fs_cache_dir / f"{file_hash}.wav"
if audio_cache_path.exists():
self.metrics["fs_hits"] += 1
logger.info(f"✓ Filesystem cache hit for {file_hash[:8]}...")
# Return path indicator (transcript would be re-generated from cached audio)
return {"cached_audio_path": str(audio_cache_path)}
# Cache miss
self.metrics["misses"] += 1
logger.info(f"✗ Cache miss for {file_hash[:8]}...")
return None
async def set_transcript(self, file_hash: str, transcript: Dict[str, Any],
audio_path: Optional[Path] = None):
"""Store transcript in multi-layer cache.
Args:
file_hash: Hash of the audio file
transcript: Transcript data to cache
audio_path: Optional preprocessed audio to cache
"""
# Layer 1: Store in memory cache
await self.memory_cache.set(f"transcript:{file_hash}", transcript)
# Layer 2: Store in database cache
self.db_cache_data[file_hash] = {
**transcript,
"cached_at": datetime.now().isoformat(),
}
self._save_db_cache()
# Layer 3: Store preprocessed audio if provided
if audio_path and audio_path.exists():
cache_path = self.fs_cache_dir / f"{file_hash}.wav"
import shutil
shutil.copy2(audio_path, cache_path)
logger.info(f"✓ Cached audio to {cache_path.name}")
async def invalidate(self, file_hash: str):
"""Invalidate cache entry across all layers.
Args:
file_hash: Hash of the file to invalidate
"""
# Remove from memory cache
await self.memory_cache.delete(f"transcript:{file_hash}")
# Remove from database cache
if file_hash in self.db_cache_data:
del self.db_cache_data[file_hash]
self._save_db_cache()
# Remove from filesystem cache
audio_cache_path = self.fs_cache_dir / f"{file_hash}.wav"
if audio_cache_path.exists():
audio_cache_path.unlink()
logger.info(f"✓ Invalidated cache for {file_hash[:8]}...")
def get_stats(self) -> Dict[str, Any]:
"""Get cache statistics and metrics.
Returns:
Cache performance metrics
"""
total_hits = (
self.metrics["memory_hits"] +
self.metrics["db_hits"] +
self.metrics["fs_hits"]
)
total_requests = total_hits + self.metrics["misses"]
hit_rate = (total_hits / total_requests * 100) if total_requests > 0 else 0
return {
"hit_rate": f"{hit_rate:.1f}%",
"memory_hits": self.metrics["memory_hits"],
"db_hits": self.metrics["db_hits"],
"fs_hits": self.metrics["fs_hits"],
"misses": self.metrics["misses"],
"cost_saved": f"${self.metrics['cost_saved']:.4f}",
"time_saved": f"{self.metrics['time_saved']:.1f}s",
"memory_size": await self.memory_cache.size(),
"db_size": len(self.db_cache_data),
"fs_size": len(list(self.fs_cache_dir.glob("*.wav"))),
}
# Simulated expensive operations
async def expensive_transcription(audio_path: Path) -> Dict[str, Any]:
"""Simulate expensive transcription operation.
This would normally call Whisper or another transcription service.
"""
logger.info(f"🔄 Performing expensive transcription for {audio_path.name}...")
# Simulate processing time
await asyncio.sleep(3)
# Simulate API cost
cost = 0.01 # $0.01 per transcription
return {
"text": f"Transcribed content of {audio_path.name}",
"segments": [
{"start": 0, "end": 5, "text": "Segment 1"},
{"start": 5, "end": 10, "text": "Segment 2"},
],
"duration": 10.0,
"cost": cost,
}
async def expensive_enhancement(transcript: str) -> str:
"""Simulate expensive AI enhancement.
This would normally call DeepSeek or another AI service.
"""
logger.info("🔄 Performing expensive AI enhancement...")
# Simulate processing time
await asyncio.sleep(2)
# Simulate API cost
cost = 0.005 # $0.005 per enhancement
return f"[ENHANCED] {transcript}"
# Cached versions using decorator
@cached(ttl=7200) # 2 hour cache
async def cached_transcription(audio_path: Path, cache: TranscriptCache) -> Dict[str, Any]:
"""Cached transcription with automatic memoization."""
file_hash = hashlib.sha256(str(audio_path).encode()).hexdigest()
# Check cache first
cached_result = await cache.get_transcript(file_hash)
if cached_result and "text" in cached_result:
cache.metrics["time_saved"] += 3.0 # Saved 3 seconds
cache.metrics["cost_saved"] += 0.01 # Saved $0.01
return cached_result
# Perform expensive operation
start_time = time.time()
result = await expensive_transcription(audio_path)
elapsed = time.time() - start_time
# Cache the result
await cache.set_transcript(file_hash, result)
return result
@cached(ttl=86400) # 24 hour cache for enhancement
async def cached_enhancement(transcript: str) -> str:
"""Cached AI enhancement."""
# This uses the decorator's built-in caching
return await expensive_enhancement(transcript)
async def warm_cache(cache: TranscriptCache, files: List[Path]):
"""Warm the cache with predictive loading.
Args:
cache: Cache manager
files: Files to pre-cache
"""
logger.info(f"🔥 Warming cache with {len(files)} files...")
for file_path in files:
file_hash = hashlib.sha256(str(file_path).encode()).hexdigest()
# Check if already cached
if await cache.get_transcript(file_hash):
continue
# Pre-load into cache
result = await expensive_transcription(file_path)
await cache.set_transcript(file_hash, result)
logger.info("✓ Cache warming complete")
async def main():
"""Run caching examples."""
# Initialize cache
cache = TranscriptCache()
# Create test files
test_files = []
for i in range(5):
file_path = Path(f"test_audio_{i}.mp3")
file_path.touch()
test_files.append(file_path)
try:
# Example 1: Basic caching with hit/miss demonstration
print("\n" + "="*60)
print("Example 1: Multi-layer Caching")
print("="*60)
# First access - cache miss
print("\nFirst access (cache miss):")
start = time.time()
result1 = await cached_transcription(test_files[0], cache)
time1 = time.time() - start
print(f" Time: {time1:.2f}s")
print(f" Result: {result1['text']}")
# Second access - cache hit
print("\nSecond access (cache hit):")
start = time.time()
result2 = await cached_transcription(test_files[0], cache)
time2 = time.time() - start
print(f" Time: {time2:.2f}s")
print(f" Speedup: {time1/time2:.1f}x faster")
# Example 2: Cache warming for batch processing
print("\n" + "="*60)
print("Example 2: Cache Warming")
print("="*60)
# Warm cache with predicted files
await warm_cache(cache, test_files[1:3])
# Process files (should all be cache hits)
print("\nProcessing pre-warmed files:")
for file_path in test_files[1:3]:
start = time.time()
result = await cached_transcription(file_path, cache)
elapsed = time.time() - start
print(f" {file_path.name}: {elapsed:.3f}s (cached)")
# Example 3: Cache invalidation
print("\n" + "="*60)
print("Example 3: Cache Invalidation")
print("="*60)
file_hash = hashlib.sha256(str(test_files[0]).encode()).hexdigest()
print(f"\nInvalidating cache for {test_files[0].name}...")
await cache.invalidate(file_hash)
# Access after invalidation - cache miss again
print("Access after invalidation:")
start = time.time()
result = await cached_transcription(test_files[0], cache)
elapsed = time.time() - start
print(f" Time: {elapsed:.2f}s (cache miss after invalidation)")
# Example 4: Enhancement caching with decorator
print("\n" + "="*60)
print("Example 4: AI Enhancement Caching")
print("="*60)
transcript = "This is a sample transcript that needs enhancement."
print("\nFirst enhancement (expensive):")
start = time.time()
enhanced1 = await cached_enhancement(transcript)
time1 = time.time() - start
print(f" Time: {time1:.2f}s")
print(f" Result: {enhanced1}")
print("\nSecond enhancement (cached):")
start = time.time()
enhanced2 = await cached_enhancement(transcript)
time2 = time.time() - start
print(f" Time: {time2:.3f}s")
print(f" Speedup: {time1/time2:.1f}x faster")
# Example 5: Cache statistics and metrics
print("\n" + "="*60)
print("Example 5: Cache Performance Metrics")
print("="*60)
stats = cache.get_stats()
print("\nCache Statistics:")
for key, value in stats.items():
print(f" {key.replace('_', ' ').title()}: {value}")
# Calculate ROI
if cache.metrics["cost_saved"] > 0:
print(f"\n💰 Cost Savings Analysis:")
print(f" Total saved: ${cache.metrics['cost_saved']:.4f}")
print(f" Time saved: {cache.metrics['time_saved']:.1f} seconds")
print(f" Efficiency: {stats['hit_rate']} cache hit rate")
# Example 6: Cache layer distribution
print("\n" + "="*60)
print("Example 6: Cache Layer Analysis")
print("="*60)
total_hits = (
cache.metrics["memory_hits"] +
cache.metrics["db_hits"] +
cache.metrics["fs_hits"]
)
if total_hits > 0:
print("\nCache Hit Distribution:")
print(f" Memory Layer: {cache.metrics['memory_hits']/total_hits*100:.1f}%")
print(f" Database Layer: {cache.metrics['db_hits']/total_hits*100:.1f}%")
print(f" Filesystem Layer: {cache.metrics['fs_hits']/total_hits*100:.1f}%")
finally:
# Cleanup
for file_path in test_files:
if file_path.exists():
file_path.unlink()
# Clean cache files
if cache.db_cache_file.exists():
cache.db_cache_file.unlink()
for cached_file in cache.fs_cache_dir.glob("*.wav"):
cached_file.unlink()
if cache.fs_cache_dir.exists():
cache.fs_cache_dir.rmdir()
print("\n✓ Cleanup complete")
if __name__ == "__main__":
print("Trax Caching Pipeline Example")
print("Using AI Assistant Library for multi-layer caching")
print("-" * 60)
# Run the async main function
asyncio.run(main())