✨ New Features: - AI-powered habit creation with natural language processing - HuggingFace transformers integration for sentiment analysis (tracked via Git LFS) - Advanced predictive analytics and behavioral insights - Voice & image input capabilities for hands-free habit tracking - Real-time notifications and community features - Plugin system with extensible architecture 🔧 Technical Improvements: - Comprehensive FastAPI backend with 30+ endpoints - React frontend with PWA capabilities - Advanced authentication with 2FA support - RBAC authorization system - Comprehensive security features (CSRF, rate limiting, audit logging) - Database migrations and health monitoring - Docker containerization support - Git LFS configured for large AI model files (2+ GB) 📚 Documentation & DevOps: - Complete deployment guides for multiple platforms - Professional README with feature highlights - GitHub Actions CI/CD workflows - Comprehensive API documentation - Security audit roadmap and compliance framework - Setup scripts for development environment 🧪 Testing & Quality: - Comprehensive test suite with 20+ test modules - Setup verification scripts - Working development environment with both backend and frontend - Health checks and monitoring systems 🌟 Ready for: - Portfolio showcasing - Community contributions - Production deployment - Professional presentation
543 lines
19 KiB
Python
543 lines
19 KiB
Python
"""
|
|
Advanced rate limiting with user-based and IP-based controls
|
|
Provides comprehensive protection against abuse with flexible configuration
|
|
"""
|
|
|
|
import time
|
|
import asyncio
|
|
import json
|
|
from typing import Dict, Optional, List, Union
|
|
from datetime import datetime, timedelta
|
|
from collections import defaultdict, deque
|
|
from dataclasses import dataclass, asdict
|
|
from enum import Enum
|
|
import redis
|
|
from fastapi import HTTPException, Request
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RateLimitType(Enum):
|
|
"""Types of rate limits"""
|
|
USER_BASED = "user"
|
|
IP_BASED = "ip"
|
|
GLOBAL = "global"
|
|
ENDPOINT_SPECIFIC = "endpoint"
|
|
SLIDING_WINDOW = "sliding"
|
|
FIXED_WINDOW = "fixed"
|
|
|
|
|
|
class RateLimitAction(Enum):
|
|
"""Actions to take when rate limit is exceeded"""
|
|
BLOCK = "block"
|
|
THROTTLE = "throttle"
|
|
WARNING = "warning"
|
|
LOG_ONLY = "log"
|
|
|
|
|
|
@dataclass
|
|
class RateLimitRule:
|
|
"""Configuration for a rate limit rule"""
|
|
max_requests: int
|
|
window_seconds: int
|
|
limit_type: RateLimitType
|
|
action: RateLimitAction = RateLimitAction.BLOCK
|
|
burst_allowance: int = 0
|
|
throttle_delay: float = 1.0
|
|
endpoints: List[str] = None
|
|
user_tiers: List[str] = None # premium, basic, free
|
|
|
|
def __post_init__(self):
|
|
if self.endpoints is None:
|
|
self.endpoints = ["*"] # Apply to all endpoints
|
|
if self.user_tiers is None:
|
|
self.user_tiers = ["*"] # Apply to all user tiers
|
|
|
|
|
|
@dataclass
|
|
class RateLimitStatus:
|
|
"""Current rate limit status for a key"""
|
|
requests_made: int
|
|
requests_remaining: int
|
|
reset_time: datetime
|
|
is_limited: bool
|
|
action: RateLimitAction
|
|
retry_after: Optional[int] = None
|
|
|
|
|
|
class AdvancedRateLimiter:
|
|
"""
|
|
Advanced rate limiting with multiple strategies and storage backends
|
|
"""
|
|
|
|
def __init__(self, redis_client: Optional[redis.Redis] = None):
|
|
self.redis = redis_client
|
|
self.local_cache = defaultdict(lambda: defaultdict(deque))
|
|
self.rules: Dict[str, RateLimitRule] = {}
|
|
self.user_tiers = {} # user_id -> tier mapping
|
|
|
|
# Default rules
|
|
self._setup_default_rules()
|
|
|
|
def _setup_default_rules(self):
|
|
"""Setup default rate limiting rules"""
|
|
|
|
# General API rate limits by user tier
|
|
self.add_rule("user_basic", RateLimitRule(
|
|
max_requests=1000,
|
|
window_seconds=3600, # 1 hour
|
|
limit_type=RateLimitType.USER_BASED,
|
|
user_tiers=["basic", "free"]
|
|
))
|
|
|
|
self.add_rule("user_premium", RateLimitRule(
|
|
max_requests=5000,
|
|
window_seconds=3600, # 1 hour
|
|
limit_type=RateLimitType.USER_BASED,
|
|
user_tiers=["premium", "pro"]
|
|
))
|
|
|
|
# IP-based limits for anonymous users
|
|
self.add_rule("ip_anonymous", RateLimitRule(
|
|
max_requests=100,
|
|
window_seconds=3600, # 1 hour
|
|
limit_type=RateLimitType.IP_BASED
|
|
))
|
|
|
|
# Strict limits for authentication endpoints
|
|
self.add_rule("auth_endpoints", RateLimitRule(
|
|
max_requests=5,
|
|
window_seconds=300, # 5 minutes
|
|
limit_type=RateLimitType.USER_BASED,
|
|
endpoints=["/auth/login", "/auth/register", "/auth/reset-password"],
|
|
action=RateLimitAction.BLOCK
|
|
))
|
|
|
|
# More lenient limits for read operations
|
|
self.add_rule("read_operations", RateLimitRule(
|
|
max_requests=500,
|
|
window_seconds=300, # 5 minutes
|
|
limit_type=RateLimitType.USER_BASED,
|
|
endpoints=["/habits", "/analytics", "/profile"],
|
|
burst_allowance=50
|
|
))
|
|
|
|
# Strict limits for write operations
|
|
self.add_rule("write_operations", RateLimitRule(
|
|
max_requests=100,
|
|
window_seconds=300, # 5 minutes
|
|
limit_type=RateLimitType.USER_BASED,
|
|
endpoints=["/habits/create", "/habits/*/complete", "/habits/*/update"],
|
|
action=RateLimitAction.THROTTLE,
|
|
throttle_delay=2.0
|
|
))
|
|
|
|
# Global rate limits for server protection
|
|
self.add_rule("global_protection", RateLimitRule(
|
|
max_requests=10000,
|
|
window_seconds=60, # 1 minute
|
|
limit_type=RateLimitType.GLOBAL
|
|
))
|
|
|
|
def add_rule(self, rule_id: str, rule: RateLimitRule):
|
|
"""Add a new rate limiting rule"""
|
|
self.rules[rule_id] = rule
|
|
logger.info(f"Added rate limit rule: {rule_id}")
|
|
|
|
def set_user_tier(self, user_id: str, tier: str):
|
|
"""Set the tier for a user (basic, premium, pro, etc.)"""
|
|
self.user_tiers[user_id] = tier
|
|
|
|
def get_user_tier(self, user_id: str) -> str:
|
|
"""Get the tier for a user, default to 'basic'"""
|
|
return self.user_tiers.get(user_id, "basic")
|
|
|
|
async def check_rate_limit(self,
|
|
request: Request,
|
|
user_id: Optional[str] = None,
|
|
endpoint: Optional[str] = None) -> RateLimitStatus:
|
|
"""
|
|
Check if a request should be rate limited
|
|
Returns RateLimitStatus with current status
|
|
"""
|
|
|
|
# Determine applicable rules
|
|
applicable_rules = self._get_applicable_rules(
|
|
user_id=user_id,
|
|
endpoint=endpoint,
|
|
ip=self._get_client_ip(request)
|
|
)
|
|
|
|
# Check each applicable rule
|
|
most_restrictive_status = None
|
|
|
|
for rule_id, rule in applicable_rules:
|
|
key = self._generate_key(rule, user_id, self._get_client_ip(request))
|
|
status = await self._check_rule(rule, key)
|
|
|
|
# Track most restrictive limit
|
|
if (most_restrictive_status is None or
|
|
status.is_limited or
|
|
status.requests_remaining < most_restrictive_status.requests_remaining):
|
|
most_restrictive_status = status
|
|
|
|
return most_restrictive_status or RateLimitStatus(
|
|
requests_made=0,
|
|
requests_remaining=float('inf'),
|
|
reset_time=datetime.now() + timedelta(hours=1),
|
|
is_limited=False,
|
|
action=RateLimitAction.LOG_ONLY
|
|
)
|
|
|
|
async def record_request(self,
|
|
request: Request,
|
|
user_id: Optional[str] = None,
|
|
endpoint: Optional[str] = None):
|
|
"""Record a request for rate limiting purposes"""
|
|
|
|
applicable_rules = self._get_applicable_rules(
|
|
user_id=user_id,
|
|
endpoint=endpoint,
|
|
ip=self._get_client_ip(request)
|
|
)
|
|
|
|
# Record request for each applicable rule
|
|
for rule_id, rule in applicable_rules:
|
|
key = self._generate_key(rule, user_id, self._get_client_ip(request))
|
|
await self._record_request_for_rule(rule, key)
|
|
|
|
def _get_applicable_rules(self,
|
|
user_id: Optional[str],
|
|
endpoint: Optional[str],
|
|
ip: str) -> List[tuple]:
|
|
"""Get rules that apply to this request"""
|
|
|
|
applicable_rules = []
|
|
user_tier = self.get_user_tier(user_id) if user_id else "anonymous"
|
|
|
|
for rule_id, rule in self.rules.items():
|
|
# Check if rule applies to this user tier
|
|
if "*" not in rule.user_tiers and user_tier not in rule.user_tiers:
|
|
continue
|
|
|
|
# Check if rule applies to this endpoint
|
|
if endpoint and "*" not in rule.endpoints:
|
|
endpoint_matches = False
|
|
for pattern in rule.endpoints:
|
|
if self._endpoint_matches(endpoint, pattern):
|
|
endpoint_matches = True
|
|
break
|
|
if not endpoint_matches:
|
|
continue
|
|
|
|
# Check if we have required identifiers for the rule type
|
|
if rule.limit_type == RateLimitType.USER_BASED and not user_id:
|
|
continue
|
|
|
|
applicable_rules.append((rule_id, rule))
|
|
|
|
return applicable_rules
|
|
|
|
def _endpoint_matches(self, endpoint: str, pattern: str) -> bool:
|
|
"""Check if an endpoint matches a pattern (supports wildcards)"""
|
|
if pattern == "*":
|
|
return True
|
|
|
|
# Simple wildcard matching
|
|
if "*" in pattern:
|
|
parts = pattern.split("*")
|
|
if len(parts) == 2:
|
|
prefix, suffix = parts
|
|
return endpoint.startswith(prefix) and endpoint.endswith(suffix)
|
|
|
|
return endpoint == pattern
|
|
|
|
def _generate_key(self,
|
|
rule: RateLimitRule,
|
|
user_id: Optional[str],
|
|
ip: str) -> str:
|
|
"""Generate a cache key for the rate limit"""
|
|
|
|
if rule.limit_type == RateLimitType.USER_BASED:
|
|
return f"rate_limit:user:{user_id}:{rule.window_seconds}"
|
|
elif rule.limit_type == RateLimitType.IP_BASED:
|
|
return f"rate_limit:ip:{ip}:{rule.window_seconds}"
|
|
elif rule.limit_type == RateLimitType.GLOBAL:
|
|
return f"rate_limit:global:{rule.window_seconds}"
|
|
else:
|
|
return f"rate_limit:custom:{user_id or ip}:{rule.window_seconds}"
|
|
|
|
async def _check_rule(self, rule: RateLimitRule, key: str) -> RateLimitStatus:
|
|
"""Check rate limit for a specific rule"""
|
|
|
|
now = time.time()
|
|
window_start = now - rule.window_seconds
|
|
|
|
if self.redis:
|
|
return await self._check_rule_redis(rule, key, now, window_start)
|
|
else:
|
|
return await self._check_rule_memory(rule, key, now, window_start)
|
|
|
|
async def _check_rule_redis(self,
|
|
rule: RateLimitRule,
|
|
key: str,
|
|
now: float,
|
|
window_start: float) -> RateLimitStatus:
|
|
"""Check rate limit using Redis storage"""
|
|
|
|
pipe = self.redis.pipeline()
|
|
|
|
# Remove old entries
|
|
pipe.zremrangebyscore(key, 0, window_start)
|
|
|
|
# Count current entries
|
|
pipe.zcard(key)
|
|
|
|
# Set expiration
|
|
pipe.expire(key, rule.window_seconds)
|
|
|
|
results = await pipe.execute()
|
|
current_count = results[1]
|
|
|
|
requests_remaining = max(0, rule.max_requests - current_count)
|
|
is_limited = current_count >= rule.max_requests
|
|
|
|
return RateLimitStatus(
|
|
requests_made=current_count,
|
|
requests_remaining=requests_remaining,
|
|
reset_time=datetime.fromtimestamp(now + rule.window_seconds),
|
|
is_limited=is_limited,
|
|
action=rule.action,
|
|
retry_after=rule.window_seconds if is_limited else None
|
|
)
|
|
|
|
async def _check_rule_memory(self,
|
|
rule: RateLimitRule,
|
|
key: str,
|
|
now: float,
|
|
window_start: float) -> RateLimitStatus:
|
|
"""Check rate limit using in-memory storage"""
|
|
|
|
requests = self.local_cache[key]['requests']
|
|
|
|
# Remove old requests
|
|
while requests and requests[0] < window_start:
|
|
requests.popleft()
|
|
|
|
current_count = len(requests)
|
|
requests_remaining = max(0, rule.max_requests - current_count)
|
|
is_limited = current_count >= rule.max_requests
|
|
|
|
return RateLimitStatus(
|
|
requests_made=current_count,
|
|
requests_remaining=requests_remaining,
|
|
reset_time=datetime.fromtimestamp(now + rule.window_seconds),
|
|
is_limited=is_limited,
|
|
action=rule.action,
|
|
retry_after=rule.window_seconds if is_limited else None
|
|
)
|
|
|
|
async def _record_request_for_rule(self, rule: RateLimitRule, key: str):
|
|
"""Record a request for a specific rule"""
|
|
|
|
now = time.time()
|
|
|
|
if self.redis:
|
|
await self._record_request_redis(key, now, rule.window_seconds)
|
|
else:
|
|
await self._record_request_memory(key, now)
|
|
|
|
async def _record_request_redis(self, key: str, timestamp: float, window_seconds: int):
|
|
"""Record a request using Redis"""
|
|
|
|
pipe = self.redis.pipeline()
|
|
pipe.zadd(key, {str(timestamp): timestamp})
|
|
pipe.expire(key, window_seconds)
|
|
await pipe.execute()
|
|
|
|
async def _record_request_memory(self, key: str, timestamp: float):
|
|
"""Record a request using in-memory storage"""
|
|
|
|
requests = self.local_cache[key]['requests']
|
|
requests.append(timestamp)
|
|
|
|
def _get_client_ip(self, request: Request) -> str:
|
|
"""Extract client IP from request"""
|
|
|
|
# Check for forwarded headers
|
|
forwarded = request.headers.get("X-Forwarded-For")
|
|
if forwarded:
|
|
return forwarded.split(",")[0].strip()
|
|
|
|
real_ip = request.headers.get("X-Real-IP")
|
|
if real_ip:
|
|
return real_ip
|
|
|
|
# Fallback to direct connection
|
|
return request.client.host if request.client else "unknown"
|
|
|
|
async def get_rate_limit_info(self,
|
|
request: Request,
|
|
user_id: Optional[str] = None) -> Dict:
|
|
"""Get comprehensive rate limit information"""
|
|
|
|
info = {
|
|
"user_id": user_id,
|
|
"ip": self._get_client_ip(request),
|
|
"user_tier": self.get_user_tier(user_id) if user_id else "anonymous",
|
|
"limits": {}
|
|
}
|
|
|
|
applicable_rules = self._get_applicable_rules(
|
|
user_id=user_id,
|
|
endpoint=None, # Get all rules
|
|
ip=info["ip"]
|
|
)
|
|
|
|
for rule_id, rule in applicable_rules:
|
|
key = self._generate_key(rule, user_id, info["ip"])
|
|
status = await self._check_rule(rule, key)
|
|
|
|
info["limits"][rule_id] = {
|
|
"max_requests": rule.max_requests,
|
|
"window_seconds": rule.window_seconds,
|
|
"requests_made": status.requests_made,
|
|
"requests_remaining": status.requests_remaining,
|
|
"reset_time": status.reset_time.isoformat(),
|
|
"is_limited": status.is_limited
|
|
}
|
|
|
|
return info
|
|
|
|
|
|
class RateLimitMiddleware(BaseHTTPMiddleware):
|
|
"""
|
|
FastAPI middleware for automatic rate limiting
|
|
"""
|
|
|
|
def __init__(self, app, rate_limiter: AdvancedRateLimiter):
|
|
super().__init__(app)
|
|
self.rate_limiter = rate_limiter
|
|
|
|
async def dispatch(self, request: Request, call_next):
|
|
# Extract user ID from JWT token or session
|
|
user_id = await self._extract_user_id(request)
|
|
endpoint = request.url.path
|
|
|
|
# Check rate limits
|
|
try:
|
|
status = await self.rate_limiter.check_rate_limit(
|
|
request=request,
|
|
user_id=user_id,
|
|
endpoint=endpoint
|
|
)
|
|
|
|
# Handle rate limit exceeded
|
|
if status.is_limited:
|
|
if status.action == RateLimitAction.BLOCK:
|
|
raise HTTPException(
|
|
status_code=429,
|
|
detail={
|
|
"error": "Rate limit exceeded",
|
|
"retry_after": status.retry_after,
|
|
"requests_remaining": status.requests_remaining,
|
|
"reset_time": status.reset_time.isoformat()
|
|
},
|
|
headers={"Retry-After": str(status.retry_after)}
|
|
)
|
|
elif status.action == RateLimitAction.THROTTLE:
|
|
# Add artificial delay
|
|
rule = next((r for r in self.rate_limiter.rules.values()
|
|
if r.action == RateLimitAction.THROTTLE), None)
|
|
if rule:
|
|
await asyncio.sleep(rule.throttle_delay)
|
|
|
|
# Process request
|
|
response = await call_next(request)
|
|
|
|
# Record successful request
|
|
await self.rate_limiter.record_request(
|
|
request=request,
|
|
user_id=user_id,
|
|
endpoint=endpoint
|
|
)
|
|
|
|
# Add rate limit headers to response
|
|
response.headers["X-RateLimit-Remaining"] = str(status.requests_remaining)
|
|
response.headers["X-RateLimit-Reset"] = str(int(status.reset_time.timestamp()))
|
|
|
|
return response
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Rate limiting error: {e}")
|
|
# Continue processing on rate limiter errors
|
|
return await call_next(request)
|
|
|
|
async def _extract_user_id(self, request: Request) -> Optional[str]:
|
|
"""Extract user ID from request (implement based on your auth system)"""
|
|
|
|
# Check for JWT token in Authorization header
|
|
auth_header = request.headers.get("Authorization")
|
|
if auth_header and auth_header.startswith("Bearer "):
|
|
token = auth_header[7:]
|
|
# Decode JWT token to get user ID
|
|
# This is a simplified example - implement proper JWT validation
|
|
try:
|
|
import jwt
|
|
payload = jwt.decode(token, options={"verify_signature": False})
|
|
return payload.get("user_id")
|
|
except:
|
|
pass
|
|
|
|
# Check for session cookie
|
|
session_id = request.cookies.get("session_id")
|
|
if session_id:
|
|
# Look up user ID from session store
|
|
# Implement based on your session management
|
|
pass
|
|
|
|
return None
|
|
|
|
|
|
# Usage example and configuration
|
|
def create_rate_limiter(redis_url: Optional[str] = None) -> AdvancedRateLimiter:
|
|
"""Create and configure a rate limiter instance"""
|
|
|
|
redis_client = None
|
|
if redis_url:
|
|
redis_client = redis.from_url(redis_url)
|
|
|
|
limiter = AdvancedRateLimiter(redis_client)
|
|
|
|
# Add custom rules for specific use cases
|
|
limiter.add_rule("habit_completion", RateLimitRule(
|
|
max_requests=50, # Max 50 habit completions per hour
|
|
window_seconds=3600,
|
|
limit_type=RateLimitType.USER_BASED,
|
|
endpoints=["/habits/*/complete"],
|
|
action=RateLimitAction.THROTTLE,
|
|
throttle_delay=1.0
|
|
))
|
|
|
|
limiter.add_rule("analytics_queries", RateLimitRule(
|
|
max_requests=200, # Max 200 analytics queries per hour
|
|
window_seconds=3600,
|
|
limit_type=RateLimitType.USER_BASED,
|
|
endpoints=["/analytics/*"],
|
|
burst_allowance=20
|
|
))
|
|
|
|
return limiter
|
|
|
|
|
|
# FastAPI dependency for rate limiting
|
|
async def get_rate_limit_info(request: Request,
|
|
rate_limiter: AdvancedRateLimiter) -> Dict:
|
|
"""FastAPI dependency to get rate limit information"""
|
|
|
|
user_id = await RateLimitMiddleware(None, rate_limiter)._extract_user_id(request)
|
|
return await rate_limiter.get_rate_limit_info(request, user_id) |