trax/examples/lora_adapter_example.py

150 lines
5.4 KiB
Python

"""
Example demonstrating LoRA Adapter usage for domain-specific adaptation.
This example shows how to create, load, switch, and save domain-specific
LoRA adapters for Whisper model adaptation.
"""
import tempfile
from pathlib import Path
from transformers import WhisperForConditionalGeneration, WhisperProcessor
from src.adapters import DomainAdapter, LoRAConfig
def main():
"""Demonstrate LoRA adapter functionality."""
print("🚀 LoRA Adapter Example")
print("=" * 50)
# Create a temporary directory for adapters
with tempfile.TemporaryDirectory() as temp_dir:
adapter_dir = Path(temp_dir) / "adapters"
# Mock Whisper model (in real usage, load actual model)
print("📦 Loading base Whisper model...")
# Note: In real usage, you would load the actual model:
# model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base")
# processor = WhisperProcessor.from_pretrained("openai/whisper-base")
# For this example, we'll use a mock
from unittest.mock import Mock
model = Mock()
model.config = Mock()
model.config.hidden_size = 768
model.config.num_attention_heads = 12
model.config.num_hidden_layers = 12
# Initialize domain adapter
print("🔧 Initializing DomainAdapter...")
domain_adapter = DomainAdapter(
base_model=model,
adapter_dir=adapter_dir
)
# Create adapters for different domains
print("\n🎯 Creating domain-specific adapters...")
# Technical domain adapter
tech_config = LoRAConfig(
rank=8,
alpha=16,
dropout=0.1,
target_modules=["q_proj", "v_proj"]
)
tech_adapter = domain_adapter.create_adapter("technical", tech_config)
print("✅ Created technical domain adapter")
# Medical domain adapter
medical_config = LoRAConfig(
rank=16,
alpha=32,
dropout=0.15,
target_modules=["q_proj", "v_proj", "k_proj"]
)
medical_adapter = domain_adapter.create_adapter("medical", medical_config)
print("✅ Created medical domain adapter")
# Legal domain adapter
legal_config = LoRAConfig(
rank=12,
alpha=24,
dropout=0.1,
target_modules=["q_proj", "v_proj"]
)
legal_adapter = domain_adapter.create_adapter("legal", legal_config)
print("✅ Created legal domain adapter")
# List available adapters
print(f"\n📋 Available adapters: {domain_adapter.list_adapters()}")
# Switch between adapters
print("\n🔄 Switching between adapters...")
domain_adapter.switch_adapter("technical")
print(f"✅ Switched to technical adapter: {domain_adapter.get_active_adapter()}")
domain_adapter.switch_adapter("medical")
print(f"✅ Switched to medical adapter: {domain_adapter.get_active_adapter()}")
domain_adapter.switch_adapter("legal")
print(f"✅ Switched to legal adapter: {domain_adapter.get_active_adapter()}")
# Get adapter information
print("\n📊 Adapter Information:")
for adapter_name in domain_adapter.list_adapters():
info = domain_adapter.get_adapter_info(adapter_name)
print(f" {adapter_name}:")
print(f" - Rank: {info['rank']}")
print(f" - Alpha: {info['alpha']}")
print(f" - Dropout: {info['dropout']}")
print(f" - Target Modules: {info['target_modules']}")
print(f" - Active: {info['is_active']}")
# Save adapters to disk
print("\n💾 Saving adapters to disk...")
for adapter_name in domain_adapter.list_adapters():
domain_adapter.save_adapter(adapter_name)
print(f"✅ Saved {adapter_name} adapter")
# Demonstrate loading from disk
print("\n📂 Loading adapters from disk...")
# Create a new domain adapter instance
new_domain_adapter = DomainAdapter(
base_model=model,
adapter_dir=adapter_dir
)
# Load adapters
for adapter_name in ["technical", "medical", "legal"]:
loaded_adapter = new_domain_adapter.load_adapter(adapter_name)
print(f"✅ Loaded {adapter_name} adapter")
print(f"📋 Loaded adapters: {new_domain_adapter.list_adapters()}")
# Demonstrate error handling
print("\n⚠️ Error handling examples:")
try:
domain_adapter.switch_adapter("nonexistent")
except Exception as e:
print(f"❌ Expected error when switching to nonexistent adapter: {type(e).__name__}")
try:
domain_adapter.create_adapter("technical", tech_config)
except Exception as e:
print(f"❌ Expected error when creating duplicate adapter: {type(e).__name__}")
# Switch back to base model
print("\n🏠 Switching back to base model...")
domain_adapter.switch_adapter(None)
print(f"✅ Active adapter: {domain_adapter.get_active_adapter()}")
print("\n🎉 LoRA Adapter example completed successfully!")
if __name__ == "__main__":
main()