fix(security): comprehensive security hardening — TLS, HMAC WS auth, rate limiting, IP leak prevention

CRITICAL fixes:
- Auto-generated self-signed TLS certs (HTTPS/WSS by default)
- Removed session_key from /srp/verify response (was sent in plaintext)
- Replaced with HMAC-SHA256 ws_token for WebSocket authentication

HIGH fixes:
- WebSocket auth now validates ws_token via hmac.compare_digest()
- /clear endpoint requires Bearer admin_token (printed at server start)
- Password no longer required as CLI arg — supports env var + getpass prompt
- Removed user_ip from Message model (no longer broadcast to clients)

MEDIUM fixes:
- Rate limiter on /srp/init and /srp/verify (10 req/min/IP)
- MessageStore capped at 1000 messages (prevents RAM DoS)
- access_log disabled (was leaking request metadata)

LOW fixes:
- Username sanitization against rich markup injection
- Dead code removed from helpers.py

All 79 tests passing.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
leetcrypt 2026-05-25 20:30:40 -07:00
parent 440b67da26
commit e7bacc93da
11 changed files with 255 additions and 80 deletions

View File

@ -1,8 +1,19 @@
import argparse import argparse
import getpass
import os
from cmd_chat.server.server import run_server from cmd_chat.server.server import run_server
from cmd_chat.client.client import Client from cmd_chat.client.client import Client
def resolve_password(args_password: str | None, prompt: str = "Room password: ") -> str:
if args_password:
return args_password
if env_pw := os.environ.get("CMD_CHAT_PASSWORD"):
return env_pw
return getpass.getpass(prompt)
def main(): def main():
parser = argparse.ArgumentParser(description="Command-line chat application") parser = argparse.ArgumentParser(description="Command-line chat application")
subparsers = parser.add_subparsers(dest="command", required=True) subparsers = parser.add_subparsers(dest="command", required=True)
@ -10,24 +21,46 @@ def main():
serve_p = subparsers.add_parser("serve", help="Run server") serve_p = subparsers.add_parser("serve", help="Run server")
serve_p.add_argument("ip_address") serve_p.add_argument("ip_address")
serve_p.add_argument("port") serve_p.add_argument("port")
serve_p.add_argument("--password", "-p", required=True) serve_p.add_argument("--password", "-p", default=None)
serve_p.add_argument("--cert", default=None, help="Path to TLS certificate")
serve_p.add_argument("--key", default=None, help="Path to TLS private key")
serve_p.add_argument("--no-tls", action="store_true", help="Disable TLS (insecure)")
connect_p = subparsers.add_parser("connect", help="Connect to server") connect_p = subparsers.add_parser("connect", help="Connect to server")
connect_p.add_argument("ip_address") connect_p.add_argument("ip_address")
connect_p.add_argument("port") connect_p.add_argument("port")
connect_p.add_argument("username") connect_p.add_argument("username")
connect_p.add_argument("password") connect_p.add_argument("--password", "-p", default=None)
connect_p.add_argument(
"--insecure", "-k", action="store_true",
help="Skip TLS certificate verification (for self-signed certs)",
)
connect_p.add_argument(
"--no-tls", action="store_true",
help="Connect without TLS (insecure)",
)
args = parser.parse_args() args = parser.parse_args()
if args.command == "serve": if args.command == "serve":
run_server(host=args.ip_address, port=int(args.port), password=args.password) password = resolve_password(args.password)
run_server(
host=args.ip_address,
port=int(args.port),
password=password,
cert_path=args.cert,
key_path=args.key,
no_tls=args.no_tls,
)
elif args.command == "connect": elif args.command == "connect":
password = resolve_password(args.password)
Client( Client(
server=args.ip_address, server=args.ip_address,
port=int(args.port), port=int(args.port),
username=args.username, username=args.username,
password=args.password, password=password,
insecure=args.insecure,
no_tls=args.no_tls,
).run() ).run()

View File

