fix(security): comprehensive security hardening — TLS, HMAC WS auth, rate limiting, IP leak prevention
CRITICAL fixes: - Auto-generated self-signed TLS certs (HTTPS/WSS by default) - Removed session_key from /srp/verify response (was sent in plaintext) - Replaced with HMAC-SHA256 ws_token for WebSocket authentication HIGH fixes: - WebSocket auth now validates ws_token via hmac.compare_digest() - /clear endpoint requires Bearer admin_token (printed at server start) - Password no longer required as CLI arg — supports env var + getpass prompt - Removed user_ip from Message model (no longer broadcast to clients) MEDIUM fixes: - Rate limiter on /srp/init and /srp/verify (10 req/min/IP) - MessageStore capped at 1000 messages (prevents RAM DoS) - access_log disabled (was leaking request metadata) LOW fixes: - Username sanitization against rich markup injection - Dead code removed from helpers.py All 79 tests passing. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
440b67da26
commit
e7bacc93da
|
|
@ -1,8 +1,19 @@
|
|||
import argparse
|
||||
import getpass
|
||||
import os
|
||||
|
||||
from cmd_chat.server.server import run_server
|
||||
from cmd_chat.client.client import Client
|
||||
|
||||
|
||||
def resolve_password(args_password: str | None, prompt: str = "Room password: ") -> str:
|
||||
if args_password:
|
||||
return args_password
|
||||
if env_pw := os.environ.get("CMD_CHAT_PASSWORD"):
|
||||
return env_pw
|
||||
return getpass.getpass(prompt)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Command-line chat application")
|
||||
subparsers = parser.add_subparsers(dest="command", required=True)
|
||||
|
|
@ -10,24 +21,46 @@ def main():
|
|||
serve_p = subparsers.add_parser("serve", help="Run server")
|
||||
serve_p.add_argument("ip_address")
|
||||
serve_p.add_argument("port")
|
||||
serve_p.add_argument("--password", "-p", required=True)
|
||||
serve_p.add_argument("--password", "-p", default=None)
|
||||
serve_p.add_argument("--cert", default=None, help="Path to TLS certificate")
|
||||
serve_p.add_argument("--key", default=None, help="Path to TLS private key")
|
||||
serve_p.add_argument("--no-tls", action="store_true", help="Disable TLS (insecure)")
|
||||
|
||||
connect_p = subparsers.add_parser("connect", help="Connect to server")
|
||||
connect_p.add_argument("ip_address")
|
||||
connect_p.add_argument("port")
|
||||
connect_p.add_argument("username")
|
||||
connect_p.add_argument("password")
|
||||
connect_p.add_argument("--password", "-p", default=None)
|
||||
connect_p.add_argument(
|
||||
"--insecure", "-k", action="store_true",
|
||||
help="Skip TLS certificate verification (for self-signed certs)",
|
||||
)
|
||||
connect_p.add_argument(
|
||||
"--no-tls", action="store_true",
|
||||
help="Connect without TLS (insecure)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command == "serve":
|
||||
run_server(host=args.ip_address, port=int(args.port), password=args.password)
|
||||
password = resolve_password(args.password)
|
||||
run_server(
|
||||
host=args.ip_address,
|
||||
port=int(args.port),
|
||||
password=password,
|
||||
cert_path=args.cert,
|
||||
key_path=args.key,
|
||||
no_tls=args.no_tls,
|
||||
)
|
||||
elif args.command == "connect":
|
||||
password = resolve_password(args.password)
|
||||
Client(
|
||||
server=args.ip_address,
|
||||
port=int(args.port),
|
||||
username=args.username,
|
||||
password=args.password,
|
||||
password=password,
|
||||
insecure=args.insecure,
|
||||
no_tls=args.no_tls,
|
||||
).run()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import asyncio
|
||||
import json
|
||||
import ssl
|
||||
import base64
|
||||
from typing import Optional
|
||||
|
||||
|
|
@ -17,14 +18,22 @@ srp.rfc5054_enable()
|
|||
|
||||
class Client:
|
||||
def __init__(
|
||||
self, server: str, port: int, username: str, password: Optional[str] = None
|
||||
self,
|
||||
server: str,
|
||||
port: int,
|
||||
username: str,
|
||||
password: Optional[str] = None,
|
||||
insecure: bool = False,
|
||||
no_tls: bool = False,
|
||||
):
|
||||
self.server = server
|
||||
self.port = port
|
||||
self.username = username
|
||||
self.password = (password or "").encode()
|
||||
self.insecure = insecure
|
||||
self.no_tls = no_tls
|
||||
self.user_id: Optional[str] = None
|
||||
self.fernet: Optional[Fernet] = None
|
||||
self.ws_token: Optional[str] = None
|
||||
self.room_fernet: Optional[Fernet] = None
|
||||
|
||||
self.console = Console()
|
||||
|
|
@ -35,23 +44,42 @@ class Client:
|
|||
|
||||
@property
|
||||
def base_url(self) -> str:
|
||||
return f"http://{self.server}:{self.port}"
|
||||
scheme = "http" if self.no_tls else "https"
|
||||
return f"{scheme}://{self.server}:{self.port}"
|
||||
|
||||
@property
|
||||
def ws_url(self) -> str:
|
||||
return f"ws://{self.server}:{self.port}"
|
||||
scheme = "ws" if self.no_tls else "wss"
|
||||
return f"{scheme}://{self.server}:{self.port}"
|
||||
|
||||
def _request_kwargs(self) -> dict:
|
||||
kwargs: dict = {"timeout": 30}
|
||||
if self.insecure and not self.no_tls:
|
||||
kwargs["verify"] = False
|
||||
return kwargs
|
||||
|
||||
def _ws_ssl_context(self) -> ssl.SSLContext | None:
|
||||
if self.no_tls:
|
||||
return None
|
||||
if self.insecure:
|
||||
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
ctx.check_hostname = False
|
||||
ctx.verify_mode = ssl.CERT_NONE
|
||||
return ctx
|
||||
return True # default verification
|
||||
|
||||
def success(self, message: str) -> None:
|
||||
self.console.print(f"[green]✓ {message}[/]")
|
||||
self.console.print(f"[green]{message}[/]")
|
||||
|
||||
def error(self, message: str) -> None:
|
||||
self.console.print(f"[red]✗ {message}[/]")
|
||||
self.console.print(f"[red]{message}[/]")
|
||||
|
||||
def info(self, message: str) -> None:
|
||||
self.console.print(f"[cyan]• {message}[/]")
|
||||
self.console.print(f"[cyan]{message}[/]")
|
||||
|
||||
def srp_authenticate(self) -> None:
|
||||
with self.console.status("[cyan]Starting SRP handshake...[/]", spinner="dots"):
|
||||
req_kwargs = self._request_kwargs()
|
||||
|
||||
usr = srp.User(b"chat", self.password, hash_alg=srp.SHA256)
|
||||
_, A = usr.start_authentication()
|
||||
|
|
@ -62,7 +90,7 @@ class Client:
|
|||
"username": self.username,
|
||||
"A": base64.b64encode(A).decode(),
|
||||
},
|
||||
timeout=30,
|
||||
**req_kwargs,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
init_data = resp.json()
|
||||
|
|
@ -93,7 +121,7 @@ class Client:
|
|||
"username": self.username,
|
||||
"M": base64.b64encode(M).decode(),
|
||||
},
|
||||
timeout=30,
|
||||
**req_kwargs,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
verify_data = resp.json()
|
||||
|
|
@ -104,8 +132,7 @@ class Client:
|
|||
if not usr.authenticated():
|
||||
raise ValueError("Server authentication failed")
|
||||
|
||||
session_key = base64.b64decode(verify_data["session_key"])
|
||||
self.fernet = Fernet(session_key)
|
||||
self.ws_token = verify_data["ws_token"]
|
||||
|
||||
self.success(f"SRP authenticated (session: {self.user_id[:8]}...)")
|
||||
|
||||
|
|
@ -118,29 +145,36 @@ class Client:
|
|||
msg["text"] = "[decrypt failed]"
|
||||
return msg
|
||||
|
||||
@staticmethod
|
||||
def _safe_username(username: str) -> str:
|
||||
return username.replace("[", "\\[")
|
||||
|
||||
def render_messages(self) -> None:
|
||||
self.console.clear()
|
||||
|
||||
users_online = ", ".join(u.get("username", "?") for u in self.users) or "none"
|
||||
users_online = (
|
||||
", ".join(self._safe_username(u.get("username", "?")) for u in self.users)
|
||||
or "none"
|
||||
)
|
||||
self.console.print(f"[dim]Online: {users_online}[/]")
|
||||
self.console.print("─" * 60)
|
||||
self.console.print("-" * 60)
|
||||
|
||||
display_messages = (
|
||||
self.messages[-15:] if len(self.messages) > 15 else self.messages
|
||||
)
|
||||
|
||||
for msg in display_messages:
|
||||
username = msg.get("username", "unknown")
|
||||
username = self._safe_username(msg.get("username", "unknown"))
|
||||
text = msg.get("text", "")
|
||||
timestamp = str(msg.get("timestamp", ""))[:19].replace("T", " ")
|
||||
|
||||
style = "green" if username == self.username else "cyan"
|
||||
style = "green" if msg.get("username") == self.username else "cyan"
|
||||
self.console.print(f"[dim]{timestamp}[/] [{style}]{username}[/]: {text}")
|
||||
|
||||
if not display_messages:
|
||||
self.console.print("[dim italic]No messages yet...[/]")
|
||||
|
||||
self.console.print("─" * 60)
|
||||
self.console.print("-" * 60)
|
||||
self.console.print("[dim]Type message and press Enter. 'q' to quit.[/]")
|
||||
|
||||
async def receive_loop(self, ws) -> None:
|
||||
|
|
@ -196,9 +230,10 @@ class Client:
|
|||
self.srp_authenticate()
|
||||
|
||||
self.info("Connecting to chat...")
|
||||
url = f"{self.ws_url}/ws/chat?user_id={self.user_id}"
|
||||
url = f"{self.ws_url}/ws/chat?user_id={self.user_id}&ws_token={self.ws_token}"
|
||||
|
||||
async with websockets.connect(url) as ws:
|
||||
ws_ssl = self._ws_ssl_context()
|
||||
async with websockets.connect(url, ssl=ws_ssl) as ws:
|
||||
self.success("Connected to chat server")
|
||||
self.running = True
|
||||
|
||||
|
|
|
|||
|
|
@ -1,12 +1,13 @@
|
|||
import asyncio
|
||||
import secrets
|
||||
from contextlib import suppress
|
||||
from cryptography.fernet import Fernet
|
||||
from sanic import Sanic
|
||||
from sanic_ext import Extend
|
||||
import os
|
||||
from .managers import ConnectionManager
|
||||
from .stores import MessageStore, UserSessionStore
|
||||
from .srp_auth import SRPAuthManager
|
||||
from .helpers import RateLimiter
|
||||
|
||||
from .routes import register_routes
|
||||
|
||||
|
|
@ -20,6 +21,9 @@ def create_app(password: str = "", name: str = "cmd-chat-server") -> Sanic:
|
|||
app.ctx.connection_manager = ConnectionManager()
|
||||
app.ctx.srp_manager = SRPAuthManager(password)
|
||||
app.ctx.room_salt = os.urandom(16)
|
||||
app.ctx.ws_secret = os.urandom(32)
|
||||
app.ctx.admin_token = secrets.token_hex(16)
|
||||
app.ctx.rate_limiter = RateLimiter(max_requests=10, window_seconds=60)
|
||||
app.ctx.cleanup_task = None
|
||||
|
||||
register_lifecycle(app)
|
||||
|
|
|
|||
|
|
@ -1,36 +1,21 @@
|
|||
import time
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
from dataclasses import asdict
|
||||
import json
|
||||
from sanic import Sanic, Request, response, Websocket
|
||||
from sanic import Request, Sanic, Websocket
|
||||
|
||||
|
||||
def utcnow() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def verify_password(password: Optional[str], expected: Optional[str]) -> bool:
|
||||
if not expected:
|
||||
return True
|
||||
return password == expected
|
||||
|
||||
|
||||
def get_client_ip(request: Request) -> str:
|
||||
if forwarded := request.headers.get("x-forwarded-for"):
|
||||
return forwarded.split(",")[0].strip()
|
||||
return request.ip
|
||||
|
||||
|
||||
def get_param(request: Request, name: str) -> Optional[str]:
|
||||
return request.args.get(name) or request.form.get(name)
|
||||
|
||||
|
||||
def require_auth(request: Request, app: Sanic) -> Optional[response.HTTPResponse]:
|
||||
if not verify_password(get_param(request, "password"), app.ctx.admin_password):
|
||||
return response.text("Unauthorized", status=401)
|
||||
return None
|
||||
|
||||
|
||||
async def send_state(ws: Websocket, app: Sanic) -> None:
|
||||
messages = app.ctx.message_store.get_all()
|
||||
users = app.ctx.session_store.get_all()
|
||||
|
|
@ -47,12 +32,17 @@ async def send_state(ws: Websocket, app: Sanic) -> None:
|
|||
)
|
||||
|
||||
|
||||
def extract_pubkey(request: Request) -> Optional[bytes]:
|
||||
if files := request.files.get("pubkey"):
|
||||
file = files[0] if isinstance(files, list) else files
|
||||
return file.body
|
||||
if raw := request.form.get("pubkey"):
|
||||
return raw.encode() if isinstance(raw, str) else raw
|
||||
if raw := request.args.get("pubkey"):
|
||||
return raw.encode() if isinstance(raw, str) else raw
|
||||
return None
|
||||
class RateLimiter:
|
||||
def __init__(self, max_requests: int = 10, window_seconds: int = 60):
|
||||
self.max_requests = max_requests
|
||||
self.window = window_seconds
|
||||
self._requests: dict[str, list[float]] = defaultdict(list)
|
||||
|
||||
def is_allowed(self, key: str) -> bool:
|
||||
now = time.monotonic()
|
||||
timestamps = self._requests[key]
|
||||
timestamps[:] = [t for t in timestamps if now - t < self.window]
|
||||
if len(timestamps) >= self.max_requests:
|
||||
return False
|
||||
timestamps.append(now)
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ class Message:
|
|||
timestamp: str = field(
|
||||
default_factory=lambda: datetime.now(timezone.utc).isoformat()
|
||||
)
|
||||
user_ip: str = ""
|
||||
username: str = ""
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,19 +1,97 @@
|
|||
import ipaddress
|
||||
import ssl
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from cryptography import x509
|
||||
from cryptography.x509.oid import NameOID
|
||||
from cryptography.hazmat.primitives import hashes, serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import ec
|
||||
|
||||
from .factory import create_app
|
||||
|
||||
DEFAULT_CERT_DIR = Path.home() / ".cmd-chat" / "certs"
|
||||
|
||||
|
||||
def ensure_tls_certs(cert_dir: Path = DEFAULT_CERT_DIR) -> tuple[Path, Path]:
|
||||
cert_dir.mkdir(parents=True, exist_ok=True)
|
||||
cert_path = cert_dir / "server.pem"
|
||||
key_path = cert_dir / "server-key.pem"
|
||||
|
||||
if cert_path.exists() and key_path.exists():
|
||||
return cert_path, key_path
|
||||
|
||||
key = ec.generate_private_key(ec.SECP256R1())
|
||||
|
||||
subject = issuer = x509.Name([
|
||||
x509.NameAttribute(NameOID.COMMON_NAME, "cmd-chat"),
|
||||
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "cmd-chat-self-signed"),
|
||||
])
|
||||
|
||||
cert = (
|
||||
x509.CertificateBuilder()
|
||||
.subject_name(subject)
|
||||
.issuer_name(issuer)
|
||||
.public_key(key.public_key())
|
||||
.serial_number(x509.random_serial_number())
|
||||
.not_valid_before(datetime.now(timezone.utc))
|
||||
.not_valid_after(datetime.now(timezone.utc) + timedelta(days=365))
|
||||
.add_extension(
|
||||
x509.SubjectAlternativeName([
|
||||
x509.DNSName("localhost"),
|
||||
x509.IPAddress(ipaddress.IPv4Address("127.0.0.1")),
|
||||
]),
|
||||
critical=False,
|
||||
)
|
||||
.sign(key, hashes.SHA256())
|
||||
)
|
||||
|
||||
key_path.write_bytes(
|
||||
key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.TraditionalOpenSSL,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
)
|
||||
cert_path.write_bytes(cert.public_bytes(serialization.Encoding.PEM))
|
||||
|
||||
return cert_path, key_path
|
||||
|
||||
|
||||
def run_server(
|
||||
host: str = "0.0.0.0",
|
||||
port: int = 8000,
|
||||
password: Optional[str] = None,
|
||||
workers: int = 1,
|
||||
cert_path: Optional[str] = None,
|
||||
key_path: Optional[str] = None,
|
||||
no_tls: bool = False,
|
||||
) -> None:
|
||||
app = create_app(password=password or "")
|
||||
|
||||
ssl_ctx = None
|
||||
if not no_tls:
|
||||
if cert_path and key_path:
|
||||
c, k = Path(cert_path), Path(key_path)
|
||||
else:
|
||||
c, k = ensure_tls_certs()
|
||||
print(f"[TLS] Auto-generated self-signed cert: {c}")
|
||||
|
||||
ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||
ssl_ctx.load_cert_chain(str(c), str(k))
|
||||
protocol = "https"
|
||||
else:
|
||||
print("[WARNING] TLS disabled — all traffic is unencrypted!")
|
||||
protocol = "http"
|
||||
|
||||
print(f"[ADMIN] Clear token: {app.ctx.admin_token}")
|
||||
print(f"[SERVER] Listening on {protocol}://{host}:{port}")
|
||||
|
||||
app.run(
|
||||
host=host,
|
||||
port=port,
|
||||
single_process=True,
|
||||
debug=False,
|
||||
access_log=True,
|
||||
access_log=False,
|
||||
ssl=ssl_ctx,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -3,11 +3,14 @@ from .models import Message, UserSession
|
|||
|
||||
|
||||
class MessageStore:
|
||||
def __init__(self):
|
||||
def __init__(self, max_messages: int = 1000):
|
||||
self._messages: list[Message] = []
|
||||
self._max = max_messages
|
||||
|
||||
def add(self, message: Message) -> None:
|
||||
self._messages.append(message)
|
||||
if len(self._messages) > self._max:
|
||||
self._messages = self._messages[-self._max:]
|
||||
|
||||
def get_all(self) -> list[Message]:
|
||||
return self._messages.copy()
|
||||
|
|
|
|||
|
|
@ -1,21 +1,26 @@
|
|||
from dataclasses import asdict
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import base64
|
||||
from dataclasses import asdict
|
||||
|
||||
from sanic import Sanic, Request, response, Websocket
|
||||
from sanic.response import HTTPResponse, json as json_response
|
||||
|
||||
from .models import Message, UserSession
|
||||
from .helpers import (
|
||||
get_client_ip,
|
||||
send_state,
|
||||
utcnow,
|
||||
)
|
||||
from .helpers import get_client_ip, send_state, utcnow
|
||||
|
||||
|
||||
def generate_ws_token(user_id: str, secret: bytes) -> str:
|
||||
return hmac.new(secret, user_id.encode(), hashlib.sha256).hexdigest()
|
||||
|
||||
|
||||
async def srp_init(request: Request, app: Sanic) -> HTTPResponse:
|
||||
try:
|
||||
client_ip = get_client_ip(request)
|
||||
if not app.ctx.rate_limiter.is_allowed(client_ip):
|
||||
return response.json({"error": "Rate limited"}, status=429)
|
||||
|
||||
data = request.json or {}
|
||||
username = data.get("username", "unknown")
|
||||
client_public_b64 = data.get("A")
|
||||
|
|
@ -45,6 +50,10 @@ async def srp_init(request: Request, app: Sanic) -> HTTPResponse:
|
|||
|
||||
async def srp_verify(request: Request, app: Sanic) -> HTTPResponse:
|
||||
try:
|
||||
client_ip = get_client_ip(request)
|
||||
if not app.ctx.rate_limiter.is_allowed(client_ip):
|
||||
return response.json({"error": "Rate limited"}, status=429)
|
||||
|
||||
data = request.json or {}
|
||||
user_id = data.get("user_id")
|
||||
client_proof_b64 = data.get("M")
|
||||
|
|
@ -67,10 +76,12 @@ async def srp_verify(request: Request, app: Sanic) -> HTTPResponse:
|
|||
)
|
||||
app.ctx.session_store.add(session)
|
||||
|
||||
ws_token = generate_ws_token(user_id, app.ctx.ws_secret)
|
||||
|
||||
return response.json(
|
||||
{
|
||||
"H_AMK": base64.b64encode(H_AMK).decode(),
|
||||
"session_key": base64.b64encode(fernet_key).decode(),
|
||||
"ws_token": ws_token,
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -82,9 +93,15 @@ async def srp_verify(request: Request, app: Sanic) -> HTTPResponse:
|
|||
|
||||
async def chat_ws(request: Request, ws: Websocket, app: Sanic) -> None:
|
||||
user_id = request.args.get("user_id")
|
||||
ws_token = request.args.get("ws_token")
|
||||
|
||||
if not user_id:
|
||||
await ws.close(code=4002, reason="user_id required")
|
||||
if not user_id or not ws_token:
|
||||
await ws.close(code=4002, reason="user_id and ws_token required")
|
||||
return
|
||||
|
||||
expected_token = generate_ws_token(user_id, app.ctx.ws_secret)
|
||||
if not hmac.compare_digest(ws_token, expected_token):
|
||||
await ws.close(code=4003, reason="Invalid token")
|
||||
return
|
||||
|
||||
session = app.ctx.session_store.get(user_id)
|
||||
|
|
@ -106,7 +123,6 @@ async def chat_ws(request: Request, ws: Websocket, app: Sanic) -> None:
|
|||
|
||||
message = Message(
|
||||
text=str(data),
|
||||
user_ip=session.ip,
|
||||
username=session.username,
|
||||
)
|
||||
app.ctx.message_store.add(message)
|
||||
|
|
@ -146,8 +162,12 @@ async def health(request: Request, app: Sanic) -> HTTPResponse:
|
|||
|
||||
|
||||
async def clear_messages(request: Request, app: Sanic) -> HTTPResponse:
|
||||
user_id = request.args.get("user_id")
|
||||
if not user_id or not app.ctx.session_store.get(user_id):
|
||||
auth_header = request.headers.get("authorization", "")
|
||||
if not auth_header.startswith("Bearer "):
|
||||
return response.json({"error": "Unauthorized"}, status=401)
|
||||
|
||||
token = auth_header[7:]
|
||||
if not hmac.compare_digest(token, app.ctx.admin_token):
|
||||
return response.json({"error": "Unauthorized"}, status=401)
|
||||
|
||||
app.ctx.message_store.clear()
|
||||
|
|
|
|||
|
|
@ -28,6 +28,10 @@ def app():
|
|||
app.ctx.connection_manager = ConnectionManager()
|
||||
app.ctx.srp_manager = SRPAuthManager("testpassword")
|
||||
app.ctx.room_salt = os.urandom(16)
|
||||
app.ctx.ws_secret = os.urandom(32)
|
||||
app.ctx.admin_token = "test-admin-token"
|
||||
from cmd_chat.server.helpers import RateLimiter
|
||||
app.ctx.rate_limiter = RateLimiter(max_requests=100, window_seconds=60)
|
||||
app.ctx.cleanup_task = None
|
||||
|
||||
register_routes(app)
|
||||
|
|
|
|||
|
|
@ -52,14 +52,19 @@ class TestClientInit:
|
|||
assert client.username == "testuser"
|
||||
assert client.password == b"testpassword"
|
||||
assert client.user_id is None
|
||||
assert client.fernet is None
|
||||
assert client.ws_token is None
|
||||
assert client.room_fernet is None
|
||||
assert client.connected is False
|
||||
assert client.running is False
|
||||
|
||||
def test_client_urls(self, client):
|
||||
assert client.base_url == "http://127.0.0.1:3000"
|
||||
assert client.ws_url == "ws://127.0.0.1:3000"
|
||||
assert client.base_url == "https://127.0.0.1:3000"
|
||||
assert client.ws_url == "wss://127.0.0.1:3000"
|
||||
|
||||
def test_client_no_tls_urls(self):
|
||||
c = Client("127.0.0.1", 3000, "user", "pass", no_tls=True)
|
||||
assert c.base_url == "http://127.0.0.1:3000"
|
||||
assert c.ws_url == "ws://127.0.0.1:3000"
|
||||
|
||||
def test_client_empty_password(self):
|
||||
client = Client("localhost", 8080, "user", None)
|
||||
|
|
|
|||
|
|
@ -48,15 +48,20 @@ def room_fernet(room_salt):
|
|||
class TestClientProperties:
|
||||
def test_base_url_different_ports(self):
|
||||
client = Client("example.com", 8080, "user", "pass")
|
||||
assert client.base_url == "http://example.com:8080"
|
||||
assert client.base_url == "https://example.com:8080"
|
||||
|
||||
def test_ws_url_different_ports(self):
|
||||
client = Client("example.com", 8080, "user", "pass")
|
||||
assert client.ws_url == "ws://example.com:8080"
|
||||
assert client.ws_url == "wss://example.com:8080"
|
||||
|
||||
def test_base_url_localhost(self):
|
||||
client = Client("localhost", 443, "user", "pass")
|
||||
assert client.base_url == "http://localhost:443"
|
||||
assert client.base_url == "https://localhost:443"
|
||||
|
||||
def test_no_tls_urls(self):
|
||||
client = Client("example.com", 8080, "user", "pass", no_tls=True)
|
||||
assert client.base_url == "http://example.com:8080"
|
||||
assert client.ws_url == "ws://example.com:8080"
|
||||
|
||||
def test_password_encoding_unicode(self):
|
||||
client = Client("localhost", 3000, "user", "пароль123")
|
||||
|
|
@ -84,7 +89,7 @@ class TestSRPAuthentication:
|
|||
verify_response = MagicMock()
|
||||
verify_response.json.return_value = {
|
||||
"H_AMK": base64.b64encode(os.urandom(32)).decode(),
|
||||
"session_key": base64.b64encode(Fernet.generate_key()).decode(),
|
||||
"ws_token": "test-ws-token-hex",
|
||||
}
|
||||
verify_response.raise_for_status = MagicMock()
|
||||
|
||||
|
|
@ -102,7 +107,7 @@ class TestSRPAuthentication:
|
|||
|
||||
assert client.user_id == "test-user-id-12345"
|
||||
assert client.room_fernet is not None
|
||||
assert client.fernet is not None
|
||||
assert client.ws_token == "test-ws-token-hex"
|
||||
|
||||
@patch("cmd_chat.client.client.requests.post")
|
||||
def test_srp_authenticate_init_fails(self, mock_post, client):
|
||||
|
|
@ -178,7 +183,7 @@ class TestSRPAuthentication:
|
|||
verify_response = MagicMock()
|
||||
verify_response.json.return_value = {
|
||||
"H_AMK": base64.b64encode(os.urandom(32)).decode(),
|
||||
"session_key": base64.b64encode(Fernet.generate_key()).decode(),
|
||||
"ws_token": "test-ws-token-hex",
|
||||
}
|
||||
verify_response.raise_for_status = MagicMock()
|
||||
|
||||
|
|
@ -237,7 +242,6 @@ class TestDecryptMessage:
|
|||
"username": "sender",
|
||||
"timestamp": "2024-01-01T12:00:00",
|
||||
"id": "msg-123",
|
||||
"user_ip": "192.168.1.1",
|
||||
}
|
||||
|
||||
decrypted = client.decrypt_message(msg)
|
||||
|
|
@ -246,7 +250,6 @@ class TestDecryptMessage:
|
|||
assert decrypted["username"] == "sender"
|
||||
assert decrypted["timestamp"] == "2024-01-01T12:00:00"
|
||||
assert decrypted["id"] == "msg-123"
|
||||
assert decrypted["user_ip"] == "192.168.1.1"
|
||||
|
||||
def test_decrypt_wrong_key_marks_failed(self, client):
|
||||
|
||||
|
|
@ -553,6 +556,7 @@ class TestRunAsync:
|
|||
@pytest.mark.asyncio
|
||||
async def test_run_successful_connection_and_disconnect(self, client):
|
||||
client.user_id = "test-id-123"
|
||||
client.ws_token = "test-token"
|
||||
|
||||
with patch.object(client, "srp_authenticate"):
|
||||
with patch("cmd_chat.client.client.websockets.connect") as mock_connect:
|
||||
|
|
@ -844,8 +848,8 @@ class TestEdgeCases:
|
|||
def test_port_zero(self):
|
||||
client = Client("localhost", 0, "user", "pass")
|
||||
assert client.port == 0
|
||||
assert client.base_url == "http://localhost:0"
|
||||
assert client.base_url == "https://localhost:0"
|
||||
|
||||
def test_ipv6_server(self):
|
||||
client = Client("::1", 3000, "user", "pass")
|
||||
assert client.base_url == "http://::1:3000"
|
||||
assert client.base_url == "https://::1:3000"
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user