youtube-summarizer/backend/services/auth_service.py

383 lines
10 KiB
Python

"""JWT Authentication Service."""
from typing import Optional, Dict, Any
from datetime import datetime, timedelta
import secrets
import hashlib
from jose import JWTError, jwt
from passlib.context import CryptContext
from sqlalchemy.orm import Session
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from core.config import settings, auth_settings
from models.user import User, RefreshToken
from core.database import get_db_context
# Password hashing context
pwd_context = CryptContext(
schemes=["bcrypt"],
deprecated="auto",
bcrypt__rounds=auth_settings.get_password_hash_rounds()
)
class AuthService:
"""Service for authentication and authorization operations."""
@staticmethod
def hash_password(password: str) -> str:
"""
Hash a password using bcrypt.
Args:
password: Plain text password
Returns:
Hashed password
"""
return pwd_context.hash(password)
@staticmethod
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""
Verify a password against its hash.
Args:
plain_password: Plain text password
hashed_password: Hashed password
Returns:
True if password matches, False otherwise
"""
return pwd_context.verify(plain_password, hashed_password)
@staticmethod
def create_access_token(data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str:
"""
Create a JWT access token.
Args:
data: Data to encode in the token
expires_delta: Token expiration time
Returns:
Encoded JWT token
"""
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode.update({
"exp": expire,
"type": "access"
})
encoded_jwt = jwt.encode(
to_encode,
auth_settings.get_jwt_secret_key(),
algorithm=settings.JWT_ALGORITHM
)
return encoded_jwt
@staticmethod
def create_refresh_token(user_id: str, db: Session) -> str:
"""
Create a refresh token and store it in the database.
Args:
user_id: User ID
db: Database session
Returns:
Refresh token
"""
# Generate a secure random token
token = secrets.token_urlsafe(32)
# Hash the token for storage
token_hash = hashlib.sha256(token.encode()).hexdigest()
# Calculate expiration
expires_at = datetime.utcnow() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
# Store in database
refresh_token = RefreshToken(
user_id=user_id,
token_hash=token_hash,
expires_at=expires_at
)
db.add(refresh_token)
db.commit()
return token
@staticmethod
def verify_refresh_token(token: str, db: Session) -> Optional[RefreshToken]:
"""
Verify a refresh token.
Args:
token: Refresh token
db: Database session
Returns:
RefreshToken object if valid, None otherwise
"""
# Hash the token
token_hash = hashlib.sha256(token.encode()).hexdigest()
# Look up in database
refresh_token = db.query(RefreshToken).filter(
RefreshToken.token_hash == token_hash,
RefreshToken.revoked == False,
RefreshToken.expires_at > datetime.utcnow()
).first()
return refresh_token
@staticmethod
def revoke_refresh_token(token: str, db: Session) -> bool:
"""
Revoke a refresh token.
Args:
token: Refresh token
db: Database session
Returns:
True if revoked successfully, False otherwise
"""
# Hash the token
token_hash = hashlib.sha256(token.encode()).hexdigest()
# Find and revoke
refresh_token = db.query(RefreshToken).filter(
RefreshToken.token_hash == token_hash
).first()
if refresh_token:
refresh_token.revoked = True
db.commit()
return True
return False
@staticmethod
def revoke_all_user_tokens(user_id: str, db: Session) -> int:
"""
Revoke all refresh tokens for a user.
Args:
user_id: User ID
db: Database session
Returns:
Number of tokens revoked
"""
count = db.query(RefreshToken).filter(
RefreshToken.user_id == user_id,
RefreshToken.revoked == False
).update({"revoked": True})
db.commit()
return count
@staticmethod
def decode_access_token(token: str) -> Optional[Dict[str, Any]]:
"""
Decode and verify a JWT access token.
Args:
token: JWT token
Returns:
Token payload if valid, None otherwise
"""
try:
payload = jwt.decode(
token,
auth_settings.get_jwt_secret_key(),
algorithms=[settings.JWT_ALGORITHM]
)
# Verify it's an access token
if payload.get("type") != "access":
return None
return payload
except JWTError:
return None
@staticmethod
def create_email_verification_token(user_id: str) -> str:
"""
Create an email verification token.
Args:
user_id: User ID
Returns:
Email verification token
"""
data = {
"user_id": user_id,
"type": "email_verification"
}
expires_delta = timedelta(hours=settings.EMAIL_VERIFICATION_EXPIRE_HOURS)
token = jwt.encode(
{
**data,
"exp": datetime.utcnow() + expires_delta
},
auth_settings.get_jwt_secret_key(),
algorithm=settings.JWT_ALGORITHM
)
return token
@staticmethod
def verify_email_token(token: str) -> Optional[str]:
"""
Verify an email verification token.
Args:
token: Email verification token
Returns:
User ID if valid, None otherwise
"""
try:
payload = jwt.decode(
token,
auth_settings.get_jwt_secret_key(),
algorithms=[settings.JWT_ALGORITHM]
)
# Verify it's an email verification token
if payload.get("type") != "email_verification":
return None
return payload.get("user_id")
except JWTError:
return None
@staticmethod
def create_password_reset_token(user_id: str) -> str:
"""
Create a password reset token.
Args:
user_id: User ID
Returns:
Password reset token
"""
data = {
"user_id": user_id,
"type": "password_reset"
}
expires_delta = timedelta(minutes=settings.PASSWORD_RESET_EXPIRE_MINUTES)
token = jwt.encode(
{
**data,
"exp": datetime.utcnow() + expires_delta
},
auth_settings.get_jwt_secret_key(),
algorithm=settings.JWT_ALGORITHM
)
return token
@staticmethod
def verify_password_reset_token(token: str) -> Optional[str]:
"""
Verify a password reset token.
Args:
token: Password reset token
Returns:
User ID if valid, None otherwise
"""
try:
payload = jwt.decode(
token,
auth_settings.get_jwt_secret_key(),
algorithms=[settings.JWT_ALGORITHM]
)
# Verify it's a password reset token
if payload.get("type") != "password_reset":
return None
return payload.get("user_id")
except JWTError:
return None
@staticmethod
def authenticate_user(email: str, password: str, db: Session) -> Optional[User]:
"""
Authenticate a user by email and password.
Args:
email: User email
password: User password
db: Database session
Returns:
User object if authenticated, None otherwise
"""
user = db.query(User).filter(User.email == email).first()
if not user:
return None
if not AuthService.verify_password(password, user.password_hash):
return None
# Update last login
user.last_login = datetime.utcnow()
db.commit()
return user
@staticmethod
def get_current_user(token: str, db: Session) -> Optional[User]:
"""
Get the current user from an access token.
Args:
token: JWT access token
db: Database session
Returns:
User object if valid, None otherwise
"""
payload = AuthService.decode_access_token(token)
if not payload:
return None
user_id = payload.get("sub") # Subject claim contains user ID
if not user_id:
return None
user = db.query(User).filter(
User.id == user_id,
User.is_active == True
).first()
return user