@ -1,5 +1,6 @@
import asyncio import asyncio
import json import json
import ssl
import base64 import base64
from typing import Optional from typing import Optional
@ -17,14 +18,22 @@ srp.rfc5054_enable()
class Client: class Client:
def __init__( def __init__(
self, server: str, port: int, username: str, password: Optional[str] = None self,
server: str,
port: int,
username: str,
password: Optional[str] = None,
insecure: bool = False,
no_tls: bool = False,
): ):
self.server = server self.server = server
self.port = port self.port = port
self.username = username self.username = username
self.password = (password or "").encode() self.password = (password or "").encode()
self.insecure = insecure
self.no_tls = no_tls
self.user_id: Optional[str] = None self.user_id: Optional[str] = None
self.fernet: Optional[Fernet] = None self.ws_token: Optional[str] = None
self.room_fernet: Optional[Fernet] = None self.room_fernet: Optional[Fernet] = None
self.console = Console() self.console = Console()
@ -35,23 +44,42 @@ class Client:
@property @property
def base_url(self) -> str: def base_url(self) -> str:
return f"http://{self.server}:{self.port}" scheme = "http" if self.no_tls else "https"
return f"{scheme}://{self.server}:{self.port}"
@property @property
def ws_url(self) -> str: def ws_url(self) -> str:
return f"ws://{self.server}:{self.port}" scheme = "ws" if self.no_tls else "wss"
return f"{scheme}://{self.server}:{self.port}"
def _request_kwargs(self) -> dict:
kwargs: dict = {"timeout": 30}
if self.insecure and not self.no_tls:
kwargs["verify"] = False
return kwargs
def _ws_ssl_context(self) -> ssl.SSLContext | None:
if self.no_tls:
return None
if self.insecure:
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE
return ctx
return True # default verification
def success(self, message: str) -> None: def success(self, message: str) -> None:
self.console.print(f"[green]✓ {message}[/]") self.console.print(f"[green]{message}[/]")
def error(self, message: str) -> None: def error(self, message: str) -> None:
self.console.print(f"[red]✗ {message}[/]") self.console.print(f"[red]{message}[/]")
def info(self, message: str) -> None: def info(self, message: str) -> None:
self.console.print(f"[cyan]• {message}[/]") self.console.print(f"[cyan]{message}[/]")
def srp_authenticate(self) -> None: def srp_authenticate(self) -> None:
with self.console.status("[cyan]Starting SRP handshake...[/]", spinner="dots"): with self.console.status("[cyan]Starting SRP handshake...[/]", spinner="dots"):
req_kwargs = self._request_kwargs()
usr = srp.User(b"chat", self.password, hash_alg=srp.SHA256) usr = srp.User(b"chat", self.password, hash_alg=srp.SHA256)
_, A = usr.start_authentication() _, A = usr.start_authentication()
@ -62,7 +90,7 @@ class Client:
"username": self.username, "username": self.username,
"A": base64.b64encode(A).decode(), "A": base64.b64encode(A).decode(),
}, },
timeout=30, **req_kwargs,
) )
resp.raise_for_status() resp.raise_for_status()
init_data = resp.json() init_data = resp.json()
@ -93,7 +121,7 @@ class Client:
"username": self.username, "username": self.username,
"M": base64.b64encode(M).decode(), "M": base64.b64encode(M).decode(),
}, },
timeout=30, **req_kwargs,
) )
resp.raise_for_status() resp.raise_for_status()
verify_data = resp.json() verify_data = resp.json()
@ -104,8 +132,7 @@ class Client:
if not usr.authenticated(): if not usr.authenticated():
raise ValueError("Server authentication failed") raise ValueError("Server authentication failed")
session_key = base64.b64decode(verify_data["session_key"]) self.ws_token = verify_data["ws_token"]
self.fernet = Fernet(session_key)
self.success(f"SRP authenticated (session: {self.user_id[:8]}...)") self.success(f"SRP authenticated (session: {self.user_id[:8]}...)")
@ -118,29 +145,36 @@ class Client:
msg["text"] = "[decrypt failed]" msg["text"] = "[decrypt failed]"
return msg return msg
@staticmethod
def _safe_username(username: str) -> str:
return username.replace("[", "\\[")
def render_messages(self) -> None: def render_messages(self) -> None:
self.console.clear() self.console.clear()
users_online = ", ".join(u.get("username", "?") for u in self.users) or "none" users_online = (
", ".join(self._safe_username(u.get("username", "?")) for u in self.users)
or "none"
)
self.console.print(f"[dim]Online: {users_online}[/]") self.console.print(f"[dim]Online: {users_online}[/]")
self.console.print("" * 60) self.console.print("-" * 60)
display_messages = ( display_messages = (
self.messages[-15:] if len(self.messages) > 15 else self.messages self.messages[-15:] if len(self.messages) > 15 else self.messages
) )
for msg in display_messages: for msg in display_messages:
username = msg.get("username", "unknown") username = self._safe_username(msg.get("username", "unknown"))
text = msg.get("text", "") text = msg.get("text", "")
timestamp = str(msg.get("timestamp", ""))[:19].replace("T", " ") timestamp = str(msg.get("timestamp", ""))[:19].replace("T", " ")
style = "green" if username == self.username else "cyan" style = "green" if msg.get("username") == self.username else "cyan"
self.console.print(f"[dim]{timestamp}[/] [{style}]{username}[/]: {text}") self.console.print(f"[dim]{timestamp}[/] [{style}]{username}[/]: {text}")
if not display_messages: if not display_messages:
self.console.print("[dim italic]No messages yet...[/]") self.console.print("[dim italic]No messages yet...[/]")
self.console.print("" * 60) self.console.print("-" * 60)
self.console.print("[dim]Type message and press Enter. 'q' to quit.[/]") self.console.print("[dim]Type message and press Enter. 'q' to quit.[/]")
async def receive_loop(self, ws) -> None: async def receive_loop(self, ws) -> None:
@ -196,9 +230,10 @@ class Client:
self.srp_authenticate() self.srp_authenticate()
self.info("Connecting to chat...") self.info("Connecting to chat...")
url = f"{self.ws_url}/ws/chat?user_id={self.user_id}" url = f"{self.ws_url}/ws/chat?user_id={self.user_id}&ws_token={self.ws_token}"
async with websockets.connect(url) as ws: ws_ssl = self._ws_ssl_context()
async with websockets.connect(url, ssl=ws_ssl) as ws:
self.success("Connected to chat server") self.success("Connected to chat server")
self.running = True self.running = True

