fix(security): comprehensive security hardening — TLS, HMAC WS auth, rate limiting, IP leak prevention

CRITICAL fixes:
- Auto-generated self-signed TLS certs (HTTPS/WSS by default)
- Removed session_key from /srp/verify response (was sent in plaintext)
- Replaced with HMAC-SHA256 ws_token for WebSocket authentication

HIGH fixes:
- WebSocket auth now validates ws_token via hmac.compare_digest()
- /clear endpoint requires Bearer admin_token (printed at server start)
- Password no longer required as CLI arg — supports env var + getpass prompt
- Removed user_ip from Message model (no longer broadcast to clients)

MEDIUM fixes:
- Rate limiter on /srp/init and /srp/verify (10 req/min/IP)
- MessageStore capped at 1000 messages (prevents RAM DoS)
- access_log disabled (was leaking request metadata)

LOW fixes:
- Username sanitization against rich markup injection
- Dead code removed from helpers.py

All 79 tests passing.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
leetcrypt 2026-05-25 20:30:40 -07:00
parent 440b67da26
commit e7bacc93da
11 changed files with 255 additions and 80 deletions

View File

@ -1,8 +1,19 @@
import argparse
import getpass
import os
from cmd_chat.server.server import run_server
from cmd_chat.client.client import Client
def resolve_password(args_password: str | None, prompt: str = "Room password: ") -> str:
if args_password:
return args_password
if env_pw := os.environ.get("CMD_CHAT_PASSWORD"):
return env_pw
return getpass.getpass(prompt)
def main():
parser = argparse.ArgumentParser(description="Command-line chat application")
subparsers = parser.add_subparsers(dest="command", required=True)
@ -10,24 +21,46 @@ def main():
serve_p = subparsers.add_parser("serve", help="Run server")
serve_p.add_argument("ip_address")
serve_p.add_argument("port")
serve_p.add_argument("--password", "-p", required=True)
serve_p.add_argument("--password", "-p", default=None)
serve_p.add_argument("--cert", default=None, help="Path to TLS certificate")
serve_p.add_argument("--key", default=None, help="Path to TLS private key")
serve_p.add_argument("--no-tls", action="store_true", help="Disable TLS (insecure)")
connect_p = subparsers.add_parser("connect", help="Connect to server")
connect_p.add_argument("ip_address")
connect_p.add_argument("port")
connect_p.add_argument("username")
connect_p.add_argument("password")
connect_p.add_argument("--password", "-p", default=None)
connect_p.add_argument(
"--insecure", "-k", action="store_true",
help="Skip TLS certificate verification (for self-signed certs)",
)
connect_p.add_argument(
"--no-tls", action="store_true",
help="Connect without TLS (insecure)",
)
args = parser.parse_args()
if args.command == "serve":
run_server(host=args.ip_address, port=int(args.port), password=args.password)
password = resolve_password(args.password)
run_server(
host=args.ip_address,
port=int(args.port),
password=password,
cert_path=args.cert,
key_path=args.key,
no_tls=args.no_tls,
)
elif args.command == "connect":
password = resolve_password(args.password)
Client(
server=args.ip_address,
port=int(args.port),
username=args.username,
password=args.password,
password=password,
insecure=args.insecure,
no_tls=args.no_tls,
).run()

View File

