""" Example demonstrating LoRA Adapter usage for domain-specific adaptation. This example shows how to create, load, switch, and save domain-specific LoRA adapters for Whisper model adaptation. """ import tempfile from pathlib import Path from transformers import WhisperForConditionalGeneration, WhisperProcessor from src.adapters import DomainAdapter, LoRAConfig def main(): """Demonstrate LoRA adapter functionality.""" print("šŸš€ LoRA Adapter Example") print("=" * 50) # Create a temporary directory for adapters with tempfile.TemporaryDirectory() as temp_dir: adapter_dir = Path(temp_dir) / "adapters" # Mock Whisper model (in real usage, load actual model) print("šŸ“¦ Loading base Whisper model...") # Note: In real usage, you would load the actual model: # model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base") # processor = WhisperProcessor.from_pretrained("openai/whisper-base") # For this example, we'll use a mock from unittest.mock import Mock model = Mock() model.config = Mock() model.config.hidden_size = 768 model.config.num_attention_heads = 12 model.config.num_hidden_layers = 12 # Initialize domain adapter print("šŸ”§ Initializing DomainAdapter...") domain_adapter = DomainAdapter( base_model=model, adapter_dir=adapter_dir ) # Create adapters for different domains print("\nšŸŽÆ Creating domain-specific adapters...") # Technical domain adapter tech_config = LoRAConfig( rank=8, alpha=16, dropout=0.1, target_modules=["q_proj", "v_proj"] ) tech_adapter = domain_adapter.create_adapter("technical", tech_config) print("āœ… Created technical domain adapter") # Medical domain adapter medical_config = LoRAConfig( rank=16, alpha=32, dropout=0.15, target_modules=["q_proj", "v_proj", "k_proj"] ) medical_adapter = domain_adapter.create_adapter("medical", medical_config) print("āœ… Created medical domain adapter") # Legal domain adapter legal_config = LoRAConfig( rank=12, alpha=24, dropout=0.1, target_modules=["q_proj", "v_proj"] ) legal_adapter = domain_adapter.create_adapter("legal", legal_config) print("āœ… Created legal domain adapter") # List available adapters print(f"\nšŸ“‹ Available adapters: {domain_adapter.list_adapters()}") # Switch between adapters print("\nšŸ”„ Switching between adapters...") domain_adapter.switch_adapter("technical") print(f"āœ… Switched to technical adapter: {domain_adapter.get_active_adapter()}") domain_adapter.switch_adapter("medical") print(f"āœ… Switched to medical adapter: {domain_adapter.get_active_adapter()}") domain_adapter.switch_adapter("legal") print(f"āœ… Switched to legal adapter: {domain_adapter.get_active_adapter()}") # Get adapter information print("\nšŸ“Š Adapter Information:") for adapter_name in domain_adapter.list_adapters(): info = domain_adapter.get_adapter_info(adapter_name) print(f" {adapter_name}:") print(f" - Rank: {info['rank']}") print(f" - Alpha: {info['alpha']}") print(f" - Dropout: {info['dropout']}") print(f" - Target Modules: {info['target_modules']}") print(f" - Active: {info['is_active']}") # Save adapters to disk print("\nšŸ’¾ Saving adapters to disk...") for adapter_name in domain_adapter.list_adapters(): domain_adapter.save_adapter(adapter_name) print(f"āœ… Saved {adapter_name} adapter") # Demonstrate loading from disk print("\nšŸ“‚ Loading adapters from disk...") # Create a new domain adapter instance new_domain_adapter = DomainAdapter( base_model=model, adapter_dir=adapter_dir ) # Load adapters for adapter_name in ["technical", "medical", "legal"]: loaded_adapter = new_domain_adapter.load_adapter(adapter_name) print(f"āœ… Loaded {adapter_name} adapter") print(f"šŸ“‹ Loaded adapters: {new_domain_adapter.list_adapters()}") # Demonstrate error handling print("\nāš ļø Error handling examples:") try: domain_adapter.switch_adapter("nonexistent") except Exception as e: print(f"āŒ Expected error when switching to nonexistent adapter: {type(e).__name__}") try: domain_adapter.create_adapter("technical", tech_config) except Exception as e: print(f"āŒ Expected error when creating duplicate adapter: {type(e).__name__}") # Switch back to base model print("\nšŸ  Switching back to base model...") domain_adapter.switch_adapter(None) print(f"āœ… Active adapter: {domain_adapter.get_active_adapter()}") print("\nšŸŽ‰ LoRA Adapter example completed successfully!") if __name__ == "__main__": main()