View File

@ -1,12 +1,13 @@
import asyncio import asyncio
import secrets
from contextlib import suppress from contextlib import suppress
from cryptography.fernet import Fernet
from sanic import Sanic from sanic import Sanic
from sanic_ext import Extend from sanic_ext import Extend
import os import os
from .managers import ConnectionManager from .managers import ConnectionManager
from .stores import MessageStore, UserSessionStore from .stores import MessageStore, UserSessionStore
from .srp_auth import SRPAuthManager from .srp_auth import SRPAuthManager
from .helpers import RateLimiter
from .routes import register_routes from .routes import register_routes
@ -20,6 +21,9 @@ def create_app(password: str = "", name: str = "cmd-chat-server") -> Sanic:
app.ctx.connection_manager = ConnectionManager() app.ctx.connection_manager = ConnectionManager()
app.ctx.srp_manager = SRPAuthManager(password) app.ctx.srp_manager = SRPAuthManager(password)
app.ctx.room_salt = os.urandom(16) app.ctx.room_salt = os.urandom(16)
app.ctx.ws_secret = os.urandom(32)
app.ctx.admin_token = secrets.token_hex(16)
app.ctx.rate_limiter = RateLimiter(max_requests=10, window_seconds=60)
app.ctx.cleanup_task = None app.ctx.cleanup_task = None
register_lifecycle(app) register_lifecycle(app)

View File