@ -1,5 +1,6 @@
import asyncio
import json
import ssl
import base64
from typing import Optional
@ -17,14 +18,22 @@ srp.rfc5054_enable()
class Client:
def __init__(
self, server: str, port: int, username: str, password: Optional[str] = None
self,
server: str,
port: int,
username: str,
password: Optional[str] = None,
insecure: bool = False,
no_tls: bool = False,
):
self.server = server
self.port = port
self.username = username
self.password = (password or "").encode()
self.insecure = insecure
self.no_tls = no_tls
self.user_id: Optional[str] = None
self.fernet: Optional[Fernet] = None
self.ws_token: Optional[str] = None
self.room_fernet: Optional[Fernet] = None
self.console = Console()
@ -35,23 +44,42 @@ class Client:
@property
def base_url(self) -> str:
return f"http://{self.server}:{self.port}"
scheme = "http" if self.no_tls else "https"
return f"{scheme}://{self.server}:{self.port}"
@property
def ws_url(self) -> str:
return f"ws://{self.server}:{self.port}"
scheme = "ws" if self.no_tls else "wss"
return f"{scheme}://{self.server}:{self.port}"
def _request_kwargs(self) -> dict:
kwargs: dict = {"timeout": 30}
if self.insecure and not self.no_tls:
kwargs["verify"] = False
return kwargs
def _ws_ssl_context(self) -> ssl.SSLContext | None:
if self.no_tls:
return None
if self.insecure:
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE
return ctx
return True # default verification
def success(self, message: str) -> None:
self.console.print(f"[green]✓ {message}[/]")
self.console.print(f"[green]{message}[/]")
def error(self, message: str) -> None:
self.console.print(f"[red]✗ {message}[/]")
self.console.print(f"[red]{message}[/]")
def info(self, message: str) -> None:
self.console.print(f"[cyan]• {message}[/]")
self.console.print(f"[cyan]{message}[/]")
def srp_authenticate(self) -> None:
with self.console.status("[cyan]Starting SRP handshake...[/]", spinner="dots"):
req_kwargs = self._request_kwargs()
usr = srp.User(b"chat", self.password, hash_alg=srp.SHA256)
_, A = usr.start_authentication()
@ -62,7 +90,7 @@ class Client:
"username": self.username,
"A": base64.b64encode(A).decode(),
},
timeout=30,
**req_kwargs,
)
resp.raise_for_status()
init_data = resp.json()
@ -93,7 +121,7 @@ class Client:
"username": self.username,
"M": base64.b64encode(M).decode(),
},
timeout=30,
**req_kwargs,
)
resp.raise_for_status()
verify_data = resp.json()
@ -104,8 +132,7 @@ class Client:
if not usr.authenticated():
raise ValueError("Server authentication failed")
session_key = base64.b64decode(verify_data["session_key"])
self.fernet = Fernet(session_key)
self.ws_token = verify_data["ws_token"]
self.success(f"SRP authenticated (session: {self.user_id[:8]}...)")
@ -118,29 +145,36 @@ class Client:
msg["text"] = "[decrypt failed]"
return msg
@staticmethod
def _safe_username(username: str) -> str:
return username.replace("[", "\\[")
def render_messages(self) -> None:
self.console.clear()
users_online = ", ".join(u.get("username", "?") for u in self.users) or "none"
users_online = (
", ".join(self._safe_username(u.get("username", "?")) for u in self.users)
or "none"
)
self.console.print(f"[dim]Online: {users_online}[/]")
self.console.print("" * 60)
self.console.print("-" * 60)
display_messages = (
self.messages[-15:] if len(self.messages) > 15 else self.messages
)
for msg in display_messages:
username = msg.get("username", "unknown")
username = self._safe_username(msg.get("username", "unknown"))
text = msg.get("text", "")
timestamp = str(msg.get("timestamp", ""))[:19].replace("T", " ")
style = "green" if username == self.username else "cyan"
style = "green" if msg.get("username") == self.username else "cyan"
self.console.print(f"[dim]{timestamp}[/] [{style}]{username}[/]: {text}")
if not display_messages:
self.console.print("[dim italic]No messages yet...[/]")
self.console.print("" * 60)
self.console.print("-" * 60)
self.console.print("[dim]Type message and press Enter. 'q' to quit.[/]")
async def receive_loop(self, ws) -> None:
@ -196,9 +230,10 @@ class Client:
self.srp_authenticate()
self.info("Connecting to chat...")
url = f"{self.ws_url}/ws/chat?user_id={self.user_id}"
url = f"{self.ws_url}/ws/chat?user_id={self.user_id}&ws_token={self.ws_token}"
async with websockets.connect(url) as ws:
ws_ssl = self._ws_ssl_context()
async with websockets.connect(url, ssl=ws_ssl) as ws:
self.success("Connected to chat server")
self.running = True

View File

@ -1,12 +1,13 @@
import asyncio
import secrets
from contextlib import suppress
from cryptography.fernet import Fernet
from sanic import Sanic
from sanic_ext import Extend
import os
from .managers import ConnectionManager
from .stores import MessageStore, UserSessionStore
from .srp_auth import SRPAuthManager
from .helpers import RateLimiter
from .routes import register_routes
@ -20,6 +21,9 @@ def create_app(password: str = "", name: str = "cmd-chat-server") -> Sanic:
app.ctx.connection_manager = ConnectionManager()
app.ctx.srp_manager = SRPAuthManager(password)
app.ctx.room_salt = os.urandom(16)
app.ctx.ws_secret = os.urandom(32)
app.ctx.admin_token = secrets.token_hex(16)
app.ctx.rate_limiter = RateLimiter(max_requests=10, window_seconds=60)
app.ctx.cleanup_task = None
register_lifecycle(app)

