import uuid from datetime import datetime, timedelta, timezone from typing import Any from config import OIDC_ENABLED, ROMM_AUTH_SECRET_KEY, ROMM_BASE_URL from decorators.auth import oauth from exceptions.auth_exceptions import OAuthCredentialsException, UserDisabledException from fastapi import HTTPException, status from handler.auth.constants import ALGORITHM, DEFAULT_OAUTH_TOKEN_EXPIRY, TokenPurpose from handler.redis_handler import redis_client from joserfc import jwt from joserfc.errors import BadSignatureError, DecodeError from joserfc.jwk import OctKey from logger.formatter import CYAN from logger.formatter import highlight as hl from logger.logger import log from passlib.context import CryptContext from starlette.requests import HTTPConnection class AuthHandler: def __init__(self) -> None: self.pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") self.reset_passwd_token_expires_in_minutes = 10 self.invite_link_token_expires_in_minutes = 10 def verify_password(self, plain_password, hashed_password): return self.pwd_context.verify(plain_password, hashed_password) def get_password_hash(self, password): return self.pwd_context.hash(password) def authenticate_user(self, username: str, password: str): from handler.database import db_user_handler user = db_user_handler.get_user_by_username(username) if not user: return None if not self.verify_password(password, user.hashed_password): return None return user async def get_current_active_user_from_session(self, conn: HTTPConnection): from handler.database import db_user_handler issuer = conn.session.get("iss") if not issuer or issuer != "romm:auth": return None username = conn.session.get("sub") if not username: return None # Key exists therefore user is probably authenticated user = db_user_handler.get_user_by_username(username) if user is None or not user.enabled: conn.session.clear() log.error( "User '%s' %s", hl(username, color=CYAN), "not found" if user is None else "not enabled", ) return None return user def generate_password_reset_token(self, user: Any) -> None: now = datetime.now(timezone.utc) jti = str(uuid.uuid4()) to_encode = { "sub": user.username, "email": user.email, "type": TokenPurpose.RESET, "iat": int(now.timestamp()), "exp": int( ( now + timedelta(minutes=self.reset_passwd_token_expires_in_minutes) ).timestamp() ), "jti": jti, } token = jwt.encode( {"alg": ALGORITHM}, to_encode, OctKey.import_key(ROMM_AUTH_SECRET_KEY) ) log.info( f"Reset password link requested for {hl(user.username, color=CYAN)}. Reset link: {hl(f'{ROMM_BASE_URL}/reset-password?token={token}')}" ) redis_client.setex( f"reset-jti:{jti}", self.reset_passwd_token_expires_in_minutes * 60, "valid" ) def verify_password_reset_token(self, token: str) -> Any: """Verify the password reset token. Args: token (str): The token to verify. Raises: HTTPException: If the token is invalid or expired. HTTPException: If the token is missing or malformed. HTTPException: If the user is not found. HTTPException: If the token is not for password reset. """ from handler.database import db_user_handler try: payload = jwt.decode(token, ROMM_AUTH_SECRET_KEY, algorithms=[ALGORITHM]) except (BadSignatureError, DecodeError, ValueError) as exc: raise HTTPException(status_code=400, detail="Invalid token") from exc if payload.claims.get("type") != TokenPurpose.RESET: raise HTTPException(status_code=400, detail="Invalid token purpose") username = payload.claims.get("sub") jti = payload.claims.get("jti") if not username or not jti: raise HTTPException(status_code=400, detail="Invalid token payload") # Check JTI in Redis redis_jti_key = f"reset-jti:{jti}" if not redis_client.exists(redis_jti_key): raise HTTPException( status_code=400, detail="This token has already been used or is invalid" ) # Delete it to enforce one-time use redis_client.delete(redis_jti_key) user = db_user_handler.get_user_by_username(username) if not user: raise HTTPException(status_code=404, detail="User not found") now = datetime.now(timezone.utc).timestamp() if now > payload.claims.get("exp"): raise HTTPException(status_code=400, detail="Token has expired") return user def set_user_new_password(self, user: Any, new_password: str) -> None: """ Set the new password for the user. Args: user (Any): The user object. new_password (str): The new password to set. """ from handler.database import db_user_handler db_user_handler.update_user( user.id, {"hashed_password": self.get_password_hash(new_password)} ) def generate_invite_link_token(self, user: Any, role: str) -> str: """ Generate an invite link token for the user. Args: user (Any): The user object. role (str): The role of the user. Returns: str: The generated invite link token. """ now = datetime.now(timezone.utc) jti = str(uuid.uuid4()) to_encode = { "sub": user.username, "type": TokenPurpose.INVITE, "role": role.upper(), "iat": int(now.timestamp()), "exp": int( ( now + timedelta(minutes=self.invite_link_token_expires_in_minutes) ).timestamp() ), "jti": jti, } token = jwt.encode( {"alg": ALGORITHM}, to_encode, OctKey.import_key(ROMM_AUTH_SECRET_KEY) ) invite_link = f"{ROMM_BASE_URL}/register?token={token}" log.info( f"Invite link created by {hl(user.username, color=CYAN)}: {hl(invite_link)}" ) redis_client.setex( f"invite-jti:{jti}", self.invite_link_token_expires_in_minutes * 60, "valid" ) return token def verify_invite_link_token(self, token: str) -> tuple[str, str]: """ Verify the invite link token. Args: token (str): The token to verify. Returns: str: The JTI (JWT ID) of the token. """ try: payload = jwt.decode(token, OctKey.import_key(ROMM_AUTH_SECRET_KEY)) except (BadSignatureError, DecodeError, ValueError) as exc: raise HTTPException(status_code=400, detail="Invalid token") from exc if payload.claims.get("type") != TokenPurpose.INVITE: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid token type.", ) jti = payload.claims.get("jti") role = payload.claims.get("role", "USER").upper() if not jti or redis_client.get(f"invite-jti:{jti}") != b"valid": raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invite token has already been used or is invalid.", ) return jti, role def invalidate_invite_link_token(self, jti: str) -> None: """ Invalidate the invite link token. Args: jti (str): The JTI (JWT ID) of the token to invalidate. """ redis_client.delete(f"invite-jti:{jti}") class OAuthHandler: def __init__(self) -> None: pass def create_oauth_token( self, data: dict, expires_delta: timedelta = DEFAULT_OAUTH_TOKEN_EXPIRY ) -> str: to_encode = data.copy() expire = datetime.now(timezone.utc) + expires_delta to_encode.update({"exp": expire}) return jwt.encode( {"alg": ALGORITHM}, to_encode, OctKey.import_key(ROMM_AUTH_SECRET_KEY) ) async def get_current_active_user_from_bearer_token(self, token: str): from handler.database import db_user_handler try: payload = jwt.decode(token, OctKey.import_key(ROMM_AUTH_SECRET_KEY)) except (BadSignatureError, DecodeError, ValueError) as exc: raise OAuthCredentialsException from exc issuer = payload.claims.get("iss") if not issuer or issuer != "romm:oauth": return None, None username = payload.claims.get("sub") if username is None: raise OAuthCredentialsException user = db_user_handler.get_user_by_username(username) if user is None: raise OAuthCredentialsException if not user.enabled: raise UserDisabledException return user, payload.claims class OpenIDHandler: async def get_current_active_user_from_openid_token(self, token: Any): from handler.database import db_user_handler from models.user import Role, User if not OIDC_ENABLED: return None, None userinfo = token.get("userinfo") if userinfo is None: log.error("Userinfo is missing from token.") raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Userinfo is missing from token.", ) email = userinfo.get("email") if email is None: log.error("Email is missing from token.") raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Email is missing from token.", ) metadata = await oauth.openid.load_server_metadata() claims_supported = metadata.get("claims_supported") is_email_verified = userinfo.get("email_verified", None) # Fail if email is explicitly unverified, or `email_verified` is a supported claim and # email is not explicitly verified. if is_email_verified is False or ( claims_supported and "email_verified" in claims_supported and is_email_verified is not True ): log.error("Email is not verified.") raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Email is not verified.", ) preferred_username = userinfo.get("preferred_username") user = db_user_handler.get_user_by_email(email) if user is None: log.info( "User with email '%s' not found, creating new user", hl(email, color=CYAN), ) new_user = User( username=preferred_username, hashed_password=str(uuid.uuid4()), email=email, enabled=True, role=Role.VIEWER, ) user = db_user_handler.add_user(new_user) if not user.enabled: raise UserDisabledException log.info("User successfully authenticated: %s", hl(email, color=CYAN)) return user, userinfo