@ -1,36 +1,21 @@
import time
from collections import defaultdict
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Optional
from dataclasses import asdict from dataclasses import asdict
import json import json
from sanic import Sanic, Request, response, Websocket from sanic import Request, Sanic, Websocket
def utcnow() -> datetime: def utcnow() -> datetime:
return datetime.now(timezone.utc) return datetime.now(timezone.utc)
def verify_password(password: Optional[str], expected: Optional[str]) -> bool:
if not expected:
return True
return password == expected
def get_client_ip(request: Request) -> str: def get_client_ip(request: Request) -> str:
if forwarded := request.headers.get("x-forwarded-for"): if forwarded := request.headers.get("x-forwarded-for"):
return forwarded.split(",")[0].strip() return forwarded.split(",")[0].strip()
return request.ip return request.ip
def get_param(request: Request, name: str) -> Optional[str]:
return request.args.get(name) or request.form.get(name)
def require_auth(request: Request, app: Sanic) -> Optional[response.HTTPResponse]:
if not verify_password(get_param(request, "password"), app.ctx.admin_password):
return response.text("Unauthorized", status=401)
return None
async def send_state(ws: Websocket, app: Sanic) -> None: async def send_state(ws: Websocket, app: Sanic) -> None:
messages = app.ctx.message_store.get_all() messages = app.ctx.message_store.get_all()
users = app.ctx.session_store.get_all() users = app.ctx.session_store.get_all()
@ -47,12 +32,17 @@ async def send_state(ws: Websocket, app: Sanic) -> None:
) )
def extract_pubkey(request: Request) -> Optional[bytes]: class RateLimiter:
if files := request.files.get("pubkey"): def __init__(self, max_requests: int = 10, window_seconds: int = 60):
file = files[0] if isinstance(files, list) else files self.max_requests = max_requests
return file.body self.window = window_seconds
if raw := request.form.get("pubkey"): self._requests: dict[str, list[float]] = defaultdict(list)
return raw.encode() if isinstance(raw, str) else raw
if raw := request.args.get("pubkey"): def is_allowed(self, key: str) -> bool:
return raw.encode() if isinstance(raw, str) else raw now = time.monotonic()
return None timestamps = self._requests[key]
timestamps[:] = [t for t in timestamps if now - t < self.window]
if len(timestamps) >= self.max_requests:
return False
timestamps.append(now)
return True

View File

@ -11,7 +11,6 @@ class Message:
timestamp: str = field( timestamp: str = field(
default_factory=lambda: datetime.now(timezone.utc).isoformat() default_factory=lambda: datetime.now(timezone.utc).isoformat()
) )
user_ip: str = ""
username: str = "" username: str = ""

View File

@ -1,19 +1,97 @@
import ipaddress
import ssl
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Optional from typing import Optional
from cryptography import x509
from cryptography.x509.oid import NameOID
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import ec
from .factory import create_app from .factory import create_app
DEFAULT_CERT_DIR = Path.home() / ".cmd-chat" / "certs"
def ensure_tls_certs(cert_dir: Path = DEFAULT_CERT_DIR) -> tuple[Path, Path]:
cert_dir.mkdir(parents=True, exist_ok=True)
cert_path = cert_dir / "server.pem"
key_path = cert_dir / "server-key.pem"
if cert_path.exists() and key_path.exists():
return cert_path, key_path
key = ec.generate_private_key(ec.SECP256R1())
subject = issuer = x509.Name([
x509.NameAttribute(NameOID.COMMON_NAME, "cmd-chat"),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "cmd-chat-self-signed"),
])
cert = (
x509.CertificateBuilder()
.subject_name(subject)
.issuer_name(issuer)
.public_key(key.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(datetime.now(timezone.utc))
.not_valid_after(datetime.now(timezone.utc) + timedelta(days=365))
.add_extension(
x509.SubjectAlternativeName([
x509.DNSName("localhost"),
x509.IPAddress(ipaddress.IPv4Address("127.0.0.1")),
]),
critical=False,
)
.sign(key, hashes.SHA256())
)
key_path.write_bytes(
key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
)
)
cert_path.write_bytes(cert.public_bytes(serialization.Encoding.PEM))
return cert_path, key_path
def run_server( def run_server(
host: str = "0.0.0.0", host: str = "0.0.0.0",
port: int = 8000, port: int = 8000,
password: Optional[str] = None, password: Optional[str] = None,
workers: int = 1, cert_path: Optional[str] = None,
key_path: Optional[str] = None,
no_tls: bool = False,
) -> None: ) -> None:
app = create_app(password=password or "") app = create_app(password=password or "")
ssl_ctx = None
if not no_tls:
if cert_path and key_path:
c, k = Path(cert_path), Path(key_path)
else:
c, k = ensure_tls_certs()
print(f"[TLS] Auto-generated self-signed cert: {c}")
ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
ssl_ctx.load_cert_chain(str(c), str(k))
protocol = "https"
else:
print("[WARNING] TLS disabled — all traffic is unencrypted!")
protocol = "http"
print(f"[ADMIN] Clear token: {app.ctx.admin_token}")
print(f"[SERVER] Listening on {protocol}://{host}:{port}")
app.run( app.run(
host=host, host=host,
port=port, port=port,
single_process=True, single_process=True,
debug=False, debug=False,
access_log=True, access_log=False,
ssl=ssl_ctx,
) )