View File

@ -1,36 +1,21 @@
import time
from collections import defaultdict
from datetime import datetime, timezone
from typing import Optional
from dataclasses import asdict
import json
from sanic import Sanic, Request, response, Websocket
from sanic import Request, Sanic, Websocket
def utcnow() -> datetime:
return datetime.now(timezone.utc)
def verify_password(password: Optional[str], expected: Optional[str]) -> bool:
if not expected:
return True
return password == expected
def get_client_ip(request: Request) -> str:
if forwarded := request.headers.get("x-forwarded-for"):
return forwarded.split(",")[0].strip()
return request.ip
def get_param(request: Request, name: str) -> Optional[str]:
return request.args.get(name) or request.form.get(name)
def require_auth(request: Request, app: Sanic) -> Optional[response.HTTPResponse]:
if not verify_password(get_param(request, "password"), app.ctx.admin_password):
return response.text("Unauthorized", status=401)
return None
async def send_state(ws: Websocket, app: Sanic) -> None:
messages = app.ctx.message_store.get_all()
users = app.ctx.session_store.get_all()
@ -47,12 +32,17 @@ async def send_state(ws: Websocket, app: Sanic) -> None:
)
def extract_pubkey(request: Request) -> Optional[bytes]:
if files := request.files.get("pubkey"):
file = files[0] if isinstance(files, list) else files
return file.body
if raw := request.form.get("pubkey"):
return raw.encode() if isinstance(raw, str) else raw
if raw := request.args.get("pubkey"):
return raw.encode() if isinstance(raw, str) else raw
return None
class RateLimiter:
def __init__(self, max_requests: int = 10, window_seconds: int = 60):
self.max_requests = max_requests
self.window = window_seconds
self._requests: dict[str, list[float]] = defaultdict(list)
def is_allowed(self, key: str) -> bool:
now = time.monotonic()
timestamps = self._requests[key]
timestamps[:] = [t for t in timestamps if now - t < self.window]
if len(timestamps) >= self.max_requests:
return False
timestamps.append(now)
return True

View File

@ -11,7 +11,6 @@ class Message:
timestamp: str = field(
default_factory=lambda: datetime.now(timezone.utc).isoformat()
)
user_ip: str = ""
username: str = ""

View File

@ -1,19 +1,97 @@
import ipaddress
import ssl
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Optional
from cryptography import x509
from cryptography.x509.oid import NameOID
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import ec
from .factory import create_app
DEFAULT_CERT_DIR = Path.home() / ".cmd-chat" / "certs"
def ensure_tls_certs(cert_dir: Path = DEFAULT_CERT_DIR) -> tuple[Path, Path]:
cert_dir.mkdir(parents=True, exist_ok=True)
cert_path = cert_dir / "server.pem"
key_path = cert_dir / "server-key.pem"
if cert_path.exists() and key_path.exists():
return cert_path, key_path
key = ec.generate_private_key(ec.SECP256R1())
subject = issuer = x509.Name([
x509.NameAttribute(NameOID.COMMON_NAME, "cmd-chat"),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "cmd-chat-self-signed"),
])
cert = (
x509.CertificateBuilder()
.subject_name(subject)
.issuer_name(issuer)
.public_key(key.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(datetime.now(timezone.utc))
.not_valid_after(datetime.now(timezone.utc) + timedelta(days=365))
.add_extension(
x509.SubjectAlternativeName([
x509.DNSName("localhost"),
x509.IPAddress(ipaddress.IPv4Address("127.0.0.1")),
]),
critical=False,
)
.sign(key, hashes.SHA256())
)
key_path.write_bytes(
key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
)
)
cert_path.write_bytes(cert.public_bytes(serialization.Encoding.PEM))
return cert_path, key_path
def run_server(
host: str = "0.0.0.0",
port: int = 8000,
password: Optional[str] = None,
workers: int = 1,
cert_path: Optional[str] = None,
key_path: Optional[str] = None,
no_tls: bool = False,
) -> None:
app = create_app(password=password or "")
ssl_ctx = None
if not no_tls:
if cert_path and key_path:
c, k = Path(cert_path), Path(key_path)
else:
c, k = ensure_tls_certs()
print(f"[TLS] Auto-generated self-signed cert: {c}")
ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
ssl_ctx.load_cert_chain(str(c), str(k))
protocol = "https"
else:
print("[WARNING] TLS disabled — all traffic is unencrypted!")
protocol = "http"
print(f"[ADMIN] Clear token: {app.ctx.admin_token}")
print(f"[SERVER] Listening on {protocol}://{host}:{port}")
app.run(
host=host,
port=port,
single_process=True,
debug=False,
access_log=True,
access_log=False,
ssl=ssl_ctx,
)

