"""JWT Authentication Service.""" from typing import Optional, Dict, Any from datetime import datetime, timedelta import secrets import hashlib from jose import JWTError, jwt from passlib.context import CryptContext from sqlalchemy.orm import Session import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.parent)) from core.config import settings, auth_settings from models.user import User, RefreshToken from core.database import get_db_context # Password hashing context pwd_context = CryptContext( schemes=["bcrypt"], deprecated="auto", bcrypt__rounds=auth_settings.get_password_hash_rounds() ) class AuthService: """Service for authentication and authorization operations.""" @staticmethod def hash_password(password: str) -> str: """ Hash a password using bcrypt. Args: password: Plain text password Returns: Hashed password """ return pwd_context.hash(password) @staticmethod def verify_password(plain_password: str, hashed_password: str) -> bool: """ Verify a password against its hash. Args: plain_password: Plain text password hashed_password: Hashed password Returns: True if password matches, False otherwise """ return pwd_context.verify(plain_password, hashed_password) @staticmethod def create_access_token(data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str: """ Create a JWT access token. Args: data: Data to encode in the token expires_delta: Token expiration time Returns: Encoded JWT token """ to_encode = data.copy() if expires_delta: expire = datetime.utcnow() + expires_delta else: expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) to_encode.update({ "exp": expire, "type": "access" }) encoded_jwt = jwt.encode( to_encode, auth_settings.get_jwt_secret_key(), algorithm=settings.JWT_ALGORITHM ) return encoded_jwt @staticmethod def create_refresh_token(user_id: str, db: Session) -> str: """ Create a refresh token and store it in the database. Args: user_id: User ID db: Database session Returns: Refresh token """ # Generate a secure random token token = secrets.token_urlsafe(32) # Hash the token for storage token_hash = hashlib.sha256(token.encode()).hexdigest() # Calculate expiration expires_at = datetime.utcnow() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS) # Store in database refresh_token = RefreshToken( user_id=user_id, token_hash=token_hash, expires_at=expires_at ) db.add(refresh_token) db.commit() return token @staticmethod def verify_refresh_token(token: str, db: Session) -> Optional[RefreshToken]: """ Verify a refresh token. Args: token: Refresh token db: Database session Returns: RefreshToken object if valid, None otherwise """ # Hash the token token_hash = hashlib.sha256(token.encode()).hexdigest() # Look up in database refresh_token = db.query(RefreshToken).filter( RefreshToken.token_hash == token_hash, RefreshToken.revoked == False, RefreshToken.expires_at > datetime.utcnow() ).first() return refresh_token @staticmethod def revoke_refresh_token(token: str, db: Session) -> bool: """ Revoke a refresh token. Args: token: Refresh token db: Database session Returns: True if revoked successfully, False otherwise """ # Hash the token token_hash = hashlib.sha256(token.encode()).hexdigest() # Find and revoke refresh_token = db.query(RefreshToken).filter( RefreshToken.token_hash == token_hash ).first() if refresh_token: refresh_token.revoked = True db.commit() return True return False @staticmethod def revoke_all_user_tokens(user_id: str, db: Session) -> int: """ Revoke all refresh tokens for a user. Args: user_id: User ID db: Database session Returns: Number of tokens revoked """ count = db.query(RefreshToken).filter( RefreshToken.user_id == user_id, RefreshToken.revoked == False ).update({"revoked": True}) db.commit() return count @staticmethod def decode_access_token(token: str) -> Optional[Dict[str, Any]]: """ Decode and verify a JWT access token. Args: token: JWT token Returns: Token payload if valid, None otherwise """ try: payload = jwt.decode( token, auth_settings.get_jwt_secret_key(), algorithms=[settings.JWT_ALGORITHM] ) # Verify it's an access token if payload.get("type") != "access": return None return payload except JWTError: return None @staticmethod def create_email_verification_token(user_id: str) -> str: """ Create an email verification token. Args: user_id: User ID Returns: Email verification token """ data = { "user_id": user_id, "type": "email_verification" } expires_delta = timedelta(hours=settings.EMAIL_VERIFICATION_EXPIRE_HOURS) token = jwt.encode( { **data, "exp": datetime.utcnow() + expires_delta }, auth_settings.get_jwt_secret_key(), algorithm=settings.JWT_ALGORITHM ) return token @staticmethod def verify_email_token(token: str) -> Optional[str]: """ Verify an email verification token. Args: token: Email verification token Returns: User ID if valid, None otherwise """ try: payload = jwt.decode( token, auth_settings.get_jwt_secret_key(), algorithms=[settings.JWT_ALGORITHM] ) # Verify it's an email verification token if payload.get("type") != "email_verification": return None return payload.get("user_id") except JWTError: return None @staticmethod def create_password_reset_token(user_id: str) -> str: """ Create a password reset token. Args: user_id: User ID Returns: Password reset token """ data = { "user_id": user_id, "type": "password_reset" } expires_delta = timedelta(minutes=settings.PASSWORD_RESET_EXPIRE_MINUTES) token = jwt.encode( { **data, "exp": datetime.utcnow() + expires_delta }, auth_settings.get_jwt_secret_key(), algorithm=settings.JWT_ALGORITHM ) return token @staticmethod def verify_password_reset_token(token: str) -> Optional[str]: """ Verify a password reset token. Args: token: Password reset token Returns: User ID if valid, None otherwise """ try: payload = jwt.decode( token, auth_settings.get_jwt_secret_key(), algorithms=[settings.JWT_ALGORITHM] ) # Verify it's a password reset token if payload.get("type") != "password_reset": return None return payload.get("user_id") except JWTError: return None @staticmethod def authenticate_user(email: str, password: str, db: Session) -> Optional[User]: """ Authenticate a user by email and password. Args: email: User email password: User password db: Database session Returns: User object if authenticated, None otherwise """ user = db.query(User).filter(User.email == email).first() if not user: return None if not AuthService.verify_password(password, user.password_hash): return None # Update last login user.last_login = datetime.utcnow() db.commit() return user @staticmethod def get_current_user(token: str, db: Session) -> Optional[User]: """ Get the current user from an access token. Args: token: JWT access token db: Database session Returns: User object if valid, None otherwise """ payload = AuthService.decode_access_token(token) if not payload: return None user_id = payload.get("sub") # Subject claim contains user ID if not user_id: return None user = db.query(User).filter( User.id == user_id, User.is_active == True ).first() return user