View File

@ -3,11 +3,14 @@ from .models import Message, UserSession
class MessageStore: class MessageStore:
def __init__(self): def __init__(self, max_messages: int = 1000):
self._messages: list[Message] = [] self._messages: list[Message] = []
self._max = max_messages
def add(self, message: Message) -> None: def add(self, message: Message) -> None:
self._messages.append(message) self._messages.append(message)
if len(self._messages) > self._max:
self._messages = self._messages[-self._max:]
def get_all(self) -> list[Message]: def get_all(self) -> list[Message]:
return self._messages.copy() return self._messages.copy()

View File

@ -1,21 +1,26 @@
from dataclasses import asdict import hashlib
import hmac
import json import json
import base64 import base64
from dataclasses import asdict
from sanic import Sanic, Request, response, Websocket from sanic import Sanic, Request, response, Websocket
from sanic.response import HTTPResponse, json as json_response from sanic.response import HTTPResponse, json as json_response
from .models import Message, UserSession from .models import Message, UserSession
from .helpers import ( from .helpers import get_client_ip, send_state, utcnow
get_client_ip,
send_state,
utcnow, def generate_ws_token(user_id: str, secret: bytes) -> str:
) return hmac.new(secret, user_id.encode(), hashlib.sha256).hexdigest()
async def srp_init(request: Request, app: Sanic) -> HTTPResponse: async def srp_init(request: Request, app: Sanic) -> HTTPResponse:
try: try:
client_ip = get_client_ip(request)
if not app.ctx.rate_limiter.is_allowed(client_ip):
return response.json({"error": "Rate limited"}, status=429)
data = request.json or {} data = request.json or {}
username = data.get("username", "unknown") username = data.get("username", "unknown")
client_public_b64 = data.get("A") client_public_b64 = data.get("A")
@ -45,6 +50,10 @@ async def srp_init(request: Request, app: Sanic) -> HTTPResponse:
async def srp_verify(request: Request, app: Sanic) -> HTTPResponse: async def srp_verify(request: Request, app: Sanic) -> HTTPResponse:
try: try:
client_ip = get_client_ip(request)
if not app.ctx.rate_limiter.is_allowed(client_ip):
return response.json({"error": "Rate limited"}, status=429)
data = request.json or {} data = request.json or {}
user_id = data.get("user_id") user_id = data.get("user_id")
client_proof_b64 = data.get("M") client_proof_b64 = data.get("M")
@ -67,10 +76,12 @@ async def srp_verify(request: Request, app: Sanic) -> HTTPResponse:
) )
app.ctx.session_store.add(session) app.ctx.session_store.add(session)
ws_token = generate_ws_token(user_id, app.ctx.ws_secret)
return response.json( return response.json(
{ {
"H_AMK": base64.b64encode(H_AMK).decode(), "H_AMK": base64.b64encode(H_AMK).decode(),
"session_key": base64.b64encode(fernet_key).decode(), "ws_token": ws_token,
} }
) )
@ -82,9 +93,15 @@ async def srp_verify(request: Request, app: Sanic) -> HTTPResponse:
async def chat_ws(request: Request, ws: Websocket, app: Sanic) -> None: async def chat_ws(request: Request, ws: Websocket, app: Sanic) -> None:
user_id = request.args.get("user_id") user_id = request.args.get("user_id")
ws_token = request.args.get("ws_token")
if not user_id: if not user_id or not ws_token:
await ws.close(code=4002, reason="user_id required") await ws.close(code=4002, reason="user_id and ws_token required")
return
expected_token = generate_ws_token(user_id, app.ctx.ws_secret)
if not hmac.compare_digest(ws_token, expected_token):
await ws.close(code=4003, reason="Invalid token")
return return
session = app.ctx.session_store.get(user_id) session = app.ctx.session_store.get(user_id)
@ -106,7 +123,6 @@ async def chat_ws(request: Request, ws: Websocket, app: Sanic) -> None:
message = Message( message = Message(
text=str(data), text=str(data),
user_ip=session.ip,
username=session.username, username=session.username,
) )
app.ctx.message_store.add(message) app.ctx.message_store.add(message)
@ -146,8 +162,12 @@ async def health(request: Request, app: Sanic) -> HTTPResponse:
async def clear_messages(request: Request, app: Sanic) -> HTTPResponse: async def clear_messages(request: Request, app: Sanic) -> HTTPResponse:
user_id = request.args.get("user_id") auth_header = request.headers.get("authorization", "")
if not user_id or not app.ctx.session_store.get(user_id): if not auth_header.startswith("Bearer "):
return response.json({"error": "Unauthorized"}, status=401)
token = auth_header[7:]
if not hmac.compare_digest(token, app.ctx.admin_token):
return response.json({"error": "Unauthorized"}, status=401) return response.json({"error": "Unauthorized"}, status=401)
app.ctx.message_store.clear() app.ctx.message_store.clear()

