diff --git a/cmd_chat/__init__.py b/cmd_chat/__init__.py index c6ea4f6..ef9bc2d 100644 --- a/cmd_chat/__init__.py +++ b/cmd_chat/__init__.py @@ -1,8 +1,19 @@ import argparse +import getpass +import os + from cmd_chat.server.server import run_server from cmd_chat.client.client import Client +def resolve_password(args_password: str | None, prompt: str = "Room password: ") -> str: + if args_password: + return args_password + if env_pw := os.environ.get("CMD_CHAT_PASSWORD"): + return env_pw + return getpass.getpass(prompt) + + def main(): parser = argparse.ArgumentParser(description="Command-line chat application") subparsers = parser.add_subparsers(dest="command", required=True) @@ -10,24 +21,46 @@ def main(): serve_p = subparsers.add_parser("serve", help="Run server") serve_p.add_argument("ip_address") serve_p.add_argument("port") - serve_p.add_argument("--password", "-p", required=True) + serve_p.add_argument("--password", "-p", default=None) + serve_p.add_argument("--cert", default=None, help="Path to TLS certificate") + serve_p.add_argument("--key", default=None, help="Path to TLS private key") + serve_p.add_argument("--no-tls", action="store_true", help="Disable TLS (insecure)") connect_p = subparsers.add_parser("connect", help="Connect to server") connect_p.add_argument("ip_address") connect_p.add_argument("port") connect_p.add_argument("username") - connect_p.add_argument("password") + connect_p.add_argument("--password", "-p", default=None) + connect_p.add_argument( + "--insecure", "-k", action="store_true", + help="Skip TLS certificate verification (for self-signed certs)", + ) + connect_p.add_argument( + "--no-tls", action="store_true", + help="Connect without TLS (insecure)", + ) args = parser.parse_args() if args.command == "serve": - run_server(host=args.ip_address, port=int(args.port), password=args.password) + password = resolve_password(args.password) + run_server( + host=args.ip_address, + port=int(args.port), + password=password, + cert_path=args.cert, + key_path=args.key, + no_tls=args.no_tls, + ) elif args.command == "connect": + password = resolve_password(args.password) Client( server=args.ip_address, port=int(args.port), username=args.username, - password=args.password, + password=password, + insecure=args.insecure, + no_tls=args.no_tls, ).run() diff --git a/cmd_chat/client/client.py b/cmd_chat/client/client.py index ea3e67f..b1b9214 100644 --- a/cmd_chat/client/client.py +++ b/cmd_chat/client/client.py @@ -1,5 +1,6 @@ import asyncio import json +import ssl import base64 from typing import Optional @@ -17,14 +18,22 @@ srp.rfc5054_enable() class Client: def __init__( - self, server: str, port: int, username: str, password: Optional[str] = None + self, + server: str, + port: int, + username: str, + password: Optional[str] = None, + insecure: bool = False, + no_tls: bool = False, ): self.server = server self.port = port self.username = username self.password = (password or "").encode() + self.insecure = insecure + self.no_tls = no_tls self.user_id: Optional[str] = None - self.fernet: Optional[Fernet] = None + self.ws_token: Optional[str] = None self.room_fernet: Optional[Fernet] = None self.console = Console() @@ -35,23 +44,42 @@ class Client: @property def base_url(self) -> str: - return f"http://{self.server}:{self.port}" + scheme = "http" if self.no_tls else "https" + return f"{scheme}://{self.server}:{self.port}" @property def ws_url(self) -> str: - return f"ws://{self.server}:{self.port}" + scheme = "ws" if self.no_tls else "wss" + return f"{scheme}://{self.server}:{self.port}" + + def _request_kwargs(self) -> dict: + kwargs: dict = {"timeout": 30} + if self.insecure and not self.no_tls: + kwargs["verify"] = False + return kwargs + + def _ws_ssl_context(self) -> ssl.SSLContext | None: + if self.no_tls: + return None + if self.insecure: + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ctx.check_hostname = False + ctx.verify_mode = ssl.CERT_NONE + return ctx + return True # default verification def success(self, message: str) -> None: - self.console.print(f"[green]✓ {message}[/]") + self.console.print(f"[green]{message}[/]") def error(self, message: str) -> None: - self.console.print(f"[red]✗ {message}[/]") + self.console.print(f"[red]{message}[/]") def info(self, message: str) -> None: - self.console.print(f"[cyan]• {message}[/]") + self.console.print(f"[cyan]{message}[/]") def srp_authenticate(self) -> None: with self.console.status("[cyan]Starting SRP handshake...[/]", spinner="dots"): + req_kwargs = self._request_kwargs() usr = srp.User(b"chat", self.password, hash_alg=srp.SHA256) _, A = usr.start_authentication() @@ -62,7 +90,7 @@ class Client: "username": self.username, "A": base64.b64encode(A).decode(), }, - timeout=30, + **req_kwargs, ) resp.raise_for_status() init_data = resp.json() @@ -93,7 +121,7 @@ class Client: "username": self.username, "M": base64.b64encode(M).decode(), }, - timeout=30, + **req_kwargs, ) resp.raise_for_status() verify_data = resp.json() @@ -104,8 +132,7 @@ class Client: if not usr.authenticated(): raise ValueError("Server authentication failed") - session_key = base64.b64decode(verify_data["session_key"]) - self.fernet = Fernet(session_key) + self.ws_token = verify_data["ws_token"] self.success(f"SRP authenticated (session: {self.user_id[:8]}...)") @@ -118,29 +145,36 @@ class Client: msg["text"] = "[decrypt failed]" return msg + @staticmethod + def _safe_username(username: str) -> str: + return username.replace("[", "\\[") + def render_messages(self) -> None: self.console.clear() - users_online = ", ".join(u.get("username", "?") for u in self.users) or "none" + users_online = ( + ", ".join(self._safe_username(u.get("username", "?")) for u in self.users) + or "none" + ) self.console.print(f"[dim]Online: {users_online}[/]") - self.console.print("─" * 60) + self.console.print("-" * 60) display_messages = ( self.messages[-15:] if len(self.messages) > 15 else self.messages ) for msg in display_messages: - username = msg.get("username", "unknown") + username = self._safe_username(msg.get("username", "unknown")) text = msg.get("text", "") timestamp = str(msg.get("timestamp", ""))[:19].replace("T", " ") - style = "green" if username == self.username else "cyan" + style = "green" if msg.get("username") == self.username else "cyan" self.console.print(f"[dim]{timestamp}[/] [{style}]{username}[/]: {text}") if not display_messages: self.console.print("[dim italic]No messages yet...[/]") - self.console.print("─" * 60) + self.console.print("-" * 60) self.console.print("[dim]Type message and press Enter. 'q' to quit.[/]") async def receive_loop(self, ws) -> None: @@ -196,9 +230,10 @@ class Client: self.srp_authenticate() self.info("Connecting to chat...") - url = f"{self.ws_url}/ws/chat?user_id={self.user_id}" + url = f"{self.ws_url}/ws/chat?user_id={self.user_id}&ws_token={self.ws_token}" - async with websockets.connect(url) as ws: + ws_ssl = self._ws_ssl_context() + async with websockets.connect(url, ssl=ws_ssl) as ws: self.success("Connected to chat server") self.running = True diff --git a/cmd_chat/server/factory.py b/cmd_chat/server/factory.py index 894f08c..c40541e 100644 --- a/cmd_chat/server/factory.py +++ b/cmd_chat/server/factory.py @@ -1,12 +1,13 @@ import asyncio +import secrets from contextlib import suppress -from cryptography.fernet import Fernet from sanic import Sanic from sanic_ext import Extend import os from .managers import ConnectionManager from .stores import MessageStore, UserSessionStore from .srp_auth import SRPAuthManager +from .helpers import RateLimiter from .routes import register_routes @@ -20,6 +21,9 @@ def create_app(password: str = "", name: str = "cmd-chat-server") -> Sanic: app.ctx.connection_manager = ConnectionManager() app.ctx.srp_manager = SRPAuthManager(password) app.ctx.room_salt = os.urandom(16) + app.ctx.ws_secret = os.urandom(32) + app.ctx.admin_token = secrets.token_hex(16) + app.ctx.rate_limiter = RateLimiter(max_requests=10, window_seconds=60) app.ctx.cleanup_task = None register_lifecycle(app) diff --git a/cmd_chat/server/helpers.py b/cmd_chat/server/helpers.py index db47811..bfd6b7f 100644 --- a/cmd_chat/server/helpers.py +++ b/cmd_chat/server/helpers.py @@ -1,36 +1,21 @@ +import time +from collections import defaultdict from datetime import datetime, timezone -from typing import Optional from dataclasses import asdict import json -from sanic import Sanic, Request, response, Websocket +from sanic import Request, Sanic, Websocket def utcnow() -> datetime: return datetime.now(timezone.utc) -def verify_password(password: Optional[str], expected: Optional[str]) -> bool: - if not expected: - return True - return password == expected - - def get_client_ip(request: Request) -> str: if forwarded := request.headers.get("x-forwarded-for"): return forwarded.split(",")[0].strip() return request.ip -def get_param(request: Request, name: str) -> Optional[str]: - return request.args.get(name) or request.form.get(name) - - -def require_auth(request: Request, app: Sanic) -> Optional[response.HTTPResponse]: - if not verify_password(get_param(request, "password"), app.ctx.admin_password): - return response.text("Unauthorized", status=401) - return None - - async def send_state(ws: Websocket, app: Sanic) -> None: messages = app.ctx.message_store.get_all() users = app.ctx.session_store.get_all() @@ -47,12 +32,17 @@ async def send_state(ws: Websocket, app: Sanic) -> None: ) -def extract_pubkey(request: Request) -> Optional[bytes]: - if files := request.files.get("pubkey"): - file = files[0] if isinstance(files, list) else files - return file.body - if raw := request.form.get("pubkey"): - return raw.encode() if isinstance(raw, str) else raw - if raw := request.args.get("pubkey"): - return raw.encode() if isinstance(raw, str) else raw - return None +class RateLimiter: + def __init__(self, max_requests: int = 10, window_seconds: int = 60): + self.max_requests = max_requests + self.window = window_seconds + self._requests: dict[str, list[float]] = defaultdict(list) + + def is_allowed(self, key: str) -> bool: + now = time.monotonic() + timestamps = self._requests[key] + timestamps[:] = [t for t in timestamps if now - t < self.window] + if len(timestamps) >= self.max_requests: + return False + timestamps.append(now) + return True diff --git a/cmd_chat/server/models.py b/cmd_chat/server/models.py index dce1601..c83b134 100644 --- a/cmd_chat/server/models.py +++ b/cmd_chat/server/models.py @@ -11,7 +11,6 @@ class Message: timestamp: str = field( default_factory=lambda: datetime.now(timezone.utc).isoformat() ) - user_ip: str = "" username: str = "" diff --git a/cmd_chat/server/server.py b/cmd_chat/server/server.py index 87e03c0..4985bf2 100644 --- a/cmd_chat/server/server.py +++ b/cmd_chat/server/server.py @@ -1,19 +1,97 @@ +import ipaddress +import ssl +from datetime import datetime, timedelta, timezone +from pathlib import Path from typing import Optional + +from cryptography import x509 +from cryptography.x509.oid import NameOID +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import ec + from .factory import create_app +DEFAULT_CERT_DIR = Path.home() / ".cmd-chat" / "certs" + + +def ensure_tls_certs(cert_dir: Path = DEFAULT_CERT_DIR) -> tuple[Path, Path]: + cert_dir.mkdir(parents=True, exist_ok=True) + cert_path = cert_dir / "server.pem" + key_path = cert_dir / "server-key.pem" + + if cert_path.exists() and key_path.exists(): + return cert_path, key_path + + key = ec.generate_private_key(ec.SECP256R1()) + + subject = issuer = x509.Name([ + x509.NameAttribute(NameOID.COMMON_NAME, "cmd-chat"), + x509.NameAttribute(NameOID.ORGANIZATION_NAME, "cmd-chat-self-signed"), + ]) + + cert = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.now(timezone.utc)) + .not_valid_after(datetime.now(timezone.utc) + timedelta(days=365)) + .add_extension( + x509.SubjectAlternativeName([ + x509.DNSName("localhost"), + x509.IPAddress(ipaddress.IPv4Address("127.0.0.1")), + ]), + critical=False, + ) + .sign(key, hashes.SHA256()) + ) + + key_path.write_bytes( + key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ) + ) + cert_path.write_bytes(cert.public_bytes(serialization.Encoding.PEM)) + + return cert_path, key_path + def run_server( host: str = "0.0.0.0", port: int = 8000, password: Optional[str] = None, - workers: int = 1, + cert_path: Optional[str] = None, + key_path: Optional[str] = None, + no_tls: bool = False, ) -> None: app = create_app(password=password or "") + ssl_ctx = None + if not no_tls: + if cert_path and key_path: + c, k = Path(cert_path), Path(key_path) + else: + c, k = ensure_tls_certs() + print(f"[TLS] Auto-generated self-signed cert: {c}") + + ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + ssl_ctx.load_cert_chain(str(c), str(k)) + protocol = "https" + else: + print("[WARNING] TLS disabled — all traffic is unencrypted!") + protocol = "http" + + print(f"[ADMIN] Clear token: {app.ctx.admin_token}") + print(f"[SERVER] Listening on {protocol}://{host}:{port}") + app.run( host=host, port=port, single_process=True, debug=False, - access_log=True, + access_log=False, + ssl=ssl_ctx, ) diff --git a/cmd_chat/server/stores.py b/cmd_chat/server/stores.py index ff1ec1b..2b1d60a 100644 --- a/cmd_chat/server/stores.py +++ b/cmd_chat/server/stores.py @@ -3,11 +3,14 @@ from .models import Message, UserSession class MessageStore: - def __init__(self): + def __init__(self, max_messages: int = 1000): self._messages: list[Message] = [] + self._max = max_messages def add(self, message: Message) -> None: self._messages.append(message) + if len(self._messages) > self._max: + self._messages = self._messages[-self._max:] def get_all(self) -> list[Message]: return self._messages.copy() diff --git a/cmd_chat/server/views.py b/cmd_chat/server/views.py index e2ecd36..a24ee69 100644 --- a/cmd_chat/server/views.py +++ b/cmd_chat/server/views.py @@ -1,21 +1,26 @@ -from dataclasses import asdict - +import hashlib +import hmac import json import base64 +from dataclasses import asdict from sanic import Sanic, Request, response, Websocket from sanic.response import HTTPResponse, json as json_response from .models import Message, UserSession -from .helpers import ( - get_client_ip, - send_state, - utcnow, -) +from .helpers import get_client_ip, send_state, utcnow + + +def generate_ws_token(user_id: str, secret: bytes) -> str: + return hmac.new(secret, user_id.encode(), hashlib.sha256).hexdigest() async def srp_init(request: Request, app: Sanic) -> HTTPResponse: try: + client_ip = get_client_ip(request) + if not app.ctx.rate_limiter.is_allowed(client_ip): + return response.json({"error": "Rate limited"}, status=429) + data = request.json or {} username = data.get("username", "unknown") client_public_b64 = data.get("A") @@ -45,6 +50,10 @@ async def srp_init(request: Request, app: Sanic) -> HTTPResponse: async def srp_verify(request: Request, app: Sanic) -> HTTPResponse: try: + client_ip = get_client_ip(request) + if not app.ctx.rate_limiter.is_allowed(client_ip): + return response.json({"error": "Rate limited"}, status=429) + data = request.json or {} user_id = data.get("user_id") client_proof_b64 = data.get("M") @@ -67,10 +76,12 @@ async def srp_verify(request: Request, app: Sanic) -> HTTPResponse: ) app.ctx.session_store.add(session) + ws_token = generate_ws_token(user_id, app.ctx.ws_secret) + return response.json( { "H_AMK": base64.b64encode(H_AMK).decode(), - "session_key": base64.b64encode(fernet_key).decode(), + "ws_token": ws_token, } ) @@ -82,9 +93,15 @@ async def srp_verify(request: Request, app: Sanic) -> HTTPResponse: async def chat_ws(request: Request, ws: Websocket, app: Sanic) -> None: user_id = request.args.get("user_id") + ws_token = request.args.get("ws_token") - if not user_id: - await ws.close(code=4002, reason="user_id required") + if not user_id or not ws_token: + await ws.close(code=4002, reason="user_id and ws_token required") + return + + expected_token = generate_ws_token(user_id, app.ctx.ws_secret) + if not hmac.compare_digest(ws_token, expected_token): + await ws.close(code=4003, reason="Invalid token") return session = app.ctx.session_store.get(user_id) @@ -106,7 +123,6 @@ async def chat_ws(request: Request, ws: Websocket, app: Sanic) -> None: message = Message( text=str(data), - user_ip=session.ip, username=session.username, ) app.ctx.message_store.add(message) @@ -146,8 +162,12 @@ async def health(request: Request, app: Sanic) -> HTTPResponse: async def clear_messages(request: Request, app: Sanic) -> HTTPResponse: - user_id = request.args.get("user_id") - if not user_id or not app.ctx.session_store.get(user_id): + auth_header = request.headers.get("authorization", "") + if not auth_header.startswith("Bearer "): + return response.json({"error": "Unauthorized"}, status=401) + + token = auth_header[7:] + if not hmac.compare_digest(token, app.ctx.admin_token): return response.json({"error": "Unauthorized"}, status=401) app.ctx.message_store.clear() diff --git a/tests/conftest.py b/tests/conftest.py index d11ae13..dad0516 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -28,6 +28,10 @@ def app(): app.ctx.connection_manager = ConnectionManager() app.ctx.srp_manager = SRPAuthManager("testpassword") app.ctx.room_salt = os.urandom(16) + app.ctx.ws_secret = os.urandom(32) + app.ctx.admin_token = "test-admin-token" + from cmd_chat.server.helpers import RateLimiter + app.ctx.rate_limiter = RateLimiter(max_requests=100, window_seconds=60) app.ctx.cleanup_task = None register_routes(app) diff --git a/tests/test_client.py b/tests/test_client.py index 54b338a..edafd9b 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -52,14 +52,19 @@ class TestClientInit: assert client.username == "testuser" assert client.password == b"testpassword" assert client.user_id is None - assert client.fernet is None + assert client.ws_token is None assert client.room_fernet is None assert client.connected is False assert client.running is False def test_client_urls(self, client): - assert client.base_url == "http://127.0.0.1:3000" - assert client.ws_url == "ws://127.0.0.1:3000" + assert client.base_url == "https://127.0.0.1:3000" + assert client.ws_url == "wss://127.0.0.1:3000" + + def test_client_no_tls_urls(self): + c = Client("127.0.0.1", 3000, "user", "pass", no_tls=True) + assert c.base_url == "http://127.0.0.1:3000" + assert c.ws_url == "ws://127.0.0.1:3000" def test_client_empty_password(self): client = Client("localhost", 8080, "user", None) diff --git a/tests/test_client_extended.py b/tests/test_client_extended.py index e9f0d32..0bb95f3 100644 --- a/tests/test_client_extended.py +++ b/tests/test_client_extended.py @@ -48,15 +48,20 @@ def room_fernet(room_salt): class TestClientProperties: def test_base_url_different_ports(self): client = Client("example.com", 8080, "user", "pass") - assert client.base_url == "http://example.com:8080" + assert client.base_url == "https://example.com:8080" def test_ws_url_different_ports(self): client = Client("example.com", 8080, "user", "pass") - assert client.ws_url == "ws://example.com:8080" + assert client.ws_url == "wss://example.com:8080" def test_base_url_localhost(self): client = Client("localhost", 443, "user", "pass") - assert client.base_url == "http://localhost:443" + assert client.base_url == "https://localhost:443" + + def test_no_tls_urls(self): + client = Client("example.com", 8080, "user", "pass", no_tls=True) + assert client.base_url == "http://example.com:8080" + assert client.ws_url == "ws://example.com:8080" def test_password_encoding_unicode(self): client = Client("localhost", 3000, "user", "пароль123") @@ -84,7 +89,7 @@ class TestSRPAuthentication: verify_response = MagicMock() verify_response.json.return_value = { "H_AMK": base64.b64encode(os.urandom(32)).decode(), - "session_key": base64.b64encode(Fernet.generate_key()).decode(), + "ws_token": "test-ws-token-hex", } verify_response.raise_for_status = MagicMock() @@ -102,7 +107,7 @@ class TestSRPAuthentication: assert client.user_id == "test-user-id-12345" assert client.room_fernet is not None - assert client.fernet is not None + assert client.ws_token == "test-ws-token-hex" @patch("cmd_chat.client.client.requests.post") def test_srp_authenticate_init_fails(self, mock_post, client): @@ -178,7 +183,7 @@ class TestSRPAuthentication: verify_response = MagicMock() verify_response.json.return_value = { "H_AMK": base64.b64encode(os.urandom(32)).decode(), - "session_key": base64.b64encode(Fernet.generate_key()).decode(), + "ws_token": "test-ws-token-hex", } verify_response.raise_for_status = MagicMock() @@ -237,7 +242,6 @@ class TestDecryptMessage: "username": "sender", "timestamp": "2024-01-01T12:00:00", "id": "msg-123", - "user_ip": "192.168.1.1", } decrypted = client.decrypt_message(msg) @@ -246,7 +250,6 @@ class TestDecryptMessage: assert decrypted["username"] == "sender" assert decrypted["timestamp"] == "2024-01-01T12:00:00" assert decrypted["id"] == "msg-123" - assert decrypted["user_ip"] == "192.168.1.1" def test_decrypt_wrong_key_marks_failed(self, client): @@ -553,6 +556,7 @@ class TestRunAsync: @pytest.mark.asyncio async def test_run_successful_connection_and_disconnect(self, client): client.user_id = "test-id-123" + client.ws_token = "test-token" with patch.object(client, "srp_authenticate"): with patch("cmd_chat.client.client.websockets.connect") as mock_connect: @@ -844,8 +848,8 @@ class TestEdgeCases: def test_port_zero(self): client = Client("localhost", 0, "user", "pass") assert client.port == 0 - assert client.base_url == "http://localhost:0" + assert client.base_url == "https://localhost:0" def test_ipv6_server(self): client = Client("::1", 3000, "user", "pass") - assert client.base_url == "http://::1:3000" + assert client.base_url == "https://::1:3000"