""" API Key Management and Rate Limiting Service Handles API key generation, validation, and rate limiting for the developer platform """ import secrets import hashlib import logging from datetime import datetime, timedelta from typing import Optional, Dict, Any, List from dataclasses import dataclass, asdict from enum import Enum import redis import json logger = logging.getLogger(__name__) class APITier(str, Enum): FREE = "free" PRO = "pro" ENTERPRISE = "enterprise" class RateLimitPeriod(str, Enum): MINUTE = "minute" HOUR = "hour" DAY = "day" MONTH = "month" @dataclass class APIKey: id: str user_id: str name: str key_prefix: str # Only store prefix, not full key key_hash: str # Hash of full key for validation tier: APITier rate_limits: Dict[RateLimitPeriod, int] is_active: bool created_at: datetime last_used_at: Optional[datetime] expires_at: Optional[datetime] metadata: Dict[str, Any] @dataclass class RateLimitStatus: allowed: bool remaining: int reset_time: datetime total_limit: int period: RateLimitPeriod @dataclass class APIUsage: total_requests: int successful_requests: int failed_requests: int last_request_at: Optional[datetime] daily_usage: Dict[str, int] # date -> count monthly_usage: Dict[str, int] # month -> count class APIKeyService: """ API Key Management Service Handles API key generation, validation, rate limiting, and usage tracking """ def __init__(self, redis_client: Optional[redis.Redis] = None): self.redis_client = redis_client or self._create_redis_client() # Default rate limits by tier self.default_rate_limits = { APITier.FREE: { RateLimitPeriod.MINUTE: 10, RateLimitPeriod.HOUR: 100, RateLimitPeriod.DAY: 1000, RateLimitPeriod.MONTH: 10000 }, APITier.PRO: { RateLimitPeriod.MINUTE: 100, RateLimitPeriod.HOUR: 2000, RateLimitPeriod.DAY: 25000, RateLimitPeriod.MONTH: 500000 }, APITier.ENTERPRISE: { RateLimitPeriod.MINUTE: 1000, RateLimitPeriod.HOUR: 10000, RateLimitPeriod.DAY: 100000, RateLimitPeriod.MONTH: 2000000 } } def _create_redis_client(self) -> redis.Redis: """Create Redis client with fallback to mock""" try: client = redis.Redis( host='localhost', port=6379, db=1, # Use db 1 for API keys decode_responses=True ) # Test connection client.ping() logger.info("Connected to Redis for API key service") return client except Exception as e: logger.warning(f"Redis connection failed, using in-memory fallback: {e}") return self._create_mock_redis() def _create_mock_redis(self): """Create mock Redis client for development""" class MockRedis: def __init__(self): self.data = {} self.expiry = {} def get(self, key): if key in self.expiry and datetime.now() > self.expiry[key]: del self.data[key] del self.expiry[key] return None return self.data.get(key) def set(self, key, value, ex=None): self.data[key] = value if ex: self.expiry[key] = datetime.now() + timedelta(seconds=ex) def incr(self, key): current = int(self.data.get(key, 0)) self.data[key] = str(current + 1) return current + 1 def expire(self, key, seconds): self.expiry[key] = datetime.now() + timedelta(seconds=seconds) def ttl(self, key): if key in self.expiry: remaining = (self.expiry[key] - datetime.now()).total_seconds() return max(0, int(remaining)) return -1 def exists(self, key): return key in self.data def delete(self, key): self.data.pop(key, None) self.expiry.pop(key, None) return MockRedis() def generate_api_key(self, user_id: str, name: str, tier: APITier = APITier.FREE, expires_in_days: Optional[int] = None) -> tuple[str, APIKey]: """ Generate a new API key Returns: tuple: (full_api_key, api_key_object) """ try: # Generate secure API key key_id = secrets.token_urlsafe(8) key_secret = secrets.token_urlsafe(32) # Format: ys_{tier}_{key_id}_{secret} full_key = f"ys_{tier.value}_{key_id}_{key_secret}" key_prefix = f"ys_{tier.value}_{key_id}" # Hash the full key for storage key_hash = hashlib.sha256(full_key.encode()).hexdigest() # Calculate expiry expires_at = None if expires_in_days: expires_at = datetime.now() + timedelta(days=expires_in_days) # Create API key object api_key = APIKey( id=key_id, user_id=user_id, name=name, key_prefix=key_prefix, key_hash=key_hash, tier=tier, rate_limits=self.default_rate_limits[tier].copy(), is_active=True, created_at=datetime.now(), last_used_at=None, expires_at=expires_at, metadata={} ) # Store in Redis/database self._store_api_key(api_key) logger.info(f"Generated API key {key_prefix} for user {user_id}") return full_key, api_key except Exception as e: logger.error(f"Failed to generate API key: {e}") raise def validate_api_key(self, api_key: str) -> Optional[APIKey]: """ Validate API key and return key info Args: api_key: Full API key string Returns: APIKey object if valid, None if invalid """ try: # Check format if not api_key.startswith("ys_"): return None # Extract key ID from format parts = api_key.split("_") if len(parts) < 4: return None key_id = parts[2] # Look up key in storage stored_key = self._get_api_key(key_id) if not stored_key: return None # Verify hash key_hash = hashlib.sha256(api_key.encode()).hexdigest() if key_hash != stored_key.key_hash: return None # Check if active if not stored_key.is_active: return None # Check expiry if stored_key.expires_at and datetime.now() > stored_key.expires_at: return None # Update last used stored_key.last_used_at = datetime.now() self._update_api_key(stored_key) return stored_key except Exception as e: logger.error(f"API key validation failed: {e}") return None def check_rate_limit(self, api_key: APIKey, period: RateLimitPeriod = RateLimitPeriod.HOUR) -> RateLimitStatus: """ Check rate limit for API key Args: api_key: Validated API key object period: Time period to check Returns: RateLimitStatus with current status """ try: limit = api_key.rate_limits.get(period, 0) # Create Redis key for rate limiting current_time = datetime.now() if period == RateLimitPeriod.MINUTE: time_key = current_time.strftime("%Y-%m-%d-%H-%M") reset_time = current_time.replace(second=0, microsecond=0) + timedelta(minutes=1) ttl_seconds = 60 elif period == RateLimitPeriod.HOUR: time_key = current_time.strftime("%Y-%m-%d-%H") reset_time = current_time.replace(minute=0, second=0, microsecond=0) + timedelta(hours=1) ttl_seconds = 3600 elif period == RateLimitPeriod.DAY: time_key = current_time.strftime("%Y-%m-%d") reset_time = current_time.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta(days=1) ttl_seconds = 86400 else: # MONTH time_key = current_time.strftime("%Y-%m") next_month = current_time.replace(day=1) + timedelta(days=32) reset_time = next_month.replace(day=1, hour=0, minute=0, second=0, microsecond=0) ttl_seconds = int((reset_time - current_time).total_seconds()) redis_key = f"rate_limit:{api_key.id}:{period.value}:{time_key}" # Get current usage current_usage = int(self.redis_client.get(redis_key) or 0) # Check if limit exceeded if current_usage >= limit: return RateLimitStatus( allowed=False, remaining=0, reset_time=reset_time, total_limit=limit, period=period ) # Increment usage new_usage = self.redis_client.incr(redis_key) if new_usage == 1: # First request in this period self.redis_client.expire(redis_key, ttl_seconds) remaining = max(0, limit - new_usage) return RateLimitStatus( allowed=True, remaining=remaining, reset_time=reset_time, total_limit=limit, period=period ) except Exception as e: logger.error(f"Rate limit check failed: {e}") # Allow request on error, but log it return RateLimitStatus( allowed=True, remaining=999, reset_time=datetime.now() + timedelta(hours=1), total_limit=1000, period=period ) def get_usage_stats(self, api_key: APIKey) -> APIUsage: """Get detailed usage statistics for API key""" try: # Get usage data from Redis user_stats_key = f"usage:{api_key.id}" stats_data = self.redis_client.get(user_stats_key) if stats_data: stats = json.loads(stats_data) return APIUsage( total_requests=stats.get("total_requests", 0), successful_requests=stats.get("successful_requests", 0), failed_requests=stats.get("failed_requests", 0), last_request_at=datetime.fromisoformat(stats["last_request_at"]) if stats.get("last_request_at") else None, daily_usage=stats.get("daily_usage", {}), monthly_usage=stats.get("monthly_usage", {}) ) else: return APIUsage( total_requests=0, successful_requests=0, failed_requests=0, last_request_at=None, daily_usage={}, monthly_usage={} ) except Exception as e: logger.error(f"Failed to get usage stats: {e}") return APIUsage( total_requests=0, successful_requests=0, failed_requests=0, last_request_at=None, daily_usage={}, monthly_usage={} ) def record_request(self, api_key: APIKey, success: bool = True): """Record API request for usage tracking""" try: user_stats_key = f"usage:{api_key.id}" current_stats = self.get_usage_stats(api_key) # Update stats current_stats.total_requests += 1 if success: current_stats.successful_requests += 1 else: current_stats.failed_requests += 1 current_stats.last_request_at = datetime.now() # Update daily/monthly counters today = datetime.now().strftime("%Y-%m-%d") this_month = datetime.now().strftime("%Y-%m") current_stats.daily_usage[today] = current_stats.daily_usage.get(today, 0) + 1 current_stats.monthly_usage[this_month] = current_stats.monthly_usage.get(this_month, 0) + 1 # Store updated stats stats_dict = asdict(current_stats) if stats_dict["last_request_at"]: stats_dict["last_request_at"] = stats_dict["last_request_at"].isoformat() self.redis_client.set(user_stats_key, json.dumps(stats_dict), ex=86400 * 30) # 30 days TTL except Exception as e: logger.error(f"Failed to record request: {e}") def revoke_api_key(self, key_id: str) -> bool: """Revoke (deactivate) an API key""" try: api_key = self._get_api_key(key_id) if api_key: api_key.is_active = False self._update_api_key(api_key) logger.info(f"Revoked API key {key_id}") return True return False except Exception as e: logger.error(f"Failed to revoke API key {key_id}: {e}") return False def list_api_keys(self, user_id: str) -> List[APIKey]: """List all API keys for a user""" try: # In production, this would query the database # For now, return mock data return [] except Exception as e: logger.error(f"Failed to list API keys for user {user_id}: {e}") return [] def _store_api_key(self, api_key: APIKey): """Store API key in Redis/database""" try: key_data = asdict(api_key) # Convert datetime objects to ISO format key_data["created_at"] = key_data["created_at"].isoformat() if key_data["last_used_at"]: key_data["last_used_at"] = key_data["last_used_at"].isoformat() if key_data["expires_at"]: key_data["expires_at"] = key_data["expires_at"].isoformat() redis_key = f"api_key:{api_key.id}" self.redis_client.set(redis_key, json.dumps(key_data), ex=86400 * 365) # 1 year TTL except Exception as e: logger.error(f"Failed to store API key: {e}") raise def _get_api_key(self, key_id: str) -> Optional[APIKey]: """Retrieve API key from Redis/database""" try: redis_key = f"api_key:{key_id}" key_data = self.redis_client.get(redis_key) if not key_data: return None data = json.loads(key_data) # Convert ISO format back to datetime data["created_at"] = datetime.fromisoformat(data["created_at"]) if data["last_used_at"]: data["last_used_at"] = datetime.fromisoformat(data["last_used_at"]) if data["expires_at"]: data["expires_at"] = datetime.fromisoformat(data["expires_at"]) # Convert tier back to enum data["tier"] = APITier(data["tier"]) # Convert rate_limits keys back to enums rate_limits = {} for period_str, limit in data["rate_limits"].items(): rate_limits[RateLimitPeriod(period_str)] = limit data["rate_limits"] = rate_limits return APIKey(**data) except Exception as e: logger.error(f"Failed to get API key {key_id}: {e}") return None def _update_api_key(self, api_key: APIKey): """Update API key in Redis/database""" try: self._store_api_key(api_key) except Exception as e: logger.error(f"Failed to update API key: {e}") raise # Global service instance api_key_service = APIKeyService()