View File

@ -28,6 +28,10 @@ def app():
app.ctx.connection_manager = ConnectionManager() app.ctx.connection_manager = ConnectionManager()
app.ctx.srp_manager = SRPAuthManager("testpassword") app.ctx.srp_manager = SRPAuthManager("testpassword")
app.ctx.room_salt = os.urandom(16) app.ctx.room_salt = os.urandom(16)
app.ctx.ws_secret = os.urandom(32)
app.ctx.admin_token = "test-admin-token"
from cmd_chat.server.helpers import RateLimiter
app.ctx.rate_limiter = RateLimiter(max_requests=100, window_seconds=60)
app.ctx.cleanup_task = None app.ctx.cleanup_task = None
register_routes(app) register_routes(app)

View File

@ -52,14 +52,19 @@ class TestClientInit:
assert client.username == "testuser" assert client.username == "testuser"
assert client.password == b"testpassword" assert client.password == b"testpassword"
assert client.user_id is None assert client.user_id is None
assert client.fernet is None assert client.ws_token is None
assert client.room_fernet is None assert client.room_fernet is None
assert client.connected is False assert client.connected is False
assert client.running is False assert client.running is False
def test_client_urls(self, client): def test_client_urls(self, client):
assert client.base_url == "http://127.0.0.1:3000" assert client.base_url == "https://127.0.0.1:3000"
assert client.ws_url == "ws://127.0.0.1:3000" assert client.ws_url == "wss://127.0.0.1:3000"
def test_client_no_tls_urls(self):
c = Client("127.0.0.1", 3000, "user", "pass", no_tls=True)
assert c.base_url == "http://127.0.0.1:3000"
assert c.ws_url == "ws://127.0.0.1:3000"
def test_client_empty_password(self): def test_client_empty_password(self):
client = Client("localhost", 8080, "user", None) client = Client("localhost", 8080, "user", None)

View File

