import hmac import secrets import threading import time from typing import Optional from argon2 import PasswordHasher from argon2.exceptions import VerifyMismatchError, InvalidHash from fastapi import HTTPException, Request, Response, status from itsdangerous import BadSignature, SignatureExpired, URLSafeTimedSerializer from settings import ( AUTH_PASSWORD_HASH, AUTH_USERNAME, COOKIE_SECURE, LOGIN_RATE_LIMIT, LOGIN_RATE_WINDOW_SECONDS, SESSION_COOKIE_NAME, SESSION_MAX_AGE_SECONDS, SESSION_SECRET, ) _hasher = PasswordHasher() _serializer = URLSafeTimedSerializer(SESSION_SECRET, salt="session") _csrf_serializer = URLSafeTimedSerializer(SESSION_SECRET, salt="csrf") # ---------- Password & session ---------- def verify_password(username: str, password: str) -> bool: if not hmac.compare_digest(username or "", AUTH_USERNAME): # run hasher anyway to keep timing similar (and not leak whether user exists) try: _hasher.verify(AUTH_PASSWORD_HASH, password) except Exception: pass return False try: _hasher.verify(AUTH_PASSWORD_HASH, password) return True except (VerifyMismatchError, InvalidHash): return False def issue_session_cookie(response: Response, username: str) -> None: token = _serializer.dumps({"u": username, "iat": int(time.time())}) response.set_cookie( key=SESSION_COOKIE_NAME, value=token, max_age=SESSION_MAX_AGE_SECONDS, httponly=True, secure=COOKIE_SECURE, samesite="strict", path="/", ) def clear_session_cookie(response: Response) -> None: response.delete_cookie( SESSION_COOKIE_NAME, path="/", secure=COOKIE_SECURE, httponly=True, samesite="strict", ) def current_user(request: Request) -> Optional[str]: token = request.cookies.get(SESSION_COOKIE_NAME) if not token: return None try: data = _serializer.loads(token, max_age=SESSION_MAX_AGE_SECONDS) except (BadSignature, SignatureExpired): return None return data.get("u") def require_user(request: Request) -> str: user = current_user(request) if not user: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="login required") return user # ---------- CSRF (synchronizer token bound to session) ---------- def issue_csrf_token(username: str) -> str: return _csrf_serializer.dumps({"u": username}) def verify_csrf(request: Request, submitted: str) -> bool: user = current_user(request) if not user or not submitted: return False try: data = _csrf_serializer.loads(submitted, max_age=SESSION_MAX_AGE_SECONDS) except (BadSignature, SignatureExpired): return False return hmac.compare_digest(str(data.get("u", "")), user) def require_csrf(request: Request, token: str) -> None: if not verify_csrf(request, token): raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="bad csrf") # ---------- Login rate limiting (in-memory, per IP) ---------- _rate_lock = threading.Lock() _rate_log: dict[str, list[float]] = {} def rate_limit_login(ip: str) -> bool: """Returns True if the request is allowed.""" now = time.time() cutoff = now - LOGIN_RATE_WINDOW_SECONDS with _rate_lock: attempts = [t for t in _rate_log.get(ip, []) if t > cutoff] if len(attempts) >= LOGIN_RATE_LIMIT: _rate_log[ip] = attempts return False attempts.append(now) _rate_log[ip] = attempts # opportunistic cleanup if len(_rate_log) > 1024: for k in list(_rate_log.keys()): if not _rate_log[k] or _rate_log[k][-1] < cutoff: _rate_log.pop(k, None) return True def constant_time_compare(a: str, b: str) -> bool: return hmac.compare_digest(a or "", b or "")