✨ New Features: - AI-powered habit creation with natural language processing - HuggingFace transformers integration for sentiment analysis (tracked via Git LFS) - Advanced predictive analytics and behavioral insights - Voice & image input capabilities for hands-free habit tracking - Real-time notifications and community features - Plugin system with extensible architecture 🔧 Technical Improvements: - Comprehensive FastAPI backend with 30+ endpoints - React frontend with PWA capabilities - Advanced authentication with 2FA support - RBAC authorization system - Comprehensive security features (CSRF, rate limiting, audit logging) - Database migrations and health monitoring - Docker containerization support - Git LFS configured for large AI model files (2+ GB) 📚 Documentation & DevOps: - Complete deployment guides for multiple platforms - Professional README with feature highlights - GitHub Actions CI/CD workflows - Comprehensive API documentation - Security audit roadmap and compliance framework - Setup scripts for development environment 🧪 Testing & Quality: - Comprehensive test suite with 20+ test modules - Setup verification scripts - Working development environment with both backend and frontend - Health checks and monitoring systems 🌟 Ready for: - Portfolio showcasing - Community contributions - Production deployment - Professional presentation
425 lines
16 KiB
Python
425 lines
16 KiB
Python
import os
|
|
import time
|
|
from typing import Optional
|
|
|
|
import os
|
|
import time
|
|
from typing import Optional
|
|
|
|
from fastapi import APIRouter, Request, Depends, HTTPException
|
|
from authlib.integrations.starlette_client import OAuth
|
|
from sqlalchemy.orm import Session
|
|
import requests
|
|
|
|
import models
|
|
from db import get_db
|
|
from transaction import transactional
|
|
|
|
router = APIRouter()
|
|
oauth = OAuth()
|
|
|
|
# Load config from env at runtime
|
|
GOOGLE_CLIENT_ID = os.getenv('GOOGLE_CLIENT_ID')
|
|
GOOGLE_CLIENT_SECRET = os.getenv('GOOGLE_CLIENT_SECRET')
|
|
BASE_URL = os.getenv('BASE_URL', 'http://localhost:8000')
|
|
OIDC_STATE_MODE = os.getenv('OIDC_STATE_MODE', 'db').strip().lower() # 'db' | 'jwt'
|
|
OIDC_STATE_SECRET = os.getenv('OIDC_STATE_SECRET') # fallback to JWT secret below if not set
|
|
OIDC_VALIDATE_CLAIMS = os.getenv('OIDC_VALIDATE_CLAIMS', 'false').strip().lower() in ('1','true','yes','on')
|
|
|
|
if GOOGLE_CLIENT_ID and GOOGLE_CLIENT_SECRET:
|
|
oauth.register(
|
|
name='google',
|
|
client_id=GOOGLE_CLIENT_ID,
|
|
client_secret=GOOGLE_CLIENT_SECRET,
|
|
server_metadata_url='https://accounts.google.com/.well-known-openid-configuration',
|
|
client_kwargs={'scope': 'openid email profile https://www.googleapis.com/auth/calendar.events'},
|
|
)
|
|
|
|
|
|
@router.get('/oauth/google/login')
|
|
async def google_login(request: Request):
|
|
if 'google' not in oauth:
|
|
return {'error': 'google oauth not configured; set GOOGLE_CLIENT_ID and GOOGLE_CLIENT_SECRET'}
|
|
redirect_uri = BASE_URL + '/api/v1/oauth/google/callback'
|
|
return await oauth.google.authorize_redirect(request, redirect_uri)
|
|
|
|
|
|
@router.get('/oauth/google/callback')
|
|
async def google_callback(request: Request, db: Session = Depends(get_db)):
|
|
"""Handle Google's OAuth callback and persist Integration + OAuthToken rows.
|
|
|
|
Associates integration with `user_id` query param (or 1) and stores encrypted tokens.
|
|
"""
|
|
if 'google' not in oauth:
|
|
return {'error': 'google oauth not configured'}
|
|
|
|
token = await oauth.google.authorize_access_token(request)
|
|
userinfo = None
|
|
try:
|
|
userinfo = await oauth.google.parse_id_token(request, token)
|
|
except Exception:
|
|
try:
|
|
resp = await oauth.google.get('userinfo', token=token)
|
|
userinfo = resp.json()
|
|
except Exception:
|
|
userinfo = {}
|
|
|
|
qs = dict(request.query_params)
|
|
user_id = int(qs.get('user_id')) if qs.get('user_id') else 1
|
|
|
|
ext_id = userinfo.get('sub') or userinfo.get('id') or None
|
|
|
|
from .crypto import encrypt_text
|
|
|
|
expires_at = None
|
|
if token.get('expires_in'):
|
|
expires_at = int(time.time()) + int(token.get('expires_in'))
|
|
|
|
# persist integration + token in a transaction so audit logs can participate
|
|
with transactional(db):
|
|
integration = db.query(models.Integration).filter_by(user_id=user_id, provider='google').first()
|
|
if not integration:
|
|
integration = models.Integration(user_id=user_id, provider='google', external_id=ext_id, config='{}')
|
|
db.add(integration)
|
|
db.flush()
|
|
db.refresh(integration)
|
|
|
|
oauth_token = models.OAuthToken(
|
|
integration_id=integration.id,
|
|
access_token=encrypt_text(token.get('access_token') or ''),
|
|
refresh_token=encrypt_text(token.get('refresh_token') or ''),
|
|
scope=token.get('scope'),
|
|
expires_at=expires_at,
|
|
)
|
|
db.add(oauth_token)
|
|
db.flush()
|
|
db.refresh(oauth_token)
|
|
|
|
return {'ok': True, 'integration_id': integration.id, 'token_saved': bool(oauth_token.id)}
|
|
|
|
|
|
# --- Generic OIDC with PKCE (multi-provider) ---
|
|
BASE_URL = os.getenv('BASE_URL', 'http://localhost:8000')
|
|
|
|
def _load_provider_configs():
|
|
"""Load provider configs from env JSON OIDC_PROVIDERS or legacy single-provider vars."""
|
|
import json
|
|
raw = os.getenv('OIDC_PROVIDERS')
|
|
if raw:
|
|
try:
|
|
data = json.loads(raw)
|
|
if isinstance(data, dict):
|
|
return data
|
|
except Exception:
|
|
pass
|
|
# Legacy single provider
|
|
issuer = os.getenv('OIDC_ISSUER')
|
|
client_id = os.getenv('OIDC_CLIENT_ID')
|
|
client_secret = os.getenv('OIDC_CLIENT_SECRET')
|
|
scope = os.getenv('OIDC_SCOPE', 'openid email profile')
|
|
if issuer and client_id:
|
|
return {'oidc': {'issuer': issuer, 'client_id': client_id, 'client_secret': client_secret, 'scope': scope}}
|
|
return {}
|
|
|
|
|
|
def _ensure_registered(provider: str):
|
|
cfgs = _load_provider_configs()
|
|
# Already registered as an attribute on the OAuth registry
|
|
if hasattr(oauth, provider):
|
|
return provider in cfgs or provider == 'google'
|
|
if provider in cfgs:
|
|
info = cfgs[provider]
|
|
oauth.register(
|
|
name=provider,
|
|
client_id=info.get('client_id'),
|
|
client_secret=info.get('client_secret'),
|
|
server_metadata_url=f"{info.get('issuer').rstrip('/')}/.well-known/openid-configuration",
|
|
client_kwargs={'scope': info.get('scope', 'openid email profile')},
|
|
code_challenge_method='S256',
|
|
)
|
|
return True
|
|
return False
|
|
|
|
|
|
@router.get('/auth/oidc/providers')
|
|
def list_oidc_providers():
|
|
return {'providers': list(_load_provider_configs().keys())}
|
|
|
|
|
|
@router.get('/auth/oidc/{provider}/login')
|
|
async def oidc_login_provider(provider: str, request: Request, db: Session = Depends(get_db)):
|
|
if not _ensure_registered(provider):
|
|
raise HTTPException(status_code=400, detail='OIDC provider not configured')
|
|
redirect_uri = BASE_URL + '/api/v1/auth/oidc/callback'
|
|
# Create PKCE params; Authlib can generate code_verifier if not provided; we'll track state+verifier
|
|
from authlib.oauth2.rfc7636 import create_s256_code_challenge
|
|
import secrets
|
|
code_verifier = secrets.token_urlsafe(64)
|
|
code_challenge = create_s256_code_challenge(code_verifier)
|
|
# generate state
|
|
state = secrets.token_urlsafe(32)
|
|
exp = int(time.time()) + 600
|
|
# Two modes: DB-backed or signed JWT state
|
|
if OIDC_STATE_MODE == 'db':
|
|
exp_dt = __import__('datetime').datetime.fromtimestamp(exp, __import__('datetime').timezone.utc)
|
|
s = models.OIDCLoginState(state=state, provider=provider, code_verifier=code_verifier, expires_at=exp_dt)
|
|
db.add(s)
|
|
db.commit()
|
|
state_out = state
|
|
else:
|
|
import jwt as pyjwt
|
|
from .auth import JWT_SECRET as DEFAULT_SECRET
|
|
secret = OIDC_STATE_SECRET or DEFAULT_SECRET
|
|
payload = {'pv': provider, 'cv': code_verifier, 'exp': exp, 'iat': int(time.time())}
|
|
state_out = pyjwt.encode(payload, secret, algorithm='HS256')
|
|
# Use authorize_redirect with extra PKCE params
|
|
client = getattr(oauth, provider)
|
|
return await client.authorize_redirect(
|
|
request,
|
|
redirect_uri,
|
|
code_challenge=code_challenge,
|
|
code_challenge_method='S256',
|
|
state=state_out,
|
|
)
|
|
|
|
|
|
@router.get('/auth/oidc/callback')
|
|
async def oidc_callback(request: Request, db: Session = Depends(get_db)):
|
|
params = dict(request.query_params)
|
|
state = params.get('state')
|
|
if not state:
|
|
raise HTTPException(status_code=400, detail='missing state')
|
|
rec = None
|
|
provider = None
|
|
code_verifier = None
|
|
# Try DB mode first (back-compat)
|
|
if OIDC_STATE_MODE == 'db':
|
|
rec = db.query(models.OIDCLoginState).filter_by(state=state).first()
|
|
if not rec:
|
|
raise HTTPException(status_code=400, detail='invalid state')
|
|
provider = rec.provider
|
|
code_verifier = rec.code_verifier
|
|
else:
|
|
# JWT state
|
|
import jwt as pyjwt
|
|
from .auth import JWT_SECRET as DEFAULT_SECRET
|
|
secret = OIDC_STATE_SECRET or DEFAULT_SECRET
|
|
try:
|
|
data = pyjwt.decode(state, secret, algorithms=['HS256'])
|
|
provider = data.get('pv')
|
|
code_verifier = data.get('cv')
|
|
if not provider or not code_verifier:
|
|
raise Exception('invalid')
|
|
except Exception:
|
|
raise HTTPException(status_code=400, detail='invalid state')
|
|
# Optional expiration check (handle naive vs aware datetimes)
|
|
from datetime import datetime, timezone
|
|
if rec.expires_at:
|
|
exp_at = rec.expires_at
|
|
if getattr(exp_at, 'tzinfo', None) is None:
|
|
exp_at = exp_at.replace(tzinfo=timezone.utc)
|
|
if exp_at < datetime.now(timezone.utc):
|
|
# Cleanup stale state and reject
|
|
try:
|
|
db.delete(rec)
|
|
db.commit()
|
|
except Exception:
|
|
pass
|
|
raise HTTPException(status_code=400, detail='state expired')
|
|
|
|
# Exchange code for tokens using stored code_verifier
|
|
# Ensure client is registered for the stored provider
|
|
if not _ensure_registered(provider):
|
|
raise HTTPException(status_code=400, detail='OIDC provider not configured')
|
|
client = getattr(oauth, provider)
|
|
try:
|
|
token = await client.authorize_access_token(request, code_verifier=code_verifier)
|
|
except Exception:
|
|
raise HTTPException(status_code=401, detail='token exchange failed')
|
|
|
|
# Get userinfo via ID token or userinfo endpoint
|
|
userinfo = None
|
|
try:
|
|
userinfo = await client.parse_id_token(request, token)
|
|
except Exception:
|
|
try:
|
|
resp = await client.get('userinfo', token=token)
|
|
userinfo = resp.json()
|
|
except Exception:
|
|
userinfo = {}
|
|
|
|
# Optional audience/issuer extra validation
|
|
if OIDC_VALIDATE_CLAIMS:
|
|
try:
|
|
# claims may be in parsed userinfo or in raw id_token
|
|
claims = userinfo or {}
|
|
idt = token.get('id_token') if isinstance(token, dict) else None
|
|
iss_expected = None
|
|
aud_expected = None
|
|
cfgs = _load_provider_configs()
|
|
info = cfgs.get(provider) or {}
|
|
iss_expected = info.get('issuer')
|
|
aud_expected = info.get('client_id')
|
|
if idt:
|
|
import jwt as pyjwt
|
|
# don't verify signature again here; Authlib already did. We just parse claims.
|
|
try:
|
|
claims = pyjwt.decode(idt, options={'verify_signature': False})
|
|
except Exception:
|
|
claims = claims or {}
|
|
if iss_expected and claims.get('iss') and claims.get('iss') != iss_expected:
|
|
raise HTTPException(status_code=401, detail='issuer mismatch')
|
|
aud = claims.get('aud')
|
|
if aud_expected and aud:
|
|
if isinstance(aud, str) and aud != aud_expected:
|
|
raise HTTPException(status_code=401, detail='audience mismatch')
|
|
if isinstance(aud, (list, tuple)) and aud_expected not in aud:
|
|
raise HTTPException(status_code=401, detail='audience mismatch')
|
|
except HTTPException:
|
|
raise
|
|
except Exception:
|
|
# be conservative: if validation requested but we cannot evaluate, reject
|
|
raise HTTPException(status_code=401, detail='claim validation failed')
|
|
|
|
email = userinfo.get('email') or userinfo.get('preferred_username')
|
|
if not email:
|
|
raise HTTPException(status_code=400, detail='email not provided by provider')
|
|
|
|
# Upsert user and create app session
|
|
user = db.query(models.User).filter_by(email=email).first()
|
|
if not user:
|
|
user = models.User(email=email, display_name=userinfo.get('name'))
|
|
db.add(user)
|
|
db.flush()
|
|
db.refresh(user)
|
|
|
|
# cleanup state
|
|
if rec is not None:
|
|
try:
|
|
db.delete(rec)
|
|
db.commit()
|
|
except Exception:
|
|
pass
|
|
|
|
# Issue app session cookie
|
|
from .auth import create_token
|
|
app_token = create_token({'sub': user.id})
|
|
from fastapi.responses import JSONResponse
|
|
resp = JSONResponse({'id': user.id, 'email': user.email})
|
|
from .config import settings as cfg
|
|
# set cookies
|
|
resp.set_cookie('session', app_token, httponly=True, secure=cfg.COOKIE_SECURE, samesite=cfg.COOKIE_SAMESITE)
|
|
import secrets as _secrets
|
|
csrf = _secrets.token_urlsafe(32)
|
|
resp.set_cookie(cfg.CSRF_COOKIE_NAME, csrf, httponly=False, secure=cfg.COOKIE_SECURE, samesite=cfg.COOKIE_SAMESITE)
|
|
# store provider and id_token for RP-initiated logout
|
|
id_token = token.get('id_token') if isinstance(token, dict) else None
|
|
if id_token:
|
|
resp.set_cookie('oidc_id_token', id_token, httponly=True, secure=cfg.COOKIE_SECURE, samesite=cfg.COOKIE_SAMESITE)
|
|
resp.set_cookie('oidc_provider', provider, httponly=True, secure=cfg.COOKIE_SECURE, samesite=cfg.COOKIE_SAMESITE)
|
|
return resp
|
|
|
|
|
|
@router.post('/auth/oidc/logout')
|
|
async def oidc_logout(request: Request):
|
|
"""Logout locally and, if supported, redirect to the IdP end_session endpoint (RP-initiated logout)."""
|
|
from fastapi.responses import JSONResponse, RedirectResponse
|
|
provider = request.cookies.get('oidc_provider')
|
|
id_token = request.cookies.get('oidc_id_token')
|
|
# Clear local cookies
|
|
from .config import settings as cfg
|
|
resp: JSONResponse | RedirectResponse
|
|
end_session = None
|
|
if provider and _ensure_registered(provider):
|
|
client = getattr(oauth, provider)
|
|
try:
|
|
# ensure metadata loaded
|
|
meta = await client.load_server_metadata()
|
|
except Exception:
|
|
meta = getattr(client, 'server_metadata', {}) or {}
|
|
end_session = (meta or {}).get('end_session_endpoint') or (meta or {}).get('end_session_endpoint_url')
|
|
if end_session and id_token:
|
|
post_logout = BASE_URL + '/api/v1/auth/oidc/logout/callback'
|
|
url = f"{end_session}?id_token_hint={id_token}&post_logout_redirect_uri={post_logout}"
|
|
resp = RedirectResponse(url, status_code=307)
|
|
else:
|
|
resp = JSONResponse({'ok': True, 'logged_out': True})
|
|
# clear cookies
|
|
resp.delete_cookie('session')
|
|
resp.delete_cookie('oidc_id_token')
|
|
resp.delete_cookie('oidc_provider')
|
|
try:
|
|
# also clear csrf cookie if present
|
|
from .config import settings as cfg2
|
|
resp.delete_cookie(cfg2.CSRF_COOKIE_NAME)
|
|
except Exception:
|
|
pass
|
|
return resp
|
|
|
|
|
|
@router.get('/auth/oidc/logout/callback')
|
|
def oidc_logout_callback():
|
|
return {'ok': True, 'logout': 'complete'}
|
|
|
|
|
|
def _decrypt_token(db_token_encrypted: str) -> str:
|
|
from .crypto import decrypt_text
|
|
|
|
return decrypt_text(db_token_encrypted)
|
|
|
|
|
|
def refresh_google_token_if_needed(token_row: models.OAuthToken, db: Session) -> Optional[models.OAuthToken]:
|
|
"""Refresh Google's access token using refresh_token; return new OAuthToken row or None.
|
|
|
|
This helper requires an active SQLAlchemy `db` Session so it can participate in
|
|
the caller's transaction. It no longer creates or commits its own session.
|
|
"""
|
|
# still valid
|
|
if token_row.expires_at and token_row.expires_at > int(time.time()):
|
|
return None
|
|
if not token_row.refresh_token:
|
|
return None
|
|
|
|
token_url = 'https://oauth2.googleapis.com/token'
|
|
client_id = os.getenv('GOOGLE_CLIENT_ID')
|
|
client_secret = os.getenv('GOOGLE_CLIENT_SECRET')
|
|
if not client_id or not client_secret:
|
|
return None
|
|
|
|
data = {
|
|
'client_id': client_id,
|
|
'client_secret': client_secret,
|
|
'grant_type': 'refresh_token',
|
|
'refresh_token': _decrypt_token(token_row.refresh_token),
|
|
}
|
|
|
|
try:
|
|
resp = requests.post(token_url, data=data, timeout=10)
|
|
if resp.status_code != 200:
|
|
return None
|
|
t = resp.json()
|
|
from .crypto import encrypt_text
|
|
|
|
new_expires = None
|
|
if t.get('expires_in'):
|
|
new_expires = int(time.time()) + int(t.get('expires_in'))
|
|
|
|
new_row = models.OAuthToken(
|
|
integration_id=token_row.integration_id,
|
|
access_token=encrypt_text(t.get('access_token') or ''),
|
|
refresh_token=encrypt_text(t.get('refresh_token') or _decrypt_token(token_row.refresh_token)),
|
|
scope=t.get('scope') or token_row.scope,
|
|
expires_at=new_expires,
|
|
)
|
|
db.add(new_row)
|
|
# caller controls commit/flush/refresh; flush so caller can see id if needed
|
|
db.flush()
|
|
db.refresh(new_row)
|
|
return new_row
|
|
except Exception:
|
|
return None
|
|
|
|
|
|
oauth_router = router
|
|
|