fix(security): comprehensive security hardening — TLS, HMAC WS auth, rate limiting, IP leak prevention
CRITICAL fixes: - Auto-generated self-signed TLS certs (HTTPS/WSS by default) - Removed session_key from /srp/verify response (was sent in plaintext) - Replaced with HMAC-SHA256 ws_token for WebSocket authentication HIGH fixes: - WebSocket auth now validates ws_token via hmac.compare_digest() - /clear endpoint requires Bearer admin_token (printed at server start) - Password no longer required as CLI arg — supports env var + getpass prompt - Removed user_ip from Message model (no longer broadcast to clients) MEDIUM fixes: - Rate limiter on /srp/init and /srp/verify (10 req/min/IP) - MessageStore capped at 1000 messages (prevents RAM DoS) - access_log disabled (was leaking request metadata) LOW fixes: - Username sanitization against rich markup injection - Dead code removed from helpers.py All 79 tests passing. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
440b67da26
commit
e7bacc93da
|
|
@ -1,8 +1,19 @@
|
||||||
import argparse
|
import argparse
|
||||||
|
import getpass
|
||||||
|
import os
|
||||||
|
|
||||||
from cmd_chat.server.server import run_server
|
from cmd_chat.server.server import run_server
|
||||||
from cmd_chat.client.client import Client
|
from cmd_chat.client.client import Client
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_password(args_password: str | None, prompt: str = "Room password: ") -> str:
|
||||||
|
if args_password:
|
||||||
|
return args_password
|
||||||
|
if env_pw := os.environ.get("CMD_CHAT_PASSWORD"):
|
||||||
|
return env_pw
|
||||||
|
return getpass.getpass(prompt)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description="Command-line chat application")
|
parser = argparse.ArgumentParser(description="Command-line chat application")
|
||||||
subparsers = parser.add_subparsers(dest="command", required=True)
|
subparsers = parser.add_subparsers(dest="command", required=True)
|
||||||
|
|
@ -10,24 +21,46 @@ def main():
|
||||||
serve_p = subparsers.add_parser("serve", help="Run server")
|
serve_p = subparsers.add_parser("serve", help="Run server")
|
||||||
serve_p.add_argument("ip_address")
|
serve_p.add_argument("ip_address")
|
||||||
serve_p.add_argument("port")
|
serve_p.add_argument("port")
|
||||||
serve_p.add_argument("--password", "-p", required=True)
|
serve_p.add_argument("--password", "-p", default=None)
|
||||||
|
serve_p.add_argument("--cert", default=None, help="Path to TLS certificate")
|
||||||
|
serve_p.add_argument("--key", default=None, help="Path to TLS private key")
|
||||||
|
serve_p.add_argument("--no-tls", action="store_true", help="Disable TLS (insecure)")
|
||||||
|
|
||||||
connect_p = subparsers.add_parser("connect", help="Connect to server")
|
connect_p = subparsers.add_parser("connect", help="Connect to server")
|
||||||
connect_p.add_argument("ip_address")
|
connect_p.add_argument("ip_address")
|
||||||
connect_p.add_argument("port")
|
connect_p.add_argument("port")
|
||||||
connect_p.add_argument("username")
|
connect_p.add_argument("username")
|
||||||
connect_p.add_argument("password")
|
connect_p.add_argument("--password", "-p", default=None)
|
||||||
|
connect_p.add_argument(
|
||||||
|
"--insecure", "-k", action="store_true",
|
||||||
|
help="Skip TLS certificate verification (for self-signed certs)",
|
||||||
|
)
|
||||||
|
connect_p.add_argument(
|
||||||
|
"--no-tls", action="store_true",
|
||||||
|
help="Connect without TLS (insecure)",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.command == "serve":
|
if args.command == "serve":
|
||||||
run_server(host=args.ip_address, port=int(args.port), password=args.password)
|
password = resolve_password(args.password)
|
||||||
|
run_server(
|
||||||
|
host=args.ip_address,
|
||||||
|
port=int(args.port),
|
||||||
|
password=password,
|
||||||
|
cert_path=args.cert,
|
||||||
|
key_path=args.key,
|
||||||
|
no_tls=args.no_tls,
|
||||||
|
)
|
||||||
elif args.command == "connect":
|
elif args.command == "connect":
|
||||||
|
password = resolve_password(args.password)
|
||||||
Client(
|
Client(
|
||||||
server=args.ip_address,
|
server=args.ip_address,
|
||||||
port=int(args.port),
|
port=int(args.port),
|
||||||
username=args.username,
|
username=args.username,
|
||||||
password=args.password,
|
password=password,
|
||||||
|
insecure=args.insecure,
|
||||||
|
no_tls=args.no_tls,
|
||||||
).run()
|
).run()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
import ssl
|
||||||
import base64
|
import base64
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
@ -17,14 +18,22 @@ srp.rfc5054_enable()
|
||||||
|
|
||||||
class Client:
|
class Client:
|
||||||
def __init__(
|
def __init__(
|
||||||
self, server: str, port: int, username: str, password: Optional[str] = None
|
self,
|
||||||
|
server: str,
|
||||||
|
port: int,
|
||||||
|
username: str,
|
||||||
|
password: Optional[str] = None,
|
||||||
|
insecure: bool = False,
|
||||||
|
no_tls: bool = False,
|
||||||
):
|
):
|
||||||
self.server = server
|
self.server = server
|
||||||
self.port = port
|
self.port = port
|
||||||
self.username = username
|
self.username = username
|
||||||
self.password = (password or "").encode()
|
self.password = (password or "").encode()
|
||||||
|
self.insecure = insecure
|
||||||
|
self.no_tls = no_tls
|
||||||
self.user_id: Optional[str] = None
|
self.user_id: Optional[str] = None
|
||||||
self.fernet: Optional[Fernet] = None
|
self.ws_token: Optional[str] = None
|
||||||
self.room_fernet: Optional[Fernet] = None
|
self.room_fernet: Optional[Fernet] = None
|
||||||
|
|
||||||
self.console = Console()
|
self.console = Console()
|
||||||
|
|
@ -35,23 +44,42 @@ class Client:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def base_url(self) -> str:
|
def base_url(self) -> str:
|
||||||
return f"http://{self.server}:{self.port}"
|
scheme = "http" if self.no_tls else "https"
|
||||||
|
return f"{scheme}://{self.server}:{self.port}"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ws_url(self) -> str:
|
def ws_url(self) -> str:
|
||||||
return f"ws://{self.server}:{self.port}"
|
scheme = "ws" if self.no_tls else "wss"
|
||||||
|
return f"{scheme}://{self.server}:{self.port}"
|
||||||
|
|
||||||
|
def _request_kwargs(self) -> dict:
|
||||||
|
kwargs: dict = {"timeout": 30}
|
||||||
|
if self.insecure and not self.no_tls:
|
||||||
|
kwargs["verify"] = False
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
def _ws_ssl_context(self) -> ssl.SSLContext | None:
|
||||||
|
if self.no_tls:
|
||||||
|
return None
|
||||||
|
if self.insecure:
|
||||||
|
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||||
|
ctx.check_hostname = False
|
||||||
|
ctx.verify_mode = ssl.CERT_NONE
|
||||||
|
return ctx
|
||||||
|
return True # default verification
|
||||||
|
|
||||||
def success(self, message: str) -> None:
|
def success(self, message: str) -> None:
|
||||||
self.console.print(f"[green]✓ {message}[/]")
|
self.console.print(f"[green]{message}[/]")
|
||||||
|
|
||||||
def error(self, message: str) -> None:
|
def error(self, message: str) -> None:
|
||||||
self.console.print(f"[red]✗ {message}[/]")
|
self.console.print(f"[red]{message}[/]")
|
||||||
|
|
||||||
def info(self, message: str) -> None:
|
def info(self, message: str) -> None:
|
||||||
self.console.print(f"[cyan]• {message}[/]")
|
self.console.print(f"[cyan]{message}[/]")
|
||||||
|
|
||||||
def srp_authenticate(self) -> None:
|
def srp_authenticate(self) -> None:
|
||||||
with self.console.status("[cyan]Starting SRP handshake...[/]", spinner="dots"):
|
with self.console.status("[cyan]Starting SRP handshake...[/]", spinner="dots"):
|
||||||
|
req_kwargs = self._request_kwargs()
|
||||||
|
|
||||||
usr = srp.User(b"chat", self.password, hash_alg=srp.SHA256)
|
usr = srp.User(b"chat", self.password, hash_alg=srp.SHA256)
|
||||||
_, A = usr.start_authentication()
|
_, A = usr.start_authentication()
|
||||||
|
|
@ -62,7 +90,7 @@ class Client:
|
||||||
"username": self.username,
|
"username": self.username,
|
||||||
"A": base64.b64encode(A).decode(),
|
"A": base64.b64encode(A).decode(),
|
||||||
},
|
},
|
||||||
timeout=30,
|
**req_kwargs,
|
||||||
)
|
)
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
init_data = resp.json()
|
init_data = resp.json()
|
||||||
|
|
@ -93,7 +121,7 @@ class Client:
|
||||||
"username": self.username,
|
"username": self.username,
|
||||||
"M": base64.b64encode(M).decode(),
|
"M": base64.b64encode(M).decode(),
|
||||||
},
|
},
|
||||||
timeout=30,
|
**req_kwargs,
|
||||||
)
|
)
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
verify_data = resp.json()
|
verify_data = resp.json()
|
||||||
|
|
@ -104,8 +132,7 @@ class Client:
|
||||||
if not usr.authenticated():
|
if not usr.authenticated():
|
||||||
raise ValueError("Server authentication failed")
|
raise ValueError("Server authentication failed")
|
||||||
|
|
||||||
session_key = base64.b64decode(verify_data["session_key"])
|
self.ws_token = verify_data["ws_token"]
|
||||||
self.fernet = Fernet(session_key)
|
|
||||||
|
|
||||||
self.success(f"SRP authenticated (session: {self.user_id[:8]}...)")
|
self.success(f"SRP authenticated (session: {self.user_id[:8]}...)")
|
||||||
|
|
||||||
|
|
@ -118,29 +145,36 @@ class Client:
|
||||||
msg["text"] = "[decrypt failed]"
|
msg["text"] = "[decrypt failed]"
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _safe_username(username: str) -> str:
|
||||||
|
return username.replace("[", "\\[")
|
||||||
|
|
||||||
def render_messages(self) -> None:
|
def render_messages(self) -> None:
|
||||||
self.console.clear()
|
self.console.clear()
|
||||||
|
|
||||||
users_online = ", ".join(u.get("username", "?") for u in self.users) or "none"
|
users_online = (
|
||||||
|
", ".join(self._safe_username(u.get("username", "?")) for u in self.users)
|
||||||
|
or "none"
|
||||||
|
)
|
||||||
self.console.print(f"[dim]Online: {users_online}[/]")
|
self.console.print(f"[dim]Online: {users_online}[/]")
|
||||||
self.console.print("─" * 60)
|
self.console.print("-" * 60)
|
||||||
|
|
||||||
display_messages = (
|
display_messages = (
|
||||||
self.messages[-15:] if len(self.messages) > 15 else self.messages
|
self.messages[-15:] if len(self.messages) > 15 else self.messages
|
||||||
)
|
)
|
||||||
|
|
||||||
for msg in display_messages:
|
for msg in display_messages:
|
||||||
username = msg.get("username", "unknown")
|
username = self._safe_username(msg.get("username", "unknown"))
|
||||||
text = msg.get("text", "")
|
text = msg.get("text", "")
|
||||||
timestamp = str(msg.get("timestamp", ""))[:19].replace("T", " ")
|
timestamp = str(msg.get("timestamp", ""))[:19].replace("T", " ")
|
||||||
|
|
||||||
style = "green" if username == self.username else "cyan"
|
style = "green" if msg.get("username") == self.username else "cyan"
|
||||||
self.console.print(f"[dim]{timestamp}[/] [{style}]{username}[/]: {text}")
|
self.console.print(f"[dim]{timestamp}[/] [{style}]{username}[/]: {text}")
|
||||||
|
|
||||||
if not display_messages:
|
if not display_messages:
|
||||||
self.console.print("[dim italic]No messages yet...[/]")
|
self.console.print("[dim italic]No messages yet...[/]")
|
||||||
|
|
||||||
self.console.print("─" * 60)
|
self.console.print("-" * 60)
|
||||||
self.console.print("[dim]Type message and press Enter. 'q' to quit.[/]")
|
self.console.print("[dim]Type message and press Enter. 'q' to quit.[/]")
|
||||||
|
|
||||||
async def receive_loop(self, ws) -> None:
|
async def receive_loop(self, ws) -> None:
|
||||||
|
|
@ -196,9 +230,10 @@ class Client:
|
||||||
self.srp_authenticate()
|
self.srp_authenticate()
|
||||||
|
|
||||||
self.info("Connecting to chat...")
|
self.info("Connecting to chat...")
|
||||||
url = f"{self.ws_url}/ws/chat?user_id={self.user_id}"
|
url = f"{self.ws_url}/ws/chat?user_id={self.user_id}&ws_token={self.ws_token}"
|
||||||
|
|
||||||
async with websockets.connect(url) as ws:
|
ws_ssl = self._ws_ssl_context()
|
||||||
|
async with websockets.connect(url, ssl=ws_ssl) as ws:
|
||||||
self.success("Connected to chat server")
|
self.success("Connected to chat server")
|
||||||
self.running = True
|
self.running = True
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,13 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import secrets
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from cryptography.fernet import Fernet
|
|
||||||
from sanic import Sanic
|
from sanic import Sanic
|
||||||
from sanic_ext import Extend
|
from sanic_ext import Extend
|
||||||
import os
|
import os
|
||||||
from .managers import ConnectionManager
|
from .managers import ConnectionManager
|
||||||
from .stores import MessageStore, UserSessionStore
|
from .stores import MessageStore, UserSessionStore
|
||||||
from .srp_auth import SRPAuthManager
|
from .srp_auth import SRPAuthManager
|
||||||
|
from .helpers import RateLimiter
|
||||||
|
|
||||||
from .routes import register_routes
|
from .routes import register_routes
|
||||||
|
|
||||||
|
|
@ -20,6 +21,9 @@ def create_app(password: str = "", name: str = "cmd-chat-server") -> Sanic:
|
||||||
app.ctx.connection_manager = ConnectionManager()
|
app.ctx.connection_manager = ConnectionManager()
|
||||||
app.ctx.srp_manager = SRPAuthManager(password)
|
app.ctx.srp_manager = SRPAuthManager(password)
|
||||||
app.ctx.room_salt = os.urandom(16)
|
app.ctx.room_salt = os.urandom(16)
|
||||||
|
app.ctx.ws_secret = os.urandom(32)
|
||||||
|
app.ctx.admin_token = secrets.token_hex(16)
|
||||||
|
app.ctx.rate_limiter = RateLimiter(max_requests=10, window_seconds=60)
|
||||||
app.ctx.cleanup_task = None
|
app.ctx.cleanup_task = None
|
||||||
|
|
||||||
register_lifecycle(app)
|
register_lifecycle(app)
|
||||||
|
|
|
||||||
|
|
@ -1,36 +1,21 @@
|
||||||
|
import time
|
||||||
|
from collections import defaultdict
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Optional
|
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
import json
|
import json
|
||||||
from sanic import Sanic, Request, response, Websocket
|
from sanic import Request, Sanic, Websocket
|
||||||
|
|
||||||
|
|
||||||
def utcnow() -> datetime:
|
def utcnow() -> datetime:
|
||||||
return datetime.now(timezone.utc)
|
return datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
def verify_password(password: Optional[str], expected: Optional[str]) -> bool:
|
|
||||||
if not expected:
|
|
||||||
return True
|
|
||||||
return password == expected
|
|
||||||
|
|
||||||
|
|
||||||
def get_client_ip(request: Request) -> str:
|
def get_client_ip(request: Request) -> str:
|
||||||
if forwarded := request.headers.get("x-forwarded-for"):
|
if forwarded := request.headers.get("x-forwarded-for"):
|
||||||
return forwarded.split(",")[0].strip()
|
return forwarded.split(",")[0].strip()
|
||||||
return request.ip
|
return request.ip
|
||||||
|
|
||||||
|
|
||||||
def get_param(request: Request, name: str) -> Optional[str]:
|
|
||||||
return request.args.get(name) or request.form.get(name)
|
|
||||||
|
|
||||||
|
|
||||||
def require_auth(request: Request, app: Sanic) -> Optional[response.HTTPResponse]:
|
|
||||||
if not verify_password(get_param(request, "password"), app.ctx.admin_password):
|
|
||||||
return response.text("Unauthorized", status=401)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def send_state(ws: Websocket, app: Sanic) -> None:
|
async def send_state(ws: Websocket, app: Sanic) -> None:
|
||||||
messages = app.ctx.message_store.get_all()
|
messages = app.ctx.message_store.get_all()
|
||||||
users = app.ctx.session_store.get_all()
|
users = app.ctx.session_store.get_all()
|
||||||
|
|
@ -47,12 +32,17 @@ async def send_state(ws: Websocket, app: Sanic) -> None:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def extract_pubkey(request: Request) -> Optional[bytes]:
|
class RateLimiter:
|
||||||
if files := request.files.get("pubkey"):
|
def __init__(self, max_requests: int = 10, window_seconds: int = 60):
|
||||||
file = files[0] if isinstance(files, list) else files
|
self.max_requests = max_requests
|
||||||
return file.body
|
self.window = window_seconds
|
||||||
if raw := request.form.get("pubkey"):
|
self._requests: dict[str, list[float]] = defaultdict(list)
|
||||||
return raw.encode() if isinstance(raw, str) else raw
|
|
||||||
if raw := request.args.get("pubkey"):
|
def is_allowed(self, key: str) -> bool:
|
||||||
return raw.encode() if isinstance(raw, str) else raw
|
now = time.monotonic()
|
||||||
return None
|
timestamps = self._requests[key]
|
||||||
|
timestamps[:] = [t for t in timestamps if now - t < self.window]
|
||||||
|
if len(timestamps) >= self.max_requests:
|
||||||
|
return False
|
||||||
|
timestamps.append(now)
|
||||||
|
return True
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,6 @@ class Message:
|
||||||
timestamp: str = field(
|
timestamp: str = field(
|
||||||
default_factory=lambda: datetime.now(timezone.utc).isoformat()
|
default_factory=lambda: datetime.now(timezone.utc).isoformat()
|
||||||
)
|
)
|
||||||
user_ip: str = ""
|
|
||||||
username: str = ""
|
username: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,19 +1,97 @@
|
||||||
|
import ipaddress
|
||||||
|
import ssl
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from cryptography import x509
|
||||||
|
from cryptography.x509.oid import NameOID
|
||||||
|
from cryptography.hazmat.primitives import hashes, serialization
|
||||||
|
from cryptography.hazmat.primitives.asymmetric import ec
|
||||||
|
|
||||||
from .factory import create_app
|
from .factory import create_app
|
||||||
|
|
||||||
|
DEFAULT_CERT_DIR = Path.home() / ".cmd-chat" / "certs"
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_tls_certs(cert_dir: Path = DEFAULT_CERT_DIR) -> tuple[Path, Path]:
|
||||||
|
cert_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
cert_path = cert_dir / "server.pem"
|
||||||
|
key_path = cert_dir / "server-key.pem"
|
||||||
|
|
||||||
|
if cert_path.exists() and key_path.exists():
|
||||||
|
return cert_path, key_path
|
||||||
|
|
||||||
|
key = ec.generate_private_key(ec.SECP256R1())
|
||||||
|
|
||||||
|
subject = issuer = x509.Name([
|
||||||
|
x509.NameAttribute(NameOID.COMMON_NAME, "cmd-chat"),
|
||||||
|
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "cmd-chat-self-signed"),
|
||||||
|
])
|
||||||
|
|
||||||
|
cert = (
|
||||||
|
x509.CertificateBuilder()
|
||||||
|
.subject_name(subject)
|
||||||
|
.issuer_name(issuer)
|
||||||
|
.public_key(key.public_key())
|
||||||
|
.serial_number(x509.random_serial_number())
|
||||||
|
.not_valid_before(datetime.now(timezone.utc))
|
||||||
|
.not_valid_after(datetime.now(timezone.utc) + timedelta(days=365))
|
||||||
|
.add_extension(
|
||||||
|
x509.SubjectAlternativeName([
|
||||||
|
x509.DNSName("localhost"),
|
||||||
|
x509.IPAddress(ipaddress.IPv4Address("127.0.0.1")),
|
||||||
|
]),
|
||||||
|
critical=False,
|
||||||
|
)
|
||||||
|
.sign(key, hashes.SHA256())
|
||||||
|
)
|
||||||
|
|
||||||
|
key_path.write_bytes(
|
||||||
|
key.private_bytes(
|
||||||
|
encoding=serialization.Encoding.PEM,
|
||||||
|
format=serialization.PrivateFormat.TraditionalOpenSSL,
|
||||||
|
encryption_algorithm=serialization.NoEncryption(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
cert_path.write_bytes(cert.public_bytes(serialization.Encoding.PEM))
|
||||||
|
|
||||||
|
return cert_path, key_path
|
||||||
|
|
||||||
|
|
||||||
def run_server(
|
def run_server(
|
||||||
host: str = "0.0.0.0",
|
host: str = "0.0.0.0",
|
||||||
port: int = 8000,
|
port: int = 8000,
|
||||||
password: Optional[str] = None,
|
password: Optional[str] = None,
|
||||||
workers: int = 1,
|
cert_path: Optional[str] = None,
|
||||||
|
key_path: Optional[str] = None,
|
||||||
|
no_tls: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
app = create_app(password=password or "")
|
app = create_app(password=password or "")
|
||||||
|
|
||||||
|
ssl_ctx = None
|
||||||
|
if not no_tls:
|
||||||
|
if cert_path and key_path:
|
||||||
|
c, k = Path(cert_path), Path(key_path)
|
||||||
|
else:
|
||||||
|
c, k = ensure_tls_certs()
|
||||||
|
print(f"[TLS] Auto-generated self-signed cert: {c}")
|
||||||
|
|
||||||
|
ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||||
|
ssl_ctx.load_cert_chain(str(c), str(k))
|
||||||
|
protocol = "https"
|
||||||
|
else:
|
||||||
|
print("[WARNING] TLS disabled — all traffic is unencrypted!")
|
||||||
|
protocol = "http"
|
||||||
|
|
||||||
|
print(f"[ADMIN] Clear token: {app.ctx.admin_token}")
|
||||||
|
print(f"[SERVER] Listening on {protocol}://{host}:{port}")
|
||||||
|
|
||||||
app.run(
|
app.run(
|
||||||
host=host,
|
host=host,
|
||||||
port=port,
|
port=port,
|
||||||
single_process=True,
|
single_process=True,
|
||||||
debug=False,
|
debug=False,
|
||||||
access_log=True,
|
access_log=False,
|
||||||
|
ssl=ssl_ctx,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -3,11 +3,14 @@ from .models import Message, UserSession
|
||||||
|
|
||||||
|
|
||||||
class MessageStore:
|
class MessageStore:
|
||||||
def __init__(self):
|
def __init__(self, max_messages: int = 1000):
|
||||||
self._messages: list[Message] = []
|
self._messages: list[Message] = []
|
||||||
|
self._max = max_messages
|
||||||
|
|
||||||
def add(self, message: Message) -> None:
|
def add(self, message: Message) -> None:
|
||||||
self._messages.append(message)
|
self._messages.append(message)
|
||||||
|
if len(self._messages) > self._max:
|
||||||
|
self._messages = self._messages[-self._max:]
|
||||||
|
|
||||||
def get_all(self) -> list[Message]:
|
def get_all(self) -> list[Message]:
|
||||||
return self._messages.copy()
|
return self._messages.copy()
|
||||||
|
|
|
||||||
|
|
@ -1,21 +1,26 @@
|
||||||
from dataclasses import asdict
|
import hashlib
|
||||||
|
import hmac
|
||||||
import json
|
import json
|
||||||
import base64
|
import base64
|
||||||
|
from dataclasses import asdict
|
||||||
|
|
||||||
from sanic import Sanic, Request, response, Websocket
|
from sanic import Sanic, Request, response, Websocket
|
||||||
from sanic.response import HTTPResponse, json as json_response
|
from sanic.response import HTTPResponse, json as json_response
|
||||||
|
|
||||||
from .models import Message, UserSession
|
from .models import Message, UserSession
|
||||||
from .helpers import (
|
from .helpers import get_client_ip, send_state, utcnow
|
||||||
get_client_ip,
|
|
||||||
send_state,
|
|
||||||
utcnow,
|
def generate_ws_token(user_id: str, secret: bytes) -> str:
|
||||||
)
|
return hmac.new(secret, user_id.encode(), hashlib.sha256).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
async def srp_init(request: Request, app: Sanic) -> HTTPResponse:
|
async def srp_init(request: Request, app: Sanic) -> HTTPResponse:
|
||||||
try:
|
try:
|
||||||
|
client_ip = get_client_ip(request)
|
||||||
|
if not app.ctx.rate_limiter.is_allowed(client_ip):
|
||||||
|
return response.json({"error": "Rate limited"}, status=429)
|
||||||
|
|
||||||
data = request.json or {}
|
data = request.json or {}
|
||||||
username = data.get("username", "unknown")
|
username = data.get("username", "unknown")
|
||||||
client_public_b64 = data.get("A")
|
client_public_b64 = data.get("A")
|
||||||
|
|
@ -45,6 +50,10 @@ async def srp_init(request: Request, app: Sanic) -> HTTPResponse:
|
||||||
|
|
||||||
async def srp_verify(request: Request, app: Sanic) -> HTTPResponse:
|
async def srp_verify(request: Request, app: Sanic) -> HTTPResponse:
|
||||||
try:
|
try:
|
||||||
|
client_ip = get_client_ip(request)
|
||||||
|
if not app.ctx.rate_limiter.is_allowed(client_ip):
|
||||||
|
return response.json({"error": "Rate limited"}, status=429)
|
||||||
|
|
||||||
data = request.json or {}
|
data = request.json or {}
|
||||||
user_id = data.get("user_id")
|
user_id = data.get("user_id")
|
||||||
client_proof_b64 = data.get("M")
|
client_proof_b64 = data.get("M")
|
||||||
|
|
@ -67,10 +76,12 @@ async def srp_verify(request: Request, app: Sanic) -> HTTPResponse:
|
||||||
)
|
)
|
||||||
app.ctx.session_store.add(session)
|
app.ctx.session_store.add(session)
|
||||||
|
|
||||||
|
ws_token = generate_ws_token(user_id, app.ctx.ws_secret)
|
||||||
|
|
||||||
return response.json(
|
return response.json(
|
||||||
{
|
{
|
||||||
"H_AMK": base64.b64encode(H_AMK).decode(),
|
"H_AMK": base64.b64encode(H_AMK).decode(),
|
||||||
"session_key": base64.b64encode(fernet_key).decode(),
|
"ws_token": ws_token,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -82,9 +93,15 @@ async def srp_verify(request: Request, app: Sanic) -> HTTPResponse:
|
||||||
|
|
||||||
async def chat_ws(request: Request, ws: Websocket, app: Sanic) -> None:
|
async def chat_ws(request: Request, ws: Websocket, app: Sanic) -> None:
|
||||||
user_id = request.args.get("user_id")
|
user_id = request.args.get("user_id")
|
||||||
|
ws_token = request.args.get("ws_token")
|
||||||
|
|
||||||
if not user_id:
|
if not user_id or not ws_token:
|
||||||
await ws.close(code=4002, reason="user_id required")
|
await ws.close(code=4002, reason="user_id and ws_token required")
|
||||||
|
return
|
||||||
|
|
||||||
|
expected_token = generate_ws_token(user_id, app.ctx.ws_secret)
|
||||||
|
if not hmac.compare_digest(ws_token, expected_token):
|
||||||
|
await ws.close(code=4003, reason="Invalid token")
|
||||||
return
|
return
|
||||||
|
|
||||||
session = app.ctx.session_store.get(user_id)
|
session = app.ctx.session_store.get(user_id)
|
||||||
|
|
@ -106,7 +123,6 @@ async def chat_ws(request: Request, ws: Websocket, app: Sanic) -> None:
|
||||||
|
|
||||||
message = Message(
|
message = Message(
|
||||||
text=str(data),
|
text=str(data),
|
||||||
user_ip=session.ip,
|
|
||||||
username=session.username,
|
username=session.username,
|
||||||
)
|
)
|
||||||
app.ctx.message_store.add(message)
|
app.ctx.message_store.add(message)
|
||||||
|
|
@ -146,8 +162,12 @@ async def health(request: Request, app: Sanic) -> HTTPResponse:
|
||||||
|
|
||||||
|
|
||||||
async def clear_messages(request: Request, app: Sanic) -> HTTPResponse:
|
async def clear_messages(request: Request, app: Sanic) -> HTTPResponse:
|
||||||
user_id = request.args.get("user_id")
|
auth_header = request.headers.get("authorization", "")
|
||||||
if not user_id or not app.ctx.session_store.get(user_id):
|
if not auth_header.startswith("Bearer "):
|
||||||
|
return response.json({"error": "Unauthorized"}, status=401)
|
||||||
|
|
||||||
|
token = auth_header[7:]
|
||||||
|
if not hmac.compare_digest(token, app.ctx.admin_token):
|
||||||
return response.json({"error": "Unauthorized"}, status=401)
|
return response.json({"error": "Unauthorized"}, status=401)
|
||||||
|
|
||||||
app.ctx.message_store.clear()
|
app.ctx.message_store.clear()
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,10 @@ def app():
|
||||||
app.ctx.connection_manager = ConnectionManager()
|
app.ctx.connection_manager = ConnectionManager()
|
||||||
app.ctx.srp_manager = SRPAuthManager("testpassword")
|
app.ctx.srp_manager = SRPAuthManager("testpassword")
|
||||||
app.ctx.room_salt = os.urandom(16)
|
app.ctx.room_salt = os.urandom(16)
|
||||||
|
app.ctx.ws_secret = os.urandom(32)
|
||||||
|
app.ctx.admin_token = "test-admin-token"
|
||||||
|
from cmd_chat.server.helpers import RateLimiter
|
||||||
|
app.ctx.rate_limiter = RateLimiter(max_requests=100, window_seconds=60)
|
||||||
app.ctx.cleanup_task = None
|
app.ctx.cleanup_task = None
|
||||||
|
|
||||||
register_routes(app)
|
register_routes(app)
|
||||||
|
|
|
||||||
|
|
@ -52,14 +52,19 @@ class TestClientInit:
|
||||||
assert client.username == "testuser"
|
assert client.username == "testuser"
|
||||||
assert client.password == b"testpassword"
|
assert client.password == b"testpassword"
|
||||||
assert client.user_id is None
|
assert client.user_id is None
|
||||||
assert client.fernet is None
|
assert client.ws_token is None
|
||||||
assert client.room_fernet is None
|
assert client.room_fernet is None
|
||||||
assert client.connected is False
|
assert client.connected is False
|
||||||
assert client.running is False
|
assert client.running is False
|
||||||
|
|
||||||
def test_client_urls(self, client):
|
def test_client_urls(self, client):
|
||||||
assert client.base_url == "http://127.0.0.1:3000"
|
assert client.base_url == "https://127.0.0.1:3000"
|
||||||
assert client.ws_url == "ws://127.0.0.1:3000"
|
assert client.ws_url == "wss://127.0.0.1:3000"
|
||||||
|
|
||||||
|
def test_client_no_tls_urls(self):
|
||||||
|
c = Client("127.0.0.1", 3000, "user", "pass", no_tls=True)
|
||||||
|
assert c.base_url == "http://127.0.0.1:3000"
|
||||||
|
assert c.ws_url == "ws://127.0.0.1:3000"
|
||||||
|
|
||||||
def test_client_empty_password(self):
|
def test_client_empty_password(self):
|
||||||
client = Client("localhost", 8080, "user", None)
|
client = Client("localhost", 8080, "user", None)
|
||||||
|
|
|
||||||
|
|
@ -48,15 +48,20 @@ def room_fernet(room_salt):
|
||||||
class TestClientProperties:
|
class TestClientProperties:
|
||||||
def test_base_url_different_ports(self):
|
def test_base_url_different_ports(self):
|
||||||
client = Client("example.com", 8080, "user", "pass")
|
client = Client("example.com", 8080, "user", "pass")
|
||||||
assert client.base_url == "http://example.com:8080"
|
assert client.base_url == "https://example.com:8080"
|
||||||
|
|
||||||
def test_ws_url_different_ports(self):
|
def test_ws_url_different_ports(self):
|
||||||
client = Client("example.com", 8080, "user", "pass")
|
client = Client("example.com", 8080, "user", "pass")
|
||||||
assert client.ws_url == "ws://example.com:8080"
|
assert client.ws_url == "wss://example.com:8080"
|
||||||
|
|
||||||
def test_base_url_localhost(self):
|
def test_base_url_localhost(self):
|
||||||
client = Client("localhost", 443, "user", "pass")
|
client = Client("localhost", 443, "user", "pass")
|
||||||
assert client.base_url == "http://localhost:443"
|
assert client.base_url == "https://localhost:443"
|
||||||
|
|
||||||
|
def test_no_tls_urls(self):
|
||||||
|
client = Client("example.com", 8080, "user", "pass", no_tls=True)
|
||||||
|
assert client.base_url == "http://example.com:8080"
|
||||||
|
assert client.ws_url == "ws://example.com:8080"
|
||||||
|
|
||||||
def test_password_encoding_unicode(self):
|
def test_password_encoding_unicode(self):
|
||||||
client = Client("localhost", 3000, "user", "пароль123")
|
client = Client("localhost", 3000, "user", "пароль123")
|
||||||
|
|
@ -84,7 +89,7 @@ class TestSRPAuthentication:
|
||||||
verify_response = MagicMock()
|
verify_response = MagicMock()
|
||||||
verify_response.json.return_value = {
|
verify_response.json.return_value = {
|
||||||
"H_AMK": base64.b64encode(os.urandom(32)).decode(),
|
"H_AMK": base64.b64encode(os.urandom(32)).decode(),
|
||||||
"session_key": base64.b64encode(Fernet.generate_key()).decode(),
|
"ws_token": "test-ws-token-hex",
|
||||||
}
|
}
|
||||||
verify_response.raise_for_status = MagicMock()
|
verify_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
|
@ -102,7 +107,7 @@ class TestSRPAuthentication:
|
||||||
|
|
||||||
assert client.user_id == "test-user-id-12345"
|
assert client.user_id == "test-user-id-12345"
|
||||||
assert client.room_fernet is not None
|
assert client.room_fernet is not None
|
||||||
assert client.fernet is not None
|
assert client.ws_token == "test-ws-token-hex"
|
||||||
|
|
||||||
@patch("cmd_chat.client.client.requests.post")
|
@patch("cmd_chat.client.client.requests.post")
|
||||||
def test_srp_authenticate_init_fails(self, mock_post, client):
|
def test_srp_authenticate_init_fails(self, mock_post, client):
|
||||||
|
|
@ -178,7 +183,7 @@ class TestSRPAuthentication:
|
||||||
verify_response = MagicMock()
|
verify_response = MagicMock()
|
||||||
verify_response.json.return_value = {
|
verify_response.json.return_value = {
|
||||||
"H_AMK": base64.b64encode(os.urandom(32)).decode(),
|
"H_AMK": base64.b64encode(os.urandom(32)).decode(),
|
||||||
"session_key": base64.b64encode(Fernet.generate_key()).decode(),
|
"ws_token": "test-ws-token-hex",
|
||||||
}
|
}
|
||||||
verify_response.raise_for_status = MagicMock()
|
verify_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
|
@ -237,7 +242,6 @@ class TestDecryptMessage:
|
||||||
"username": "sender",
|
"username": "sender",
|
||||||
"timestamp": "2024-01-01T12:00:00",
|
"timestamp": "2024-01-01T12:00:00",
|
||||||
"id": "msg-123",
|
"id": "msg-123",
|
||||||
"user_ip": "192.168.1.1",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
decrypted = client.decrypt_message(msg)
|
decrypted = client.decrypt_message(msg)
|
||||||
|
|
@ -246,7 +250,6 @@ class TestDecryptMessage:
|
||||||
assert decrypted["username"] == "sender"
|
assert decrypted["username"] == "sender"
|
||||||
assert decrypted["timestamp"] == "2024-01-01T12:00:00"
|
assert decrypted["timestamp"] == "2024-01-01T12:00:00"
|
||||||
assert decrypted["id"] == "msg-123"
|
assert decrypted["id"] == "msg-123"
|
||||||
assert decrypted["user_ip"] == "192.168.1.1"
|
|
||||||
|
|
||||||
def test_decrypt_wrong_key_marks_failed(self, client):
|
def test_decrypt_wrong_key_marks_failed(self, client):
|
||||||
|
|
||||||
|
|
@ -553,6 +556,7 @@ class TestRunAsync:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_run_successful_connection_and_disconnect(self, client):
|
async def test_run_successful_connection_and_disconnect(self, client):
|
||||||
client.user_id = "test-id-123"
|
client.user_id = "test-id-123"
|
||||||
|
client.ws_token = "test-token"
|
||||||
|
|
||||||
with patch.object(client, "srp_authenticate"):
|
with patch.object(client, "srp_authenticate"):
|
||||||
with patch("cmd_chat.client.client.websockets.connect") as mock_connect:
|
with patch("cmd_chat.client.client.websockets.connect") as mock_connect:
|
||||||
|
|
@ -844,8 +848,8 @@ class TestEdgeCases:
|
||||||
def test_port_zero(self):
|
def test_port_zero(self):
|
||||||
client = Client("localhost", 0, "user", "pass")
|
client = Client("localhost", 0, "user", "pass")
|
||||||
assert client.port == 0
|
assert client.port == 0
|
||||||
assert client.base_url == "http://localhost:0"
|
assert client.base_url == "https://localhost:0"
|
||||||
|
|
||||||
def test_ipv6_server(self):
|
def test_ipv6_server(self):
|
||||||
client = Client("::1", 3000, "user", "pass")
|
client = Client("::1", 3000, "user", "pass")
|
||||||
assert client.base_url == "http://::1:3000"
|
assert client.base_url == "https://::1:3000"
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user