youtube-summarizer/backend/services/api_key_service.py

485 lines
17 KiB
Python

"""
API Key Management and Rate Limiting Service
Handles API key generation, validation, and rate limiting for the developer platform
"""
import secrets
import hashlib
import logging
from datetime import datetime, timedelta
from typing import Optional, Dict, Any, List
from dataclasses import dataclass, asdict
from enum import Enum
import redis
import json
logger = logging.getLogger(__name__)
class APITier(str, Enum):
FREE = "free"
PRO = "pro"
ENTERPRISE = "enterprise"
class RateLimitPeriod(str, Enum):
MINUTE = "minute"
HOUR = "hour"
DAY = "day"
MONTH = "month"
@dataclass
class APIKey:
id: str
user_id: str
name: str
key_prefix: str # Only store prefix, not full key
key_hash: str # Hash of full key for validation
tier: APITier
rate_limits: Dict[RateLimitPeriod, int]
is_active: bool
created_at: datetime
last_used_at: Optional[datetime]
expires_at: Optional[datetime]
metadata: Dict[str, Any]
@dataclass
class RateLimitStatus:
allowed: bool
remaining: int
reset_time: datetime
total_limit: int
period: RateLimitPeriod
@dataclass
class APIUsage:
total_requests: int
successful_requests: int
failed_requests: int
last_request_at: Optional[datetime]
daily_usage: Dict[str, int] # date -> count
monthly_usage: Dict[str, int] # month -> count
class APIKeyService:
"""
API Key Management Service
Handles API key generation, validation, rate limiting, and usage tracking
"""
def __init__(self, redis_client: Optional[redis.Redis] = None):
self.redis_client = redis_client or self._create_redis_client()
# Default rate limits by tier
self.default_rate_limits = {
APITier.FREE: {
RateLimitPeriod.MINUTE: 10,
RateLimitPeriod.HOUR: 100,
RateLimitPeriod.DAY: 1000,
RateLimitPeriod.MONTH: 10000
},
APITier.PRO: {
RateLimitPeriod.MINUTE: 100,
RateLimitPeriod.HOUR: 2000,
RateLimitPeriod.DAY: 25000,
RateLimitPeriod.MONTH: 500000
},
APITier.ENTERPRISE: {
RateLimitPeriod.MINUTE: 1000,
RateLimitPeriod.HOUR: 10000,
RateLimitPeriod.DAY: 100000,
RateLimitPeriod.MONTH: 2000000
}
}
def _create_redis_client(self) -> redis.Redis:
"""Create Redis client with fallback to mock"""
try:
client = redis.Redis(
host='localhost',
port=6379,
db=1, # Use db 1 for API keys
decode_responses=True
)
# Test connection
client.ping()
logger.info("Connected to Redis for API key service")
return client
except Exception as e:
logger.warning(f"Redis connection failed, using in-memory fallback: {e}")
return self._create_mock_redis()
def _create_mock_redis(self):
"""Create mock Redis client for development"""
class MockRedis:
def __init__(self):
self.data = {}
self.expiry = {}
def get(self, key):
if key in self.expiry and datetime.now() > self.expiry[key]:
del self.data[key]
del self.expiry[key]
return None
return self.data.get(key)
def set(self, key, value, ex=None):
self.data[key] = value
if ex:
self.expiry[key] = datetime.now() + timedelta(seconds=ex)
def incr(self, key):
current = int(self.data.get(key, 0))
self.data[key] = str(current + 1)
return current + 1
def expire(self, key, seconds):
self.expiry[key] = datetime.now() + timedelta(seconds=seconds)
def ttl(self, key):
if key in self.expiry:
remaining = (self.expiry[key] - datetime.now()).total_seconds()
return max(0, int(remaining))
return -1
def exists(self, key):
return key in self.data
def delete(self, key):
self.data.pop(key, None)
self.expiry.pop(key, None)
return MockRedis()
def generate_api_key(self, user_id: str, name: str, tier: APITier = APITier.FREE,
expires_in_days: Optional[int] = None) -> tuple[str, APIKey]:
"""
Generate a new API key
Returns:
tuple: (full_api_key, api_key_object)
"""
try:
# Generate secure API key
key_id = secrets.token_urlsafe(8)
key_secret = secrets.token_urlsafe(32)
# Format: ys_{tier}_{key_id}_{secret}
full_key = f"ys_{tier.value}_{key_id}_{key_secret}"
key_prefix = f"ys_{tier.value}_{key_id}"
# Hash the full key for storage
key_hash = hashlib.sha256(full_key.encode()).hexdigest()
# Calculate expiry
expires_at = None
if expires_in_days:
expires_at = datetime.now() + timedelta(days=expires_in_days)
# Create API key object
api_key = APIKey(
id=key_id,
user_id=user_id,
name=name,
key_prefix=key_prefix,
key_hash=key_hash,
tier=tier,
rate_limits=self.default_rate_limits[tier].copy(),
is_active=True,
created_at=datetime.now(),
last_used_at=None,
expires_at=expires_at,
metadata={}
)
# Store in Redis/database
self._store_api_key(api_key)
logger.info(f"Generated API key {key_prefix} for user {user_id}")
return full_key, api_key
except Exception as e:
logger.error(f"Failed to generate API key: {e}")
raise
def validate_api_key(self, api_key: str) -> Optional[APIKey]:
"""
Validate API key and return key info
Args:
api_key: Full API key string
Returns:
APIKey object if valid, None if invalid
"""
try:
# Check format
if not api_key.startswith("ys_"):
return None
# Extract key ID from format
parts = api_key.split("_")
if len(parts) < 4:
return None
key_id = parts[2]
# Look up key in storage
stored_key = self._get_api_key(key_id)
if not stored_key:
return None
# Verify hash
key_hash = hashlib.sha256(api_key.encode()).hexdigest()
if key_hash != stored_key.key_hash:
return None
# Check if active
if not stored_key.is_active:
return None
# Check expiry
if stored_key.expires_at and datetime.now() > stored_key.expires_at:
return None
# Update last used
stored_key.last_used_at = datetime.now()
self._update_api_key(stored_key)
return stored_key
except Exception as e:
logger.error(f"API key validation failed: {e}")
return None
def check_rate_limit(self, api_key: APIKey, period: RateLimitPeriod = RateLimitPeriod.HOUR) -> RateLimitStatus:
"""
Check rate limit for API key
Args:
api_key: Validated API key object
period: Time period to check
Returns:
RateLimitStatus with current status
"""
try:
limit = api_key.rate_limits.get(period, 0)
# Create Redis key for rate limiting
current_time = datetime.now()
if period == RateLimitPeriod.MINUTE:
time_key = current_time.strftime("%Y-%m-%d-%H-%M")
reset_time = current_time.replace(second=0, microsecond=0) + timedelta(minutes=1)
ttl_seconds = 60
elif period == RateLimitPeriod.HOUR:
time_key = current_time.strftime("%Y-%m-%d-%H")
reset_time = current_time.replace(minute=0, second=0, microsecond=0) + timedelta(hours=1)
ttl_seconds = 3600
elif period == RateLimitPeriod.DAY:
time_key = current_time.strftime("%Y-%m-%d")
reset_time = current_time.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta(days=1)
ttl_seconds = 86400
else: # MONTH
time_key = current_time.strftime("%Y-%m")
next_month = current_time.replace(day=1) + timedelta(days=32)
reset_time = next_month.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
ttl_seconds = int((reset_time - current_time).total_seconds())
redis_key = f"rate_limit:{api_key.id}:{period.value}:{time_key}"
# Get current usage
current_usage = int(self.redis_client.get(redis_key) or 0)
# Check if limit exceeded
if current_usage >= limit:
return RateLimitStatus(
allowed=False,
remaining=0,
reset_time=reset_time,
total_limit=limit,
period=period
)
# Increment usage
new_usage = self.redis_client.incr(redis_key)
if new_usage == 1: # First request in this period
self.redis_client.expire(redis_key, ttl_seconds)
remaining = max(0, limit - new_usage)
return RateLimitStatus(
allowed=True,
remaining=remaining,
reset_time=reset_time,
total_limit=limit,
period=period
)
except Exception as e:
logger.error(f"Rate limit check failed: {e}")
# Allow request on error, but log it
return RateLimitStatus(
allowed=True,
remaining=999,
reset_time=datetime.now() + timedelta(hours=1),
total_limit=1000,
period=period
)
def get_usage_stats(self, api_key: APIKey) -> APIUsage:
"""Get detailed usage statistics for API key"""
try:
# Get usage data from Redis
user_stats_key = f"usage:{api_key.id}"
stats_data = self.redis_client.get(user_stats_key)
if stats_data:
stats = json.loads(stats_data)
return APIUsage(
total_requests=stats.get("total_requests", 0),
successful_requests=stats.get("successful_requests", 0),
failed_requests=stats.get("failed_requests", 0),
last_request_at=datetime.fromisoformat(stats["last_request_at"]) if stats.get("last_request_at") else None,
daily_usage=stats.get("daily_usage", {}),
monthly_usage=stats.get("monthly_usage", {})
)
else:
return APIUsage(
total_requests=0,
successful_requests=0,
failed_requests=0,
last_request_at=None,
daily_usage={},
monthly_usage={}
)
except Exception as e:
logger.error(f"Failed to get usage stats: {e}")
return APIUsage(
total_requests=0,
successful_requests=0,
failed_requests=0,
last_request_at=None,
daily_usage={},
monthly_usage={}
)
def record_request(self, api_key: APIKey, success: bool = True):
"""Record API request for usage tracking"""
try:
user_stats_key = f"usage:{api_key.id}"
current_stats = self.get_usage_stats(api_key)
# Update stats
current_stats.total_requests += 1
if success:
current_stats.successful_requests += 1
else:
current_stats.failed_requests += 1
current_stats.last_request_at = datetime.now()
# Update daily/monthly counters
today = datetime.now().strftime("%Y-%m-%d")
this_month = datetime.now().strftime("%Y-%m")
current_stats.daily_usage[today] = current_stats.daily_usage.get(today, 0) + 1
current_stats.monthly_usage[this_month] = current_stats.monthly_usage.get(this_month, 0) + 1
# Store updated stats
stats_dict = asdict(current_stats)
if stats_dict["last_request_at"]:
stats_dict["last_request_at"] = stats_dict["last_request_at"].isoformat()
self.redis_client.set(user_stats_key, json.dumps(stats_dict), ex=86400 * 30) # 30 days TTL
except Exception as e:
logger.error(f"Failed to record request: {e}")
def revoke_api_key(self, key_id: str) -> bool:
"""Revoke (deactivate) an API key"""
try:
api_key = self._get_api_key(key_id)
if api_key:
api_key.is_active = False
self._update_api_key(api_key)
logger.info(f"Revoked API key {key_id}")
return True
return False
except Exception as e:
logger.error(f"Failed to revoke API key {key_id}: {e}")
return False
def list_api_keys(self, user_id: str) -> List[APIKey]:
"""List all API keys for a user"""
try:
# In production, this would query the database
# For now, return mock data
return []
except Exception as e:
logger.error(f"Failed to list API keys for user {user_id}: {e}")
return []
def _store_api_key(self, api_key: APIKey):
"""Store API key in Redis/database"""
try:
key_data = asdict(api_key)
# Convert datetime objects to ISO format
key_data["created_at"] = key_data["created_at"].isoformat()
if key_data["last_used_at"]:
key_data["last_used_at"] = key_data["last_used_at"].isoformat()
if key_data["expires_at"]:
key_data["expires_at"] = key_data["expires_at"].isoformat()
redis_key = f"api_key:{api_key.id}"
self.redis_client.set(redis_key, json.dumps(key_data), ex=86400 * 365) # 1 year TTL
except Exception as e:
logger.error(f"Failed to store API key: {e}")
raise
def _get_api_key(self, key_id: str) -> Optional[APIKey]:
"""Retrieve API key from Redis/database"""
try:
redis_key = f"api_key:{key_id}"
key_data = self.redis_client.get(redis_key)
if not key_data:
return None
data = json.loads(key_data)
# Convert ISO format back to datetime
data["created_at"] = datetime.fromisoformat(data["created_at"])
if data["last_used_at"]:
data["last_used_at"] = datetime.fromisoformat(data["last_used_at"])
if data["expires_at"]:
data["expires_at"] = datetime.fromisoformat(data["expires_at"])
# Convert tier back to enum
data["tier"] = APITier(data["tier"])
# Convert rate_limits keys back to enums
rate_limits = {}
for period_str, limit in data["rate_limits"].items():
rate_limits[RateLimitPeriod(period_str)] = limit
data["rate_limits"] = rate_limits
return APIKey(**data)
except Exception as e:
logger.error(f"Failed to get API key {key_id}: {e}")
return None
def _update_api_key(self, api_key: APIKey):
"""Update API key in Redis/database"""
try:
self._store_api_key(api_key)
except Exception as e:
logger.error(f"Failed to update API key: {e}")
raise
# Global service instance
api_key_service = APIKeyService()