383 lines
10 KiB
Python
383 lines
10 KiB
Python
"""JWT Authentication Service."""
|
|
|
|
from typing import Optional, Dict, Any
|
|
from datetime import datetime, timedelta
|
|
import secrets
|
|
import hashlib
|
|
from jose import JWTError, jwt
|
|
from passlib.context import CryptContext
|
|
from sqlalchemy.orm import Session
|
|
|
|
import sys
|
|
from pathlib import Path
|
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
|
|
from core.config import settings, auth_settings
|
|
from models.user import User, RefreshToken
|
|
from core.database import get_db_context
|
|
|
|
|
|
# Password hashing context
|
|
pwd_context = CryptContext(
|
|
schemes=["bcrypt"],
|
|
deprecated="auto",
|
|
bcrypt__rounds=auth_settings.get_password_hash_rounds()
|
|
)
|
|
|
|
|
|
class AuthService:
|
|
"""Service for authentication and authorization operations."""
|
|
|
|
@staticmethod
|
|
def hash_password(password: str) -> str:
|
|
"""
|
|
Hash a password using bcrypt.
|
|
|
|
Args:
|
|
password: Plain text password
|
|
|
|
Returns:
|
|
Hashed password
|
|
"""
|
|
return pwd_context.hash(password)
|
|
|
|
@staticmethod
|
|
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
|
"""
|
|
Verify a password against its hash.
|
|
|
|
Args:
|
|
plain_password: Plain text password
|
|
hashed_password: Hashed password
|
|
|
|
Returns:
|
|
True if password matches, False otherwise
|
|
"""
|
|
return pwd_context.verify(plain_password, hashed_password)
|
|
|
|
@staticmethod
|
|
def create_access_token(data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str:
|
|
"""
|
|
Create a JWT access token.
|
|
|
|
Args:
|
|
data: Data to encode in the token
|
|
expires_delta: Token expiration time
|
|
|
|
Returns:
|
|
Encoded JWT token
|
|
"""
|
|
to_encode = data.copy()
|
|
|
|
if expires_delta:
|
|
expire = datetime.utcnow() + expires_delta
|
|
else:
|
|
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
|
|
to_encode.update({
|
|
"exp": expire,
|
|
"type": "access"
|
|
})
|
|
|
|
encoded_jwt = jwt.encode(
|
|
to_encode,
|
|
auth_settings.get_jwt_secret_key(),
|
|
algorithm=settings.JWT_ALGORITHM
|
|
)
|
|
|
|
return encoded_jwt
|
|
|
|
@staticmethod
|
|
def create_refresh_token(user_id: str, db: Session) -> str:
|
|
"""
|
|
Create a refresh token and store it in the database.
|
|
|
|
Args:
|
|
user_id: User ID
|
|
db: Database session
|
|
|
|
Returns:
|
|
Refresh token
|
|
"""
|
|
# Generate a secure random token
|
|
token = secrets.token_urlsafe(32)
|
|
|
|
# Hash the token for storage
|
|
token_hash = hashlib.sha256(token.encode()).hexdigest()
|
|
|
|
# Calculate expiration
|
|
expires_at = datetime.utcnow() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
|
|
|
# Store in database
|
|
refresh_token = RefreshToken(
|
|
user_id=user_id,
|
|
token_hash=token_hash,
|
|
expires_at=expires_at
|
|
)
|
|
db.add(refresh_token)
|
|
db.commit()
|
|
|
|
return token
|
|
|
|
@staticmethod
|
|
def verify_refresh_token(token: str, db: Session) -> Optional[RefreshToken]:
|
|
"""
|
|
Verify a refresh token.
|
|
|
|
Args:
|
|
token: Refresh token
|
|
db: Database session
|
|
|
|
Returns:
|
|
RefreshToken object if valid, None otherwise
|
|
"""
|
|
# Hash the token
|
|
token_hash = hashlib.sha256(token.encode()).hexdigest()
|
|
|
|
# Look up in database
|
|
refresh_token = db.query(RefreshToken).filter(
|
|
RefreshToken.token_hash == token_hash,
|
|
RefreshToken.revoked == False,
|
|
RefreshToken.expires_at > datetime.utcnow()
|
|
).first()
|
|
|
|
return refresh_token
|
|
|
|
@staticmethod
|
|
def revoke_refresh_token(token: str, db: Session) -> bool:
|
|
"""
|
|
Revoke a refresh token.
|
|
|
|
Args:
|
|
token: Refresh token
|
|
db: Database session
|
|
|
|
Returns:
|
|
True if revoked successfully, False otherwise
|
|
"""
|
|
# Hash the token
|
|
token_hash = hashlib.sha256(token.encode()).hexdigest()
|
|
|
|
# Find and revoke
|
|
refresh_token = db.query(RefreshToken).filter(
|
|
RefreshToken.token_hash == token_hash
|
|
).first()
|
|
|
|
if refresh_token:
|
|
refresh_token.revoked = True
|
|
db.commit()
|
|
return True
|
|
|
|
return False
|
|
|
|
@staticmethod
|
|
def revoke_all_user_tokens(user_id: str, db: Session) -> int:
|
|
"""
|
|
Revoke all refresh tokens for a user.
|
|
|
|
Args:
|
|
user_id: User ID
|
|
db: Database session
|
|
|
|
Returns:
|
|
Number of tokens revoked
|
|
"""
|
|
count = db.query(RefreshToken).filter(
|
|
RefreshToken.user_id == user_id,
|
|
RefreshToken.revoked == False
|
|
).update({"revoked": True})
|
|
|
|
db.commit()
|
|
return count
|
|
|
|
@staticmethod
|
|
def decode_access_token(token: str) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Decode and verify a JWT access token.
|
|
|
|
Args:
|
|
token: JWT token
|
|
|
|
Returns:
|
|
Token payload if valid, None otherwise
|
|
"""
|
|
try:
|
|
payload = jwt.decode(
|
|
token,
|
|
auth_settings.get_jwt_secret_key(),
|
|
algorithms=[settings.JWT_ALGORITHM]
|
|
)
|
|
|
|
# Verify it's an access token
|
|
if payload.get("type") != "access":
|
|
return None
|
|
|
|
return payload
|
|
except JWTError:
|
|
return None
|
|
|
|
@staticmethod
|
|
def create_email_verification_token(user_id: str) -> str:
|
|
"""
|
|
Create an email verification token.
|
|
|
|
Args:
|
|
user_id: User ID
|
|
|
|
Returns:
|
|
Email verification token
|
|
"""
|
|
data = {
|
|
"user_id": user_id,
|
|
"type": "email_verification"
|
|
}
|
|
|
|
expires_delta = timedelta(hours=settings.EMAIL_VERIFICATION_EXPIRE_HOURS)
|
|
|
|
token = jwt.encode(
|
|
{
|
|
**data,
|
|
"exp": datetime.utcnow() + expires_delta
|
|
},
|
|
auth_settings.get_jwt_secret_key(),
|
|
algorithm=settings.JWT_ALGORITHM
|
|
)
|
|
|
|
return token
|
|
|
|
@staticmethod
|
|
def verify_email_token(token: str) -> Optional[str]:
|
|
"""
|
|
Verify an email verification token.
|
|
|
|
Args:
|
|
token: Email verification token
|
|
|
|
Returns:
|
|
User ID if valid, None otherwise
|
|
"""
|
|
try:
|
|
payload = jwt.decode(
|
|
token,
|
|
auth_settings.get_jwt_secret_key(),
|
|
algorithms=[settings.JWT_ALGORITHM]
|
|
)
|
|
|
|
# Verify it's an email verification token
|
|
if payload.get("type") != "email_verification":
|
|
return None
|
|
|
|
return payload.get("user_id")
|
|
except JWTError:
|
|
return None
|
|
|
|
@staticmethod
|
|
def create_password_reset_token(user_id: str) -> str:
|
|
"""
|
|
Create a password reset token.
|
|
|
|
Args:
|
|
user_id: User ID
|
|
|
|
Returns:
|
|
Password reset token
|
|
"""
|
|
data = {
|
|
"user_id": user_id,
|
|
"type": "password_reset"
|
|
}
|
|
|
|
expires_delta = timedelta(minutes=settings.PASSWORD_RESET_EXPIRE_MINUTES)
|
|
|
|
token = jwt.encode(
|
|
{
|
|
**data,
|
|
"exp": datetime.utcnow() + expires_delta
|
|
},
|
|
auth_settings.get_jwt_secret_key(),
|
|
algorithm=settings.JWT_ALGORITHM
|
|
)
|
|
|
|
return token
|
|
|
|
@staticmethod
|
|
def verify_password_reset_token(token: str) -> Optional[str]:
|
|
"""
|
|
Verify a password reset token.
|
|
|
|
Args:
|
|
token: Password reset token
|
|
|
|
Returns:
|
|
User ID if valid, None otherwise
|
|
"""
|
|
try:
|
|
payload = jwt.decode(
|
|
token,
|
|
auth_settings.get_jwt_secret_key(),
|
|
algorithms=[settings.JWT_ALGORITHM]
|
|
)
|
|
|
|
# Verify it's a password reset token
|
|
if payload.get("type") != "password_reset":
|
|
return None
|
|
|
|
return payload.get("user_id")
|
|
except JWTError:
|
|
return None
|
|
|
|
@staticmethod
|
|
def authenticate_user(email: str, password: str, db: Session) -> Optional[User]:
|
|
"""
|
|
Authenticate a user by email and password.
|
|
|
|
Args:
|
|
email: User email
|
|
password: User password
|
|
db: Database session
|
|
|
|
Returns:
|
|
User object if authenticated, None otherwise
|
|
"""
|
|
user = db.query(User).filter(User.email == email).first()
|
|
|
|
if not user:
|
|
return None
|
|
|
|
if not AuthService.verify_password(password, user.password_hash):
|
|
return None
|
|
|
|
# Update last login
|
|
user.last_login = datetime.utcnow()
|
|
db.commit()
|
|
|
|
return user
|
|
|
|
@staticmethod
|
|
def get_current_user(token: str, db: Session) -> Optional[User]:
|
|
"""
|
|
Get the current user from an access token.
|
|
|
|
Args:
|
|
token: JWT access token
|
|
db: Database session
|
|
|
|
Returns:
|
|
User object if valid, None otherwise
|
|
"""
|
|
payload = AuthService.decode_access_token(token)
|
|
|
|
if not payload:
|
|
return None
|
|
|
|
user_id = payload.get("sub") # Subject claim contains user ID
|
|
|
|
if not user_id:
|
|
return None
|
|
|
|
user = db.query(User).filter(
|
|
User.id == user_id,
|
|
User.is_active == True
|
|
).first()
|
|
|
|
return user |