View File

@ -3,11 +3,14 @@ from .models import Message, UserSession
class MessageStore:
def __init__(self):
def __init__(self, max_messages: int = 1000):
self._messages: list[Message] = []
self._max = max_messages
def add(self, message: Message) -> None:
self._messages.append(message)
if len(self._messages) > self._max:
self._messages = self._messages[-self._max:]
def get_all(self) -> list[Message]:
return self._messages.copy()

View File

@ -1,21 +1,26 @@
from dataclasses import asdict
import hashlib
import hmac
import json
import base64
from dataclasses import asdict
from sanic import Sanic, Request, response, Websocket
from sanic.response import HTTPResponse, json as json_response
from .models import Message, UserSession
from .helpers import (
get_client_ip,
send_state,
utcnow,
)
from .helpers import get_client_ip, send_state, utcnow
def generate_ws_token(user_id: str, secret: bytes) -> str:
return hmac.new(secret, user_id.encode(), hashlib.sha256).hexdigest()
async def srp_init(request: Request, app: Sanic) -> HTTPResponse:
try:
client_ip = get_client_ip(request)
if not app.ctx.rate_limiter.is_allowed(client_ip):
return response.json({"error": "Rate limited"}, status=429)
data = request.json or {}
username = data.get("username", "unknown")
client_public_b64 = data.get("A")
@ -45,6 +50,10 @@ async def srp_init(request: Request, app: Sanic) -> HTTPResponse:
async def srp_verify(request: Request, app: Sanic) -> HTTPResponse:
try:
client_ip = get_client_ip(request)
if not app.ctx.rate_limiter.is_allowed(client_ip):
return response.json({"error": "Rate limited"}, status=429)
data = request.json or {}
user_id = data.get("user_id")
client_proof_b64 = data.get("M")
@ -67,10 +76,12 @@ async def srp_verify(request: Request, app: Sanic) -> HTTPResponse:
)
app.ctx.session_store.add(session)
ws_token = generate_ws_token(user_id, app.ctx.ws_secret)
return response.json(
{
"H_AMK": base64.b64encode(H_AMK).decode(),
"session_key": base64.b64encode(fernet_key).decode(),
"ws_token": ws_token,
}
)
@ -82,9 +93,15 @@ async def srp_verify(request: Request, app: Sanic) -> HTTPResponse:
async def chat_ws(request: Request, ws: Websocket, app: Sanic) -> None:
user_id = request.args.get("user_id")
ws_token = request.args.get("ws_token")
if not user_id:
await ws.close(code=4002, reason="user_id required")
if not user_id or not ws_token:
await ws.close(code=4002, reason="user_id and ws_token required")
return
expected_token = generate_ws_token(user_id, app.ctx.ws_secret)
if not hmac.compare_digest(ws_token, expected_token):
await ws.close(code=4003, reason="Invalid token")
return
session = app.ctx.session_store.get(user_id)
@ -106,7 +123,6 @@ async def chat_ws(request: Request, ws: Websocket, app: Sanic) -> None:
message = Message(
text=str(data),
user_ip=session.ip,
username=session.username,
)
app.ctx.message_store.add(message)
@ -146,8 +162,12 @@ async def health(request: Request, app: Sanic) -> HTTPResponse:
async def clear_messages(request: Request, app: Sanic) -> HTTPResponse:
user_id = request.args.get("user_id")
if not user_id or not app.ctx.session_store.get(user_id):
auth_header = request.headers.get("authorization", "")
if not auth_header.startswith("Bearer "):
return response.json({"error": "Unauthorized"}, status=401)
token = auth_header[7:]
if not hmac.compare_digest(token, app.ctx.admin_token):
return response.json({"error": "Unauthorized"}, status=401)
app.ctx.message_store.clear()

View File

@ -28,6 +28,10 @@ def app():
app.ctx.connection_manager = ConnectionManager()
app.ctx.srp_manager = SRPAuthManager("testpassword")
app.ctx.room_salt = os.urandom(16)
app.ctx.ws_secret = os.urandom(32)
app.ctx.admin_token = "test-admin-token"
from cmd_chat.server.helpers import RateLimiter
app.ctx.rate_limiter = RateLimiter(max_requests=100, window_seconds=60)
app.ctx.cleanup_task = None
register_routes(app)