@ -48,15 +48,20 @@ def room_fernet(room_salt):
class TestClientProperties: class TestClientProperties:
def test_base_url_different_ports(self): def test_base_url_different_ports(self):
client = Client("example.com", 8080, "user", "pass") client = Client("example.com", 8080, "user", "pass")
assert client.base_url == "http://example.com:8080" assert client.base_url == "https://example.com:8080"
def test_ws_url_different_ports(self): def test_ws_url_different_ports(self):
client = Client("example.com", 8080, "user", "pass") client = Client("example.com", 8080, "user", "pass")
assert client.ws_url == "ws://example.com:8080" assert client.ws_url == "wss://example.com:8080"
def test_base_url_localhost(self): def test_base_url_localhost(self):
client = Client("localhost", 443, "user", "pass") client = Client("localhost", 443, "user", "pass")
assert client.base_url == "http://localhost:443" assert client.base_url == "https://localhost:443"
def test_no_tls_urls(self):
client = Client("example.com", 8080, "user", "pass", no_tls=True)
assert client.base_url == "http://example.com:8080"
assert client.ws_url == "ws://example.com:8080"
def test_password_encoding_unicode(self): def test_password_encoding_unicode(self):
client = Client("localhost", 3000, "user", "пароль123") client = Client("localhost", 3000, "user", "пароль123")
@ -84,7 +89,7 @@ class TestSRPAuthentication:
verify_response = MagicMock() verify_response = MagicMock()
verify_response.json.return_value = { verify_response.json.return_value = {
"H_AMK": base64.b64encode(os.urandom(32)).decode(), "H_AMK": base64.b64encode(os.urandom(32)).decode(),
"session_key": base64.b64encode(Fernet.generate_key()).decode(), "ws_token": "test-ws-token-hex",
} }
verify_response.raise_for_status = MagicMock() verify_response.raise_for_status = MagicMock()
@ -102,7 +107,7 @@ class TestSRPAuthentication:
assert client.user_id == "test-user-id-12345" assert client.user_id == "test-user-id-12345"
assert client.room_fernet is not None assert client.room_fernet is not None
assert client.fernet is not None assert client.ws_token == "test-ws-token-hex"
@patch("cmd_chat.client.client.requests.post") @patch("cmd_chat.client.client.requests.post")
def test_srp_authenticate_init_fails(self, mock_post, client): def test_srp_authenticate_init_fails(self, mock_post, client):
@ -178,7 +183,7 @@ class TestSRPAuthentication:
verify_response = MagicMock() verify_response = MagicMock()
verify_response.json.return_value = { verify_response.json.return_value = {
"H_AMK": base64.b64encode(os.urandom(32)).decode(), "H_AMK": base64.b64encode(os.urandom(32)).decode(),
"session_key": base64.b64encode(Fernet.generate_key()).decode(), "ws_token": "test-ws-token-hex",
} }
verify_response.raise_for_status = MagicMock() verify_response.raise_for_status = MagicMock()
@ -237,7 +242,6 @@ class TestDecryptMessage:
"username": "sender", "username": "sender",
"timestamp": "2024-01-01T12:00:00", "timestamp": "2024-01-01T12:00:00",
"id": "msg-123", "id": "msg-123",
"user_ip": "192.168.1.1",
} }
decrypted = client.decrypt_message(msg) decrypted = client.decrypt_message(msg)
@ -246,7 +250,6 @@ class TestDecryptMessage:
assert decrypted["username"] == "sender" assert decrypted["username"] == "sender"
assert decrypted["timestamp"] == "2024-01-01T12:00:00" assert decrypted["timestamp"] == "2024-01-01T12:00:00"
assert decrypted["id"] == "msg-123" assert decrypted["id"] == "msg-123"
assert decrypted["user_ip"] == "192.168.1.1"
def test_decrypt_wrong_key_marks_failed(self, client): def test_decrypt_wrong_key_marks_failed(self, client):
@ -553,6 +556,7 @@ class TestRunAsync:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_successful_connection_and_disconnect(self, client): async def test_run_successful_connection_and_disconnect(self, client):
client.user_id = "test-id-123" client.user_id = "test-id-123"
client.ws_token = "test-token"
with patch.object(client, "srp_authenticate"): with patch.object(client, "srp_authenticate"):
with patch("cmd_chat.client.client.websockets.connect") as mock_connect: with patch("cmd_chat.client.client.websockets.connect") as mock_connect:
@ -844,8 +848,8 @@ class TestEdgeCases:
def test_port_zero(self): def test_port_zero(self):
client = Client("localhost", 0, "user", "pass") client = Client("localhost", 0, "user", "pass")
assert client.port == 0 assert client.port == 0
assert client.base_url == "http://localhost:0" assert client.base_url == "https://localhost:0"
def test_ipv6_server(self): def test_ipv6_server(self):
client = Client("::1", 3000, "user", "pass") client = Client("::1", 3000, "user", "pass")
assert client.base_url == "http://::1:3000" assert client.base_url == "https://::1:3000"