485 lines
17 KiB
Python
485 lines
17 KiB
Python
"""
|
|
API Key Management and Rate Limiting Service
|
|
Handles API key generation, validation, and rate limiting for the developer platform
|
|
"""
|
|
|
|
import secrets
|
|
import hashlib
|
|
import logging
|
|
from datetime import datetime, timedelta
|
|
from typing import Optional, Dict, Any, List
|
|
from dataclasses import dataclass, asdict
|
|
from enum import Enum
|
|
import redis
|
|
import json
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class APITier(str, Enum):
|
|
FREE = "free"
|
|
PRO = "pro"
|
|
ENTERPRISE = "enterprise"
|
|
|
|
class RateLimitPeriod(str, Enum):
|
|
MINUTE = "minute"
|
|
HOUR = "hour"
|
|
DAY = "day"
|
|
MONTH = "month"
|
|
|
|
@dataclass
|
|
class APIKey:
|
|
id: str
|
|
user_id: str
|
|
name: str
|
|
key_prefix: str # Only store prefix, not full key
|
|
key_hash: str # Hash of full key for validation
|
|
tier: APITier
|
|
rate_limits: Dict[RateLimitPeriod, int]
|
|
is_active: bool
|
|
created_at: datetime
|
|
last_used_at: Optional[datetime]
|
|
expires_at: Optional[datetime]
|
|
metadata: Dict[str, Any]
|
|
|
|
@dataclass
|
|
class RateLimitStatus:
|
|
allowed: bool
|
|
remaining: int
|
|
reset_time: datetime
|
|
total_limit: int
|
|
period: RateLimitPeriod
|
|
|
|
@dataclass
|
|
class APIUsage:
|
|
total_requests: int
|
|
successful_requests: int
|
|
failed_requests: int
|
|
last_request_at: Optional[datetime]
|
|
daily_usage: Dict[str, int] # date -> count
|
|
monthly_usage: Dict[str, int] # month -> count
|
|
|
|
class APIKeyService:
|
|
"""
|
|
API Key Management Service
|
|
|
|
Handles API key generation, validation, rate limiting, and usage tracking
|
|
"""
|
|
|
|
def __init__(self, redis_client: Optional[redis.Redis] = None):
|
|
self.redis_client = redis_client or self._create_redis_client()
|
|
|
|
# Default rate limits by tier
|
|
self.default_rate_limits = {
|
|
APITier.FREE: {
|
|
RateLimitPeriod.MINUTE: 10,
|
|
RateLimitPeriod.HOUR: 100,
|
|
RateLimitPeriod.DAY: 1000,
|
|
RateLimitPeriod.MONTH: 10000
|
|
},
|
|
APITier.PRO: {
|
|
RateLimitPeriod.MINUTE: 100,
|
|
RateLimitPeriod.HOUR: 2000,
|
|
RateLimitPeriod.DAY: 25000,
|
|
RateLimitPeriod.MONTH: 500000
|
|
},
|
|
APITier.ENTERPRISE: {
|
|
RateLimitPeriod.MINUTE: 1000,
|
|
RateLimitPeriod.HOUR: 10000,
|
|
RateLimitPeriod.DAY: 100000,
|
|
RateLimitPeriod.MONTH: 2000000
|
|
}
|
|
}
|
|
|
|
def _create_redis_client(self) -> redis.Redis:
|
|
"""Create Redis client with fallback to mock"""
|
|
try:
|
|
client = redis.Redis(
|
|
host='localhost',
|
|
port=6379,
|
|
db=1, # Use db 1 for API keys
|
|
decode_responses=True
|
|
)
|
|
# Test connection
|
|
client.ping()
|
|
logger.info("Connected to Redis for API key service")
|
|
return client
|
|
except Exception as e:
|
|
logger.warning(f"Redis connection failed, using in-memory fallback: {e}")
|
|
return self._create_mock_redis()
|
|
|
|
def _create_mock_redis(self):
|
|
"""Create mock Redis client for development"""
|
|
class MockRedis:
|
|
def __init__(self):
|
|
self.data = {}
|
|
self.expiry = {}
|
|
|
|
def get(self, key):
|
|
if key in self.expiry and datetime.now() > self.expiry[key]:
|
|
del self.data[key]
|
|
del self.expiry[key]
|
|
return None
|
|
return self.data.get(key)
|
|
|
|
def set(self, key, value, ex=None):
|
|
self.data[key] = value
|
|
if ex:
|
|
self.expiry[key] = datetime.now() + timedelta(seconds=ex)
|
|
|
|
def incr(self, key):
|
|
current = int(self.data.get(key, 0))
|
|
self.data[key] = str(current + 1)
|
|
return current + 1
|
|
|
|
def expire(self, key, seconds):
|
|
self.expiry[key] = datetime.now() + timedelta(seconds=seconds)
|
|
|
|
def ttl(self, key):
|
|
if key in self.expiry:
|
|
remaining = (self.expiry[key] - datetime.now()).total_seconds()
|
|
return max(0, int(remaining))
|
|
return -1
|
|
|
|
def exists(self, key):
|
|
return key in self.data
|
|
|
|
def delete(self, key):
|
|
self.data.pop(key, None)
|
|
self.expiry.pop(key, None)
|
|
|
|
return MockRedis()
|
|
|
|
def generate_api_key(self, user_id: str, name: str, tier: APITier = APITier.FREE,
|
|
expires_in_days: Optional[int] = None) -> tuple[str, APIKey]:
|
|
"""
|
|
Generate a new API key
|
|
|
|
Returns:
|
|
tuple: (full_api_key, api_key_object)
|
|
"""
|
|
try:
|
|
# Generate secure API key
|
|
key_id = secrets.token_urlsafe(8)
|
|
key_secret = secrets.token_urlsafe(32)
|
|
|
|
# Format: ys_{tier}_{key_id}_{secret}
|
|
full_key = f"ys_{tier.value}_{key_id}_{key_secret}"
|
|
key_prefix = f"ys_{tier.value}_{key_id}"
|
|
|
|
# Hash the full key for storage
|
|
key_hash = hashlib.sha256(full_key.encode()).hexdigest()
|
|
|
|
# Calculate expiry
|
|
expires_at = None
|
|
if expires_in_days:
|
|
expires_at = datetime.now() + timedelta(days=expires_in_days)
|
|
|
|
# Create API key object
|
|
api_key = APIKey(
|
|
id=key_id,
|
|
user_id=user_id,
|
|
name=name,
|
|
key_prefix=key_prefix,
|
|
key_hash=key_hash,
|
|
tier=tier,
|
|
rate_limits=self.default_rate_limits[tier].copy(),
|
|
is_active=True,
|
|
created_at=datetime.now(),
|
|
last_used_at=None,
|
|
expires_at=expires_at,
|
|
metadata={}
|
|
)
|
|
|
|
# Store in Redis/database
|
|
self._store_api_key(api_key)
|
|
|
|
logger.info(f"Generated API key {key_prefix} for user {user_id}")
|
|
return full_key, api_key
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to generate API key: {e}")
|
|
raise
|
|
|
|
def validate_api_key(self, api_key: str) -> Optional[APIKey]:
|
|
"""
|
|
Validate API key and return key info
|
|
|
|
Args:
|
|
api_key: Full API key string
|
|
|
|
Returns:
|
|
APIKey object if valid, None if invalid
|
|
"""
|
|
try:
|
|
# Check format
|
|
if not api_key.startswith("ys_"):
|
|
return None
|
|
|
|
# Extract key ID from format
|
|
parts = api_key.split("_")
|
|
if len(parts) < 4:
|
|
return None
|
|
|
|
key_id = parts[2]
|
|
|
|
# Look up key in storage
|
|
stored_key = self._get_api_key(key_id)
|
|
if not stored_key:
|
|
return None
|
|
|
|
# Verify hash
|
|
key_hash = hashlib.sha256(api_key.encode()).hexdigest()
|
|
if key_hash != stored_key.key_hash:
|
|
return None
|
|
|
|
# Check if active
|
|
if not stored_key.is_active:
|
|
return None
|
|
|
|
# Check expiry
|
|
if stored_key.expires_at and datetime.now() > stored_key.expires_at:
|
|
return None
|
|
|
|
# Update last used
|
|
stored_key.last_used_at = datetime.now()
|
|
self._update_api_key(stored_key)
|
|
|
|
return stored_key
|
|
|
|
except Exception as e:
|
|
logger.error(f"API key validation failed: {e}")
|
|
return None
|
|
|
|
def check_rate_limit(self, api_key: APIKey, period: RateLimitPeriod = RateLimitPeriod.HOUR) -> RateLimitStatus:
|
|
"""
|
|
Check rate limit for API key
|
|
|
|
Args:
|
|
api_key: Validated API key object
|
|
period: Time period to check
|
|
|
|
Returns:
|
|
RateLimitStatus with current status
|
|
"""
|
|
try:
|
|
limit = api_key.rate_limits.get(period, 0)
|
|
|
|
# Create Redis key for rate limiting
|
|
current_time = datetime.now()
|
|
if period == RateLimitPeriod.MINUTE:
|
|
time_key = current_time.strftime("%Y-%m-%d-%H-%M")
|
|
reset_time = current_time.replace(second=0, microsecond=0) + timedelta(minutes=1)
|
|
ttl_seconds = 60
|
|
elif period == RateLimitPeriod.HOUR:
|
|
time_key = current_time.strftime("%Y-%m-%d-%H")
|
|
reset_time = current_time.replace(minute=0, second=0, microsecond=0) + timedelta(hours=1)
|
|
ttl_seconds = 3600
|
|
elif period == RateLimitPeriod.DAY:
|
|
time_key = current_time.strftime("%Y-%m-%d")
|
|
reset_time = current_time.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta(days=1)
|
|
ttl_seconds = 86400
|
|
else: # MONTH
|
|
time_key = current_time.strftime("%Y-%m")
|
|
next_month = current_time.replace(day=1) + timedelta(days=32)
|
|
reset_time = next_month.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
|
|
ttl_seconds = int((reset_time - current_time).total_seconds())
|
|
|
|
redis_key = f"rate_limit:{api_key.id}:{period.value}:{time_key}"
|
|
|
|
# Get current usage
|
|
current_usage = int(self.redis_client.get(redis_key) or 0)
|
|
|
|
# Check if limit exceeded
|
|
if current_usage >= limit:
|
|
return RateLimitStatus(
|
|
allowed=False,
|
|
remaining=0,
|
|
reset_time=reset_time,
|
|
total_limit=limit,
|
|
period=period
|
|
)
|
|
|
|
# Increment usage
|
|
new_usage = self.redis_client.incr(redis_key)
|
|
if new_usage == 1: # First request in this period
|
|
self.redis_client.expire(redis_key, ttl_seconds)
|
|
|
|
remaining = max(0, limit - new_usage)
|
|
|
|
return RateLimitStatus(
|
|
allowed=True,
|
|
remaining=remaining,
|
|
reset_time=reset_time,
|
|
total_limit=limit,
|
|
period=period
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Rate limit check failed: {e}")
|
|
# Allow request on error, but log it
|
|
return RateLimitStatus(
|
|
allowed=True,
|
|
remaining=999,
|
|
reset_time=datetime.now() + timedelta(hours=1),
|
|
total_limit=1000,
|
|
period=period
|
|
)
|
|
|
|
def get_usage_stats(self, api_key: APIKey) -> APIUsage:
|
|
"""Get detailed usage statistics for API key"""
|
|
try:
|
|
# Get usage data from Redis
|
|
user_stats_key = f"usage:{api_key.id}"
|
|
stats_data = self.redis_client.get(user_stats_key)
|
|
|
|
if stats_data:
|
|
stats = json.loads(stats_data)
|
|
return APIUsage(
|
|
total_requests=stats.get("total_requests", 0),
|
|
successful_requests=stats.get("successful_requests", 0),
|
|
failed_requests=stats.get("failed_requests", 0),
|
|
last_request_at=datetime.fromisoformat(stats["last_request_at"]) if stats.get("last_request_at") else None,
|
|
daily_usage=stats.get("daily_usage", {}),
|
|
monthly_usage=stats.get("monthly_usage", {})
|
|
)
|
|
else:
|
|
return APIUsage(
|
|
total_requests=0,
|
|
successful_requests=0,
|
|
failed_requests=0,
|
|
last_request_at=None,
|
|
daily_usage={},
|
|
monthly_usage={}
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to get usage stats: {e}")
|
|
return APIUsage(
|
|
total_requests=0,
|
|
successful_requests=0,
|
|
failed_requests=0,
|
|
last_request_at=None,
|
|
daily_usage={},
|
|
monthly_usage={}
|
|
)
|
|
|
|
def record_request(self, api_key: APIKey, success: bool = True):
|
|
"""Record API request for usage tracking"""
|
|
try:
|
|
user_stats_key = f"usage:{api_key.id}"
|
|
current_stats = self.get_usage_stats(api_key)
|
|
|
|
# Update stats
|
|
current_stats.total_requests += 1
|
|
if success:
|
|
current_stats.successful_requests += 1
|
|
else:
|
|
current_stats.failed_requests += 1
|
|
|
|
current_stats.last_request_at = datetime.now()
|
|
|
|
# Update daily/monthly counters
|
|
today = datetime.now().strftime("%Y-%m-%d")
|
|
this_month = datetime.now().strftime("%Y-%m")
|
|
|
|
current_stats.daily_usage[today] = current_stats.daily_usage.get(today, 0) + 1
|
|
current_stats.monthly_usage[this_month] = current_stats.monthly_usage.get(this_month, 0) + 1
|
|
|
|
# Store updated stats
|
|
stats_dict = asdict(current_stats)
|
|
if stats_dict["last_request_at"]:
|
|
stats_dict["last_request_at"] = stats_dict["last_request_at"].isoformat()
|
|
|
|
self.redis_client.set(user_stats_key, json.dumps(stats_dict), ex=86400 * 30) # 30 days TTL
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to record request: {e}")
|
|
|
|
def revoke_api_key(self, key_id: str) -> bool:
|
|
"""Revoke (deactivate) an API key"""
|
|
try:
|
|
api_key = self._get_api_key(key_id)
|
|
if api_key:
|
|
api_key.is_active = False
|
|
self._update_api_key(api_key)
|
|
logger.info(f"Revoked API key {key_id}")
|
|
return True
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to revoke API key {key_id}: {e}")
|
|
return False
|
|
|
|
def list_api_keys(self, user_id: str) -> List[APIKey]:
|
|
"""List all API keys for a user"""
|
|
try:
|
|
# In production, this would query the database
|
|
# For now, return mock data
|
|
return []
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to list API keys for user {user_id}: {e}")
|
|
return []
|
|
|
|
def _store_api_key(self, api_key: APIKey):
|
|
"""Store API key in Redis/database"""
|
|
try:
|
|
key_data = asdict(api_key)
|
|
# Convert datetime objects to ISO format
|
|
key_data["created_at"] = key_data["created_at"].isoformat()
|
|
if key_data["last_used_at"]:
|
|
key_data["last_used_at"] = key_data["last_used_at"].isoformat()
|
|
if key_data["expires_at"]:
|
|
key_data["expires_at"] = key_data["expires_at"].isoformat()
|
|
|
|
redis_key = f"api_key:{api_key.id}"
|
|
self.redis_client.set(redis_key, json.dumps(key_data), ex=86400 * 365) # 1 year TTL
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to store API key: {e}")
|
|
raise
|
|
|
|
def _get_api_key(self, key_id: str) -> Optional[APIKey]:
|
|
"""Retrieve API key from Redis/database"""
|
|
try:
|
|
redis_key = f"api_key:{key_id}"
|
|
key_data = self.redis_client.get(redis_key)
|
|
|
|
if not key_data:
|
|
return None
|
|
|
|
data = json.loads(key_data)
|
|
|
|
# Convert ISO format back to datetime
|
|
data["created_at"] = datetime.fromisoformat(data["created_at"])
|
|
if data["last_used_at"]:
|
|
data["last_used_at"] = datetime.fromisoformat(data["last_used_at"])
|
|
if data["expires_at"]:
|
|
data["expires_at"] = datetime.fromisoformat(data["expires_at"])
|
|
|
|
# Convert tier back to enum
|
|
data["tier"] = APITier(data["tier"])
|
|
|
|
# Convert rate_limits keys back to enums
|
|
rate_limits = {}
|
|
for period_str, limit in data["rate_limits"].items():
|
|
rate_limits[RateLimitPeriod(period_str)] = limit
|
|
data["rate_limits"] = rate_limits
|
|
|
|
return APIKey(**data)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to get API key {key_id}: {e}")
|
|
return None
|
|
|
|
def _update_api_key(self, api_key: APIKey):
|
|
"""Update API key in Redis/database"""
|
|
try:
|
|
self._store_api_key(api_key)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to update API key: {e}")
|
|
raise
|
|
|
|
# Global service instance
|
|
api_key_service = APIKeyService() |