View File

@ -52,14 +52,19 @@ class TestClientInit:
assert client.username == "testuser"
assert client.password == b"testpassword"
assert client.user_id is None
assert client.fernet is None
assert client.ws_token is None
assert client.room_fernet is None
assert client.connected is False
assert client.running is False
def test_client_urls(self, client):
assert client.base_url == "http://127.0.0.1:3000"
assert client.ws_url == "ws://127.0.0.1:3000"
assert client.base_url == "https://127.0.0.1:3000"
assert client.ws_url == "wss://127.0.0.1:3000"
def test_client_no_tls_urls(self):
c = Client("127.0.0.1", 3000, "user", "pass", no_tls=True)
assert c.base_url == "http://127.0.0.1:3000"
assert c.ws_url == "ws://127.0.0.1:3000"
def test_client_empty_password(self):
client = Client("localhost", 8080, "user", None)

View File

@ -48,15 +48,20 @@ def room_fernet(room_salt):
class TestClientProperties:
def test_base_url_different_ports(self):
client = Client("example.com", 8080, "user", "pass")
assert client.base_url == "http://example.com:8080"
assert client.base_url == "https://example.com:8080"
def test_ws_url_different_ports(self):
client = Client("example.com", 8080, "user", "pass")
assert client.ws_url == "ws://example.com:8080"
assert client.ws_url == "wss://example.com:8080"
def test_base_url_localhost(self):
client = Client("localhost", 443, "user", "pass")
assert client.base_url == "http://localhost:443"
assert client.base_url == "https://localhost:443"
def test_no_tls_urls(self):
client = Client("example.com", 8080, "user", "pass", no_tls=True)
assert client.base_url == "http://example.com:8080"
assert client.ws_url == "ws://example.com:8080"
def test_password_encoding_unicode(self):
client = Client("localhost", 3000, "user", "пароль123")
@ -84,7 +89,7 @@ class TestSRPAuthentication:
verify_response = MagicMock()
verify_response.json.return_value = {
"H_AMK": base64.b64encode(os.urandom(32)).decode(),
"session_key": base64.b64encode(Fernet.generate_key()).decode(),
"ws_token": "test-ws-token-hex",
}
verify_response.raise_for_status = MagicMock()
@ -102,7 +107,7 @@ class TestSRPAuthentication:
assert client.user_id == "test-user-id-12345"
assert client.room_fernet is not None
assert client.fernet is not None
assert client.ws_token == "test-ws-token-hex"
@patch("cmd_chat.client.client.requests.post")
def test_srp_authenticate_init_fails(self, mock_post, client):
@ -178,7 +183,7 @@ class TestSRPAuthentication:
verify_response = MagicMock()
verify_response.json.return_value = {
"H_AMK": base64.b64encode(os.urandom(32)).decode(),
"session_key": base64.b64encode(Fernet.generate_key()).decode(),
"ws_token": "test-ws-token-hex",
}
verify_response.raise_for_status = MagicMock()
@ -237,7 +242,6 @@ class TestDecryptMessage:
"username": "sender",
"timestamp": "2024-01-01T12:00:00",
"id": "msg-123",
"user_ip": "192.168.1.1",
}
decrypted = client.decrypt_message(msg)
@ -246,7 +250,6 @@ class TestDecryptMessage:
assert decrypted["username"] == "sender"
assert decrypted["timestamp"] == "2024-01-01T12:00:00"
assert decrypted["id"] == "msg-123"
assert decrypted["user_ip"] == "192.168.1.1"
def test_decrypt_wrong_key_marks_failed(self, client):
@ -553,6 +556,7 @@ class TestRunAsync:
@pytest.mark.asyncio
async def test_run_successful_connection_and_disconnect(self, client):
client.user_id = "test-id-123"
client.ws_token = "test-token"
with patch.object(client, "srp_authenticate"):
with patch("cmd_chat.client.client.websockets.connect") as mock_connect:
@ -844,8 +848,8 @@ class TestEdgeCases:
def test_port_zero(self):
client = Client("localhost", 0, "user", "pass")
assert client.port == 0
assert client.base_url == "http://localhost:0"
assert client.base_url == "https://localhost:0"
def test_ipv6_server(self):
client = Client("::1", 3000, "user", "pass")
assert client.base_url == "http://::1:3000"
assert client.base_url == "https://::1:3000"