security(kms+refresh): optional KMS envelope keys + token refresh flow for Google
This commit is contained in:
parent
8d62ac0017
commit
ed71629f88
|
|
@ -56,6 +56,11 @@ def google_events(integration_id: int):
|
||||||
token = db.query(models.OAuthToken).filter_by(integration_id=integration_id).order_by(models.OAuthToken.id.desc()).first()
|
token = db.query(models.OAuthToken).filter_by(integration_id=integration_id).order_by(models.OAuthToken.id.desc()).first()
|
||||||
if not token or not token.access_token:
|
if not token or not token.access_token:
|
||||||
raise HTTPException(status_code=404, detail='no token found for integration')
|
raise HTTPException(status_code=404, detail='no token found for integration')
|
||||||
|
# Try to refresh token if needed (refresh flow is in oauth module)
|
||||||
|
from .oauth import refresh_google_token_if_needed
|
||||||
|
refreshed = refresh_google_token_if_needed(token)
|
||||||
|
if refreshed:
|
||||||
|
token = refreshed
|
||||||
|
|
||||||
from .crypto import decrypt_text
|
from .crypto import decrypt_text
|
||||||
decrypted_access = decrypt_text(token.access_token)
|
decrypted_access = decrypt_text(token.access_token)
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,8 @@ from cryptography.fernet import Fernet, InvalidToken
|
||||||
|
|
||||||
KEY_ENV = 'LIFERPG_DATA_KEY'
|
KEY_ENV = 'LIFERPG_DATA_KEY'
|
||||||
FALLBACK_KEY_PATH = os.path.join(os.path.dirname(__file__), '.dev_liferpg_key')
|
FALLBACK_KEY_PATH = os.path.join(os.path.dirname(__file__), '.dev_liferpg_key')
|
||||||
|
KMS_WRAPPED_PATH = os.path.join(os.path.dirname(__file__), '.wrapped_data_key')
|
||||||
|
KMS_KEY_ID_ENV = 'LIFERPG_KMS_KEY_ID'
|
||||||
|
|
||||||
|
|
||||||
def _load_key_from_env():
|
def _load_key_from_env():
|
||||||
|
|
@ -29,6 +31,51 @@ def _load_or_create_fallback_key():
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _load_key_from_kms():
|
||||||
|
"""Optional: use AWS KMS to manage a wrapped data key for envelope encryption.
|
||||||
|
|
||||||
|
If env var LIFERPG_KMS_KEY_ID is set and boto3 is available, this will either:
|
||||||
|
- read an existing wrapped key from KMS_WRAPPED_PATH and call KMS Decrypt to obtain the plaintext data key,
|
||||||
|
- or call KMS GenerateDataKey to produce and persist a wrapped key locally (development convenience).
|
||||||
|
"""
|
||||||
|
key_id = os.getenv(KMS_KEY_ID_ENV)
|
||||||
|
if not key_id:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
import boto3
|
||||||
|
from botocore.exceptions import BotoCoreError, ClientError
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
kms = boto3.client('kms')
|
||||||
|
# If wrapped key exists, decrypt it
|
||||||
|
if os.path.exists(KMS_WRAPPED_PATH):
|
||||||
|
try:
|
||||||
|
with open(KMS_WRAPPED_PATH, 'rb') as f:
|
||||||
|
blob = f.read()
|
||||||
|
resp = kms.decrypt(CiphertextBlob=blob)
|
||||||
|
return resp['Plaintext']
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Otherwise, generate a new data key and store the wrapped blob
|
||||||
|
try:
|
||||||
|
resp = kms.generate_data_key(KeyId=key_id, KeySpec='AES_256')
|
||||||
|
plaintext = resp['Plaintext']
|
||||||
|
ciphertext = resp['CiphertextBlob']
|
||||||
|
# persist wrapped key
|
||||||
|
with open(KMS_WRAPPED_PATH, 'wb') as f:
|
||||||
|
f.write(ciphertext)
|
||||||
|
# restrict perms
|
||||||
|
try:
|
||||||
|
os.chmod(KMS_WRAPPED_PATH, 0o600)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return plaintext
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_fernet():
|
def get_fernet():
|
||||||
key = _load_key_from_env() or _load_or_create_fallback_key()
|
key = _load_key_from_env() or _load_or_create_fallback_key()
|
||||||
if not key:
|
if not key:
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,8 @@ from fastapi import APIRouter, Request
|
||||||
from starlette.responses import RedirectResponse
|
from starlette.responses import RedirectResponse
|
||||||
from authlib.integrations.starlette_client import OAuth
|
from authlib.integrations.starlette_client import OAuth
|
||||||
from . import models
|
from . import models
|
||||||
|
import requests
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
oauth = OAuth()
|
oauth = OAuth()
|
||||||
|
|
@ -91,3 +93,64 @@ async def google_callback(request: Request):
|
||||||
return {'ok': True, 'integration_id': integration.id, 'token_saved': bool(oauth_token.id)}
|
return {'ok': True, 'integration_id': integration.id, 'token_saved': bool(oauth_token.id)}
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
def _decrypt_token(db_token_encrypted: str) -> str:
|
||||||
|
from .crypto import decrypt_text
|
||||||
|
return decrypt_text(db_token_encrypted)
|
||||||
|
|
||||||
|
|
||||||
|
def refresh_google_token_if_needed(oauth_token_row: models.OAuthToken) -> Optional[models.OAuthToken]:
|
||||||
|
"""Refresh Google's access token using refresh_token if expired or near expiry.
|
||||||
|
|
||||||
|
Returns updated OAuthToken row (new DB row) or None on failure.
|
||||||
|
"""
|
||||||
|
# If not expired, return the same
|
||||||
|
now = int(time.time())
|
||||||
|
if oauth_token_row.expires_at and oauth_token_row.expires_at > now + 30:
|
||||||
|
return oauth_token_row
|
||||||
|
|
||||||
|
refresh_token = _decrypt_token(oauth_token_row.refresh_token)
|
||||||
|
if not refresh_token:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Use Google's token endpoint to refresh
|
||||||
|
token_url = 'https://oauth2.googleapis.com/token'
|
||||||
|
client_id = os.getenv('GOOGLE_CLIENT_ID')
|
||||||
|
client_secret = os.getenv('GOOGLE_CLIENT_SECRET')
|
||||||
|
if not client_id or not client_secret:
|
||||||
|
return None
|
||||||
|
|
||||||
|
data = {
|
||||||
|
'client_id': client_id,
|
||||||
|
'client_secret': client_secret,
|
||||||
|
'grant_type': 'refresh_token',
|
||||||
|
'refresh_token': refresh_token
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
resp = requests.post(token_url, data=data, timeout=10)
|
||||||
|
if resp.status_code != 200:
|
||||||
|
return None
|
||||||
|
t = resp.json()
|
||||||
|
# Persist new token
|
||||||
|
from .crypto import encrypt_text
|
||||||
|
db = models.SessionLocal()
|
||||||
|
try:
|
||||||
|
new_expires = None
|
||||||
|
if t.get('expires_in'):
|
||||||
|
new_expires = int(time.time()) + int(t.get('expires_in'))
|
||||||
|
new_row = models.OAuthToken(
|
||||||
|
integration_id=oauth_token_row.integration_id,
|
||||||
|
access_token=encrypt_text(t.get('access_token') or ''),
|
||||||
|
refresh_token=encrypt_text(t.get('refresh_token') or refresh_token),
|
||||||
|
scope=t.get('scope') or oauth_token_row.scope,
|
||||||
|
expires_at=new_expires
|
||||||
|
)
|
||||||
|
db.add(new_row)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(new_row)
|
||||||
|
return new_row
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
|
||||||
|
|
@ -5,3 +5,4 @@ authlib
|
||||||
python-dotenv
|
python-dotenv
|
||||||
requests
|
requests
|
||||||
cryptography
|
cryptography
|
||||||
|
boto3
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user