From 95f8a192b518d97a39f1ec2f8508898104a0c8f7 Mon Sep 17 00:00:00 2001 From: mirai Date: Fri, 2 Jan 2026 14:42:33 +0300 Subject: [PATCH] feat: complete client-server architecture refactoring Server: - Split into views, routes, helpers, models modules - Merged /ws/talk and /ws/update into single /ws/chat endpoint - Replaced polling with push-based broadcast model - Added username uniqueness validation on connect - Fixed run_server arguments bug (workers parameter) - Removed deprecated loop argument from Sanic listeners - Replaced datetime.utcnow() with timezone-aware datetime.now(timezone.utc) Client: - Rewrote client as single-file module - Migrated from websocket-client to websockets (asyncio) - Fixed websocket-client conflict with asyncio event loop on Windows - Added progress indicators for key generation, exchange, connection - Added animated 3D spinning cube in UI - Updated RSA key from 512 to 2048 bits CLI: - Removed unnecessary asyncio.run() wrapper - Simplified entry point --- README.MD | 6 +- cmd_chat.py | 10 +- cmd_chat/__init__.py | 45 ++-- cmd_chat/client/client.py | 314 ++++++++++++----------- cmd_chat/client/config.py | 10 - cmd_chat/client/core/abs/abs_crypto.py | 1 + cmd_chat/client/core/abs/abs_renderer.py | 5 +- cmd_chat/client/core/crypto.py | 5 +- cmd_chat/client/core/default_renderer.py | 61 ----- cmd_chat/client/core/rich_renderer.py | 50 ++-- cmd_chat/server/factory.py | 50 ++++ cmd_chat/server/helpers.py | 58 +++++ cmd_chat/server/logger.py | 6 + cmd_chat/server/managers.py | 48 ++++ cmd_chat/server/models.py | 32 ++- cmd_chat/server/routes.py | 21 ++ cmd_chat/server/server.py | 128 ++------- cmd_chat/server/services.py | 36 --- cmd_chat/server/stores.py | 79 ++++++ cmd_chat/server/views.py | 138 ++++++++++ requirements.txt | 4 +- setup.py | 16 +- 22 files changed, 682 insertions(+), 441 deletions(-) delete mode 100644 cmd_chat/client/core/default_renderer.py create mode 100644 cmd_chat/server/factory.py create mode 100644 cmd_chat/server/helpers.py create mode 100644 cmd_chat/server/logger.py create mode 100644 cmd_chat/server/managers.py create mode 100644 cmd_chat/server/routes.py delete mode 100644 cmd_chat/server/services.py create mode 100644 cmd_chat/server/stores.py create mode 100644 cmd_chat/server/views.py diff --git a/README.MD b/README.MD index d689d8e..66cc849 100644 --- a/README.MD +++ b/README.MD @@ -47,13 +47,13 @@ Everything happens in memory only. Nothing is written to disk. `python -m venv venv ; .\venv\Scripts\activate ; pip install -r requirements.txt` 3. Start the server (set a password for client connections): - `python cmd_chat.py serve 0.0.0.0 1000 --password YOUR_PASSWORD` + `python cmd_chat.py serve 0.0.0.0 3000 --password YOUR_PASSWORD` 4. Connect a client: - `python cmd_chat.py connect SERVER_IP 1000 USERNAME YOUR_PASSWORD` + `python cmd_chat.py connect SERVER_IP 3000 USERNAME YOUR_PASSWORD` Example (local run): - `python cmd_chat.py connect localhost 1000 tyler YOUR_PASSWORD` + `python cmd_chat.py connect localhost 3000 tyler YOUR_PASSWORD` --- diff --git a/cmd_chat.py b/cmd_chat.py index 62c00e3..f85aa9e 100644 --- a/cmd_chat.py +++ b/cmd_chat.py @@ -1,8 +1,4 @@ -import asyncio -import cmd_chat +from cmd_chat import main -async def main(): - await cmd_chat.run() - -if __name__ == '__main__': - asyncio.run(main()) \ No newline at end of file +if __name__ == "__main__": + main() diff --git a/cmd_chat/__init__.py b/cmd_chat/__init__.py index abc7a9a..2d23216 100644 --- a/cmd_chat/__init__.py +++ b/cmd_chat/__init__.py @@ -1,38 +1,39 @@ -import asyncio import argparse from cmd_chat.server.server import run_server from cmd_chat.client.client import Client + def run_http_server(ip: str, port: int, password: str | None) -> None: - run_server(ip, int(port), False, password) + run_server(host=ip, port=int(port), admin_password=password) -async def run_client(username: str, server: str, port: int, password: str | None) -> None: - Client(server=server, port=port, username=username, password=password).run() -async def run() -> None: - parser = argparse.ArgumentParser(description='Command-line chat application') - subparsers = parser.add_subparsers(dest='command', required=True) +def main(): + parser = argparse.ArgumentParser(description="Command-line chat application") + subparsers = parser.add_subparsers(dest="command", required=True) - serve_p = subparsers.add_parser('serve', help='Run server') - serve_p.add_argument('ip_address') - serve_p.add_argument('port') - serve_p.add_argument('--password', '-p', required=True, help='Admin password required for clients') + serve_p = subparsers.add_parser("serve", help="Run server") + serve_p.add_argument("ip_address") + serve_p.add_argument("port") + serve_p.add_argument("--password", "-p", required=True) - connect_p = subparsers.add_parser('connect', help='Connect to server') - connect_p.add_argument('ip_address') - connect_p.add_argument('port') - connect_p.add_argument('username') - connect_p.add_argument('password', help='Password to auth on server') + connect_p = subparsers.add_parser("connect", help="Connect to server") + connect_p.add_argument("ip_address") + connect_p.add_argument("port") + connect_p.add_argument("username") + connect_p.add_argument("password") args = parser.parse_args() - if args.command == 'serve': + if args.command == "serve": run_http_server(args.ip_address, args.port, args.password) - elif args.command == 'connect': - await run_client(args.username, args.ip_address, int(args.port), args.password) + elif args.command == "connect": + Client( + server=args.ip_address, + port=int(args.port), + username=args.username, + password=args.password, + ).run() -def main(): - asyncio.run(run()) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/cmd_chat/client/client.py b/cmd_chat/client/client.py index ee78db8..3febb65 100644 --- a/cmd_chat/client/client.py +++ b/cmd_chat/client/client.py @@ -1,161 +1,185 @@ -import ast -import time -import threading +import asyncio +import json +from typing import Optional -from websocket import create_connection, WebSocketConnectionClosedException - -from cmd_chat.client.core.crypto import RSAService -from cmd_chat.client.core.default_renderer import DefaultClientRenderer -from cmd_chat.client.core.rich_renderer import RichClientRenderer - -from cmd_chat.client.config import RENDER_TIME +import rsa +import requests +from cryptography.fernet import Fernet +import websockets +from rich.console import Console +from rich.panel import Panel -class Client(RSAService, RichClientRenderer): - +class Client: def __init__( - self, - server: str, - port: int, - username: str, - password: str | None = None + self, server: str, port: int, username: str, password: Optional[str] = None ): - super().__init__() self.server = server self.port = port self.username = username self.password = password or "" - self.base_url = f"http://{self.server}:{self.port}" - self.ws_url = f"ws://{self.server}:{self.port}" - self.close_response = str({ - "action": "close", - "username": self.username - }) - self.__stop_threads = False + self.user_id: Optional[str] = None - def _ws_full(self, path: str) -> str: - if self.password: - return f"{self.ws_url}{path}?password={self.password}" - return f"{self.ws_url}{path}" + self.public_key: Optional[rsa.PublicKey] = None + self.private_key: Optional[rsa.PrivateKey] = None + self.fernet: Optional[Fernet] = None - def _connect_ws(self, path: str, retries: int = 5, backoff: float = 0.5): - last_exc: Exception = ConnectionError("Failed to connect") - for attempt in range(retries): - try: - return create_connection(self._ws_full(path)) - except Exception as exc: - last_exc = exc - time.sleep(backoff * (2 ** attempt)) - print(f"Can't connect to {path}: {last_exc}") - raise last_exc + self.console = Console() + self.messages: list[dict] = [] + self.users: list[dict] = [] + self.connected = False + self.running = False - def send_info(self): - ws = self._connect_ws("/talk") - try: - while not self.__stop_threads: - try: - user_input = input("You're message: ") - if user_input == "q": - self.__stop_threads = True - try: - if ws: - ws.send(self.close_response) - ws.close() - except Exception: - pass - break - message = f'{self.username}: {user_input}' - socket_message = str({ - "text": self._encrypt(message), - "username": self.username - }) - ws.send(socket_message) - except (WebSocketConnectionClosedException, ConnectionResetError, ConnectionAbortedError, OSError): - try: - if ws: - try: - ws.close() - except Exception: - pass - ws = self._connect_ws("/talk") - continue - except Exception: - print("Can't establish channel") - self.__stop_threads = True - break - except KeyboardInterrupt: - self.__stop_threads = True - try: - ws.send(self.close_response) - ws.close() - except Exception: - pass - break - finally: - try: - ws.close() - except Exception: - pass + @property + def base_url(self) -> str: + return f"http://{self.server}:{self.port}" - def update_info(self): - ws = self._connect_ws("/update") - last_try = None - try: - while not self.__stop_threads: - try: - time.sleep(RENDER_TIME) - raw = ws.recv() - if isinstance(raw, bytes): - raw = raw.decode("utf-8") - response = ast.literal_eval(raw) - if last_try == response: - continue - last_try = response - self.clear_console() - if len(last_try["messages"]) > 0: - self.print_chat(response=last_try) - except (WebSocketConnectionClosedException, ConnectionResetError, ConnectionAbortedError, OSError): - try: - if ws: - try: - ws.close() - except Exception: - pass - ws = self._connect_ws("/update") - continue - except Exception: - print("Connection lost: can't establish update channel") - self.__stop_threads = True - break - except KeyboardInterrupt: - self.__stop_threads = True - try: - ws.send(self.close_response) - ws.close() - except Exception: - pass - break - finally: - try: - ws.close() - except Exception: - pass + @property + def ws_url(self) -> str: + return f"ws://{self.server}:{self.port}" - def _validate_keys(self) -> None: - self._request_key( - url=f"{self.base_url}/get_key", - username=self.username, - password=self.password + def success(self, message: str) -> None: + self.console.print(f"[green]✓ {message}[/]") + + def error(self, message: str) -> None: + self.console.print(f"[red]✗ {message}[/]") + + def info(self, message: str) -> None: + self.console.print(f"[cyan]• {message}[/]") + + def generate_keys(self) -> None: + with self.console.status( + "[cyan]Generating RSA keys (2048 bit)...[/]", spinner="dots" + ): + self.public_key, self.private_key = rsa.newkeys(2048) + self.success("RSA keys generated") + + def exchange_keys(self) -> None: + with self.console.status( + "[cyan]Exchanging keys with server...[/]", spinner="dots" + ): + pubkey_bytes = self.public_key.save_pkcs1() + response = requests.post( + f"{self.base_url}/get_key", + files={"pubkey": ("key.pem", pubkey_bytes)}, + data={"username": self.username, "password": self.password}, + timeout=30, + ) + response.raise_for_status() + + self.user_id = response.headers.get("X-User-Id") + encrypted_key = response.content + symmetric_key = rsa.decrypt(encrypted_key, self.private_key) + self.fernet = Fernet(symmetric_key) + + self.success(f"Key exchange complete (session: {self.user_id[:8]}...)") + self.public_key = None + self.private_key = None + + def render_messages(self) -> None: + self.console.clear() + + users_online = ", ".join(u.get("username", "?") for u in self.users) or "none" + self.console.print(f"[dim]Online: {users_online}[/]") + self.console.print("─" * 60) + + display_messages = ( + self.messages[-15:] if len(self.messages) > 15 else self.messages ) - self._remove_keys() - def run(self): - self._validate_keys() - threads = [ - threading.Thread(target=self.send_info, daemon=True), - threading.Thread(target=self.update_info, daemon=True) - ] - for th in threads: - th.start() - for th in threads: - th.join() + for msg in display_messages: + username = msg.get("username", "unknown") + text = msg.get("text", "") + timestamp = str(msg.get("timestamp", ""))[:19].replace("T", " ") + + style = "green" if username == self.username else "cyan" + self.console.print(f"[dim]{timestamp}[/] [{style}]{username}[/]: {text}") + + if not display_messages: + self.console.print("[dim italic]No messages yet...[/]") + + self.console.print("─" * 60) + self.console.print("[dim]Type message and press Enter. 'q' to quit.[/]") + + async def receive_loop(self, ws) -> None: + try: + async for raw in ws: + if not self.running: + break + + data = json.loads(raw) + msg_type = data.get("type", "") + + if msg_type == "init": + self.messages = data.get("messages", []) + self.users = data.get("users", []) + self.connected = True + self.render_messages() + elif msg_type == "message": + msg_data = data.get("data", {}) + self.messages.append(msg_data) + self.render_messages() + elif msg_type == "user_left": + left_id = data.get("user_id") + self.users = [u for u in self.users if u.get("user_id") != left_id] + self.render_messages() + + except websockets.ConnectionClosed: + self.connected = False + + async def input_loop(self, ws) -> None: + loop = asyncio.get_event_loop() + while self.running: + try: + text = await loop.run_in_executor(None, input) + if text.lower() in ("q", "quit", "exit"): + self.running = False + break + if text.strip(): + await ws.send(text) + except (EOFError, KeyboardInterrupt): + self.running = False + break + + async def run_async(self) -> None: + self.console.clear() + self.console.print(Panel("[bold cyan]CMD Chat Client[/]", expand=False)) + self.console.print() + + try: + self.generate_keys() + self.exchange_keys() + + self.info("Connecting to chat...") + url = ( + f"{self.ws_url}/ws/chat?user_id={self.user_id}&password={self.password}" + ) + + async with websockets.connect(url) as ws: + self.success("Connected to chat server") + self.running = True + + receive_task = asyncio.create_task(self.receive_loop(ws)) + input_task = asyncio.create_task(self.input_loop(ws)) + + done, pending = await asyncio.wait( + [receive_task, input_task], return_when=asyncio.FIRST_COMPLETED + ) + + for task in pending: + task.cancel() + + self.console.print("\n[yellow]Disconnected[/]") + + except requests.exceptions.ConnectionError: + self.error(f"Cannot connect to {self.base_url}") + except requests.exceptions.HTTPError as e: + self.error(f"Server error: {e.response.status_code} - {e.response.text}") + except Exception as e: + import traceback + + self.error(f"Error: {e}") + traceback.print_exc() + + def run(self) -> None: + asyncio.run(self.run_async()) diff --git a/cmd_chat/client/config.py b/cmd_chat/client/config.py index 2a24f18..b5eeaf9 100644 --- a/cmd_chat/client/config.py +++ b/cmd_chat/client/config.py @@ -1,11 +1 @@ -from colorama import Fore - -COLORS = { - "text_color": Fore.WHITE, - "my_username_color": Fore.MAGENTA, - "ip_color": Fore.MAGENTA, - "username_color": Fore.GREEN -} - -RENDER_TIME = 0.05 MESSAGES_TO_SHOW = 5 \ No newline at end of file diff --git a/cmd_chat/client/core/abs/abs_crypto.py b/cmd_chat/client/core/abs/abs_crypto.py index e9ab4fb..2b371e4 100644 --- a/cmd_chat/client/core/abs/abs_crypto.py +++ b/cmd_chat/client/core/abs/abs_crypto.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod + class CryptoService(ABC): @abstractmethod diff --git a/cmd_chat/client/core/abs/abs_renderer.py b/cmd_chat/client/core/abs/abs_renderer.py index 38ef96e..dfd94b4 100644 --- a/cmd_chat/client/core/abs/abs_renderer.py +++ b/cmd_chat/client/core/abs/abs_renderer.py @@ -2,10 +2,9 @@ from abc import ABC, abstractmethod class ClientRenderer(ABC): - # These attributes are expected to be provided by subclasses - # (typically via multiple inheritance with CryptoService) + username: str - + @abstractmethod def _decrypt(self, message: str) -> str: """Decrypt an encrypted message (provided by crypto mixin).""" diff --git a/cmd_chat/client/core/crypto.py b/cmd_chat/client/core/crypto.py index 976e7ab..d599c52 100644 --- a/cmd_chat/client/core/crypto.py +++ b/cmd_chat/client/core/crypto.py @@ -28,17 +28,16 @@ class RSAService(CryptoService): stream=True, ) r.raise_for_status() - # read the full response content (server returns encrypted symmetric key) message = r.content self.symmetric_key = rsa.decrypt(message, self.private_key) self.fernet = Fernet(self.symmetric_key) def _generate_keys(self): - self.public_key, self.private_key = rsa.newkeys(512) + self.public_key, self.private_key = rsa.newkeys(2048) def _get_generated_keys(self): return self.private_key, self.public_key def _remove_keys(self): self.public_key = None - self.private_key = None \ No newline at end of file + self.private_key = None diff --git a/cmd_chat/client/core/default_renderer.py b/cmd_chat/client/core/default_renderer.py deleted file mode 100644 index 765bdee..0000000 --- a/cmd_chat/client/core/default_renderer.py +++ /dev/null @@ -1,61 +0,0 @@ -import os -import platform - -from cmd_chat.client.core.abs.abs_renderer import ClientRenderer -from cmd_chat.client.config import COLORS - -from colorama import init - -init() - - -class DefaultClientRenderer(ClientRenderer): - - def __get_os(self) -> str: - """ checking what kind of platform you need - """ - if "Linux" in str(platform.platform()): - return "Linux" - return "Windows" - - def print_message(self, message: str) -> str: - """ generating string with message in required format - """ - # split only on the first ':' to keep message contents intact - parts = message.split(":", 1) - if parts[0] == self.username: - return COLORS["my_username_color"] + parts[0] + ": " + parts[1] + COLORS["text_color"] - return parts[0] + ": " + parts[1] + COLORS["text_color"] - - def clear_console(self): - # For windows clear command its cls - # For linux clear command its clear - if self.__get_os() == "Linux": - os.system("clear") - else: - os.system("cls") - - def print_ip( - self, - ip: str - ) -> str: - return f"IP: " + COLORS["ip_color"] + ip + COLORS["text_color"] - - def print_username( - self, - username: str - ) -> str: - # Username label + colored username - return f"USERNAME: " + COLORS["username_color"] + username + COLORS["text_color"] - - def print_chat(self, response) -> None: - for i, msg in enumerate(response["messages"]): - actual_message = self._decrypt(msg) - if i == 0: - for user in response["users_in_chat"]: - print(self.print_ip(user.split(",")[0])) - print(self.print_username(user.split(",")[1])) - print("Write 'q' to quit from chat") - print(f"\n{self.print_message(actual_message)}") - else: - print(f"{self.print_message(actual_message)}") diff --git a/cmd_chat/client/core/rich_renderer.py b/cmd_chat/client/core/rich_renderer.py index 0f40c3c..fecfb64 100644 --- a/cmd_chat/client/core/rich_renderer.py +++ b/cmd_chat/client/core/rich_renderer.py @@ -1,9 +1,9 @@ -import os +import os import platform -from rich.text import Text +from rich.text import Text from rich.style import Style -from rich.console import Console +from rich.console import Console from rich.table import Table from cmd_chat.client.core.abs.abs_renderer import ClientRenderer @@ -16,26 +16,26 @@ console = Console(width=75) class RichClientRenderer(ClientRenderer): def __get_os(self) -> str: - """ checking what kind of platform you need - """ + """checking what kind of platform you need""" if "Linux" in str(platform.platform()): return "Linux" return "Windows" - + def print_message(self, message: str) -> Text: - """ generating string with message in required format - """ + """generating string with message in required format""" # split only on the first ':' so message bodies containing ':' are preserved parts = message.split(":", 1) if parts[0] == self.username: - return \ - Text(text=parts[0], style="bold") + \ - Text(text=": ", style="bold") + \ - Text(text=parts[1], style="underline") - return \ - Text(text=parts[0], style="bold") + \ - Text(text=": ", style="bold") + \ - Text(text=parts[1], style="underline") + return ( + Text(text=parts[0], style="bold") + + Text(text=": ", style="bold") + + Text(text=parts[1], style="underline") + ) + return ( + Text(text=parts[0], style="bold") + + Text(text=": ", style="bold") + + Text(text=parts[1], style="underline") + ) def clear_console(self): # For windows clear command its cls @@ -45,16 +45,10 @@ class RichClientRenderer(ClientRenderer): else: os.system("cls") - def print_ip( - self, - ip: str - ) -> str: + def print_ip(self, ip: str) -> str: return ip - - def print_username( - self, - username: str - ) -> str: + + def print_username(self, username: str) -> str: return username def print_chat(self, response) -> None: @@ -68,11 +62,11 @@ class RichClientRenderer(ClientRenderer): table.add_column("USERNAME") for user in response["users_in_chat"]: table.add_row( - self.print_ip(user.split(',')[0]), - self.print_username(user.split(",")[1]) + self.print_ip(user.split(",")[0]), + self.print_username(user.split(",")[1]), ) console.print(table) console.print("Write 'q' to quit from chat", justify="left") console.print(f"\n{self.print_message(actual_message)}") else: - console.print(f"{self.print_message(actual_message)}") \ No newline at end of file + console.print(f"{self.print_message(actual_message)}") diff --git a/cmd_chat/server/factory.py b/cmd_chat/server/factory.py new file mode 100644 index 0000000..ee2ce87 --- /dev/null +++ b/cmd_chat/server/factory.py @@ -0,0 +1,50 @@ +import asyncio +from contextlib import suppress +from cryptography.fernet import Fernet +from sanic import Sanic +from sanic_ext import Extend + +from .managers import ConnectionManager +from .stores import MessageStore, UserSessionStore +from .logger import logger + +from .routes import register_routes + + +def create_app() -> Sanic: + app = Sanic("cmd-chat-server") + Extend(app) + + app.ctx.message_store = MessageStore() + app.ctx.session_store = UserSessionStore() + app.ctx.connection_manager = ConnectionManager() + app.ctx.admin_password = None + app.ctx.fernet_key = Fernet.generate_key() + app.ctx.cleanup_task = None + + register_lifecycle(app) + register_routes(app) + + return app + + +def register_lifecycle(app: Sanic) -> None: + @app.before_server_start + async def setup(app: Sanic): + logger.info("Server starting...") + app.ctx.cleanup_task = asyncio.create_task(cleanup_stale_sessions(app)) + + @app.after_server_stop + async def teardown(app: Sanic): + logger.info("Server shutting down...") + if app.ctx.cleanup_task: + app.ctx.cleanup_task.cancel() + with suppress(asyncio.CancelledError): + await app.ctx.cleanup_task + + +async def cleanup_stale_sessions(app: Sanic) -> None: + while True: + with suppress(asyncio.CancelledError): + await asyncio.sleep(300) + await app.ctx.session_store.cleanup_stale() diff --git a/cmd_chat/server/helpers.py b/cmd_chat/server/helpers.py new file mode 100644 index 0000000..506a88c --- /dev/null +++ b/cmd_chat/server/helpers.py @@ -0,0 +1,58 @@ +from datetime import datetime, timezone +from typing import Optional +from dataclasses import asdict +import json +from sanic import Sanic, Request, response, Websocket + + +def utcnow() -> datetime: + return datetime.now(timezone.utc) + + +def verify_password(password: Optional[str], expected: Optional[str]) -> bool: + if not expected: + return True + return password == expected + + +def get_client_ip(request: Request) -> str: + if forwarded := request.headers.get("x-forwarded-for"): + return forwarded.split(",")[0].strip() + return request.ip + + +def get_param(request: Request, name: str) -> Optional[str]: + return request.args.get(name) or request.form.get(name) + + +def require_auth(request: Request, app: Sanic) -> Optional[response.HTTPResponse]: + if not verify_password(get_param(request, "password"), app.ctx.admin_password): + return response.text("Unauthorized", status=401) + return None + + +async def send_state(ws: Websocket, app: Sanic) -> None: + messages = await app.ctx.message_store.get_all() + users = await app.ctx.session_store.get_all() + await ws.send( + json.dumps( + { + "type": "init", + "messages": [asdict(m) for m in messages], + "users": [ + {"user_id": u.user_id, "username": u.username} for u in users + ], + } + ) + ) + + +def extract_pubkey(request: Request) -> Optional[bytes]: + if files := request.files.get("pubkey"): + file = files[0] if isinstance(files, list) else files + return file.body + if raw := request.form.get("pubkey"): + return raw.encode() if isinstance(raw, str) else raw + if raw := request.args.get("pubkey"): + return raw.encode() if isinstance(raw, str) else raw + return None diff --git a/cmd_chat/server/logger.py b/cmd_chat/server/logger.py new file mode 100644 index 0000000..6eca880 --- /dev/null +++ b/cmd_chat/server/logger.py @@ -0,0 +1,6 @@ +import logging + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) diff --git a/cmd_chat/server/managers.py b/cmd_chat/server/managers.py new file mode 100644 index 0000000..670b0d5 --- /dev/null +++ b/cmd_chat/server/managers.py @@ -0,0 +1,48 @@ +import asyncio +from typing import Optional +from sanic import Websocket +from .logger import logger + + +class ConnectionManager: + def __init__(self): + self.active_connections: dict[str, Websocket] = {} + self._lock = asyncio.Lock() + + async def connect(self, user_id: str, websocket: Websocket) -> None: + async with self._lock: + self.active_connections[user_id] = websocket + logger.info(f"Client connected: {user_id}") + + async def disconnect(self, user_id: str) -> None: + async with self._lock: + if user_id in self.active_connections: + del self.active_connections[user_id] + logger.info(f"Client disconnected: {user_id}") + + async def broadcast(self, message: str, exclude_user: Optional[str] = None) -> None: + async with self._lock: + disconnected = [] + for user_id, connection in list(self.active_connections.items()): + if exclude_user and user_id == exclude_user: + continue + try: + await connection.send(message) + except Exception as e: + logger.warning(f"Failed to send message to {user_id}: {e}") + disconnected.append(user_id) + + for user_id in disconnected: + if user_id in self.active_connections: + del self.active_connections[user_id] + + async def send_personal(self, user_id: str, message: str) -> bool: + async with self._lock: + if connection := self.active_connections.get(user_id): + try: + await connection.send(message) + return True + except Exception as e: + logger.warning(f"Failed to send personal message to {user_id}: {e}") + return False + return False diff --git a/cmd_chat/server/models.py b/cmd_chat/server/models.py index adcc7de..a61e771 100644 --- a/cmd_chat/server/models.py +++ b/cmd_chat/server/models.py @@ -1,5 +1,31 @@ -from pydantic import BaseModel +from dataclasses import dataclass, field +from uuid import uuid4 +from datetime import datetime +from typing import Optional -class Message(BaseModel): - message: str \ No newline at end of file +@dataclass +class Message: + id: str = field(default_factory=lambda: str(uuid4())) + text: str = "" + timestamp: str = field(default_factory=lambda: datetime.utcnow().isoformat()) + user_ip: str = "" + username: str = "" + + +@dataclass +class UserSession: + user_id: str + ip: str + username: str = "unknown" + fernet_key: Optional[bytes] = None + created_at: str = field(default_factory=lambda: datetime.utcnow().isoformat()) + last_activity: str = field(default_factory=lambda: datetime.utcnow().isoformat()) + active: bool = True + + def update_activity(self): + self.last_activity = datetime.utcnow().isoformat() + + def is_stale(self, timeout_seconds: int = 3600) -> bool: + last = datetime.fromisoformat(self.last_activity) + return (datetime.utcnow() - last).total_seconds() > timeout_seconds diff --git a/cmd_chat/server/routes.py b/cmd_chat/server/routes.py new file mode 100644 index 0000000..3c68256 --- /dev/null +++ b/cmd_chat/server/routes.py @@ -0,0 +1,21 @@ +from sanic import Sanic, Request, Websocket + +from . import views + + +def register_routes(app: Sanic) -> None: + @app.route("/get_key", methods=["GET", "POST"]) + async def get_key_route(request: Request): + return await views.get_key(request, app) + + @app.websocket("/ws/chat") + async def chat_ws_route(request: Request, ws: Websocket): + await views.chat_ws(request, ws, app) + + @app.get("/health") + async def health_route(request: Request): + return await views.health(request, app) + + @app.delete("/clear") + async def clear_route(request: Request): + return await views.clear_messages(request, app) diff --git a/cmd_chat/server/server.py b/cmd_chat/server/server.py index 33bbe36..92bc83b 100644 --- a/cmd_chat/server/server.py +++ b/cmd_chat/server/server.py @@ -1,113 +1,23 @@ -import asyncio -import rsa -from cryptography.fernet import Fernet -from functools import partial -from sanic.worker.loader import AppLoader -from sanic.response import HTTPResponse -from sanic import Sanic, Request, response, Websocket -from cmd_chat.server.models import Message -from cmd_chat.server.services import ( - _get_bytes_and_serialize, - _check_ws_for_close_status, - _generate_new_message, - _generate_update_payload -) +from typing import Optional +from .logger import logger +from .factory import create_app -app = Sanic("app") -app.config.OAS = False - -MESSAGES_MEMORY_DB: list[Message] = [] -USERS: dict[str, str] = {} -PUBLIC_KEY = Fernet.generate_key() +app = create_app() -def _check_password(request: Request, expected: str | None) -> bool: - if not expected: - return True - q = request.args.get("password") - f = request.form.get("password") if hasattr(request, "form") else None - return (q or f) == expected +def run_server( + host: str = "0.0.0.0", + port: int = 8000, + admin_password: Optional[str] = None, + workers: int = 1, +) -> None: + app.ctx.admin_password = admin_password + logger.info(f"Starting server on {host}:{port}") -def _get_str_arg(request: Request, name: str) -> str | None: - return request.form.get(name) or request.args.get(name) - -def attach_endpoints(app: Sanic): - @app.websocket("/talk") - async def talk_ws_view(request: Request, ws: Websocket) -> HTTPResponse: - if not _check_password(request, app.ctx.ADMIN_PASSWORD): - await ws.close(code=4001, reason="unauthorized") - return - while True: - serialized_message: dict = await _get_bytes_and_serialize(ws) - await _check_ws_for_close_status(serialized_message, ws) - text = serialized_message.get("text") - if text is None: - continue - new_message = await _generate_new_message(text) - MESSAGES_MEMORY_DB.append(new_message) - await ws.send(str({"status": "ok"})) - await asyncio.sleep(0.2) - - @app.websocket("/update") - async def update_ws_view(request: Request, ws: Websocket) -> HTTPResponse: - if not _check_password(request, app.ctx.ADMIN_PASSWORD): - await ws.close(code=4001, reason="unauthorized") - return - while True: - payload = await _generate_update_payload(MESSAGES_MEMORY_DB, USERS) - await ws.send(payload.encode()) - await asyncio.sleep(0.2) - - @app.route('/get_key', methods=['GET', 'POST']) - async def get_key_view(request: Request) -> HTTPResponse: - if not _check_password(request, app.ctx.ADMIN_PASSWORD): - return response.text("unauthorized", status=401) - - pubkey_bytes: bytes | None = None - - if "pubkey" in request.files and request.files.get("pubkey"): - f = request.files.get("pubkey") - if isinstance(f, list): - f = f[0] - pubkey_bytes = f.body - - if pubkey_bytes is None: - raw = request.form.get("pubkey") - if raw: - pubkey_bytes = raw if isinstance(raw, bytes) else str(raw).encode() - - if pubkey_bytes is None: - raw = request.args.get("pubkey") - if raw: - pubkey_bytes = raw.encode() - - if not pubkey_bytes: - return response.text("bad request: pubkey is required", status=400) - - try: - public_key = rsa.PublicKey.load_pkcs1(pubkey_bytes) - except Exception as e: - return response.text(f"bad pubkey: {e}", status=400) - - encrypted_data = rsa.encrypt(PUBLIC_KEY, public_key) - - username = _get_str_arg(request, "username") or "unknown" - user_key = f"{request.ip}, {username}" - if user_key not in USERS: - USERS[user_key] = PUBLIC_KEY - - return response.raw(encrypted_data) - - -def create_app(app_name: str, admin_password: str | None) -> Sanic: - app = Sanic(app_name) - app.ctx.ADMIN_PASSWORD = admin_password - attach_endpoints(app) - return app - - -def run_server(host: str, port: int, dev: bool = False, admin_password: str | None = None) -> None: - loader = AppLoader(factory=partial(create_app, "CMD_SERVER", admin_password)) - app = loader.load() - app.prepare(host=host, port=port, dev=dev) - Sanic.serve(primary=app, app_loader=loader) + app.run( + host=host, + port=port, + workers=workers, + debug=False, + access_log=True, + ) diff --git a/cmd_chat/server/services.py b/cmd_chat/server/services.py deleted file mode 100644 index e951d32..0000000 --- a/cmd_chat/server/services.py +++ /dev/null @@ -1,36 +0,0 @@ -import ast -from sanic import Websocket -from cmd_chat.server.models import Message - - -async def _get_bytes_and_serialize( - ws: Websocket -) -> dict: - return ast.literal_eval(await ws.recv()) - - -async def _check_ws_for_close_status( - response: dict, - ws: Websocket -) -> None: - if "action" in response.keys(): - if response["action"] == "close": - await ws.close() - - -async def _generate_new_message( - message: str -) -> Message: - return Message(message = message) - - -async def _generate_update_payload( - memory_msgs: list[Message], - users_structure: dict -) -> str: - return str({ - "messages": [i.message for i in memory_msgs], - "users_in_chat": list(users_structure.keys()) - }) - - diff --git a/cmd_chat/server/stores.py b/cmd_chat/server/stores.py new file mode 100644 index 0000000..50b5af7 --- /dev/null +++ b/cmd_chat/server/stores.py @@ -0,0 +1,79 @@ +import asyncio +from typing import Optional + +from .models import Message, UserSession +from .logger import logger + + +class MessageStore: + def __init__(self): + self._messages: list[Message] = [] + self._lock = asyncio.Lock() + + async def add(self, message: Message) -> None: + async with self._lock: + self._messages.append(message) + logger.info(f"Message added: {message.id} from {message.username}") + + async def get_all(self) -> list[Message]: + async with self._lock: + return self._messages.copy() + + async def clear(self) -> None: + async with self._lock: + count = len(self._messages) + self._messages.clear() + logger.info(f"Cleared {count} messages") + + async def count(self) -> int: + async with self._lock: + return len(self._messages) + + +class UserSessionStore: + def __init__(self): + self._sessions: dict[str, UserSession] = {} + self._lock = asyncio.Lock() + + async def add(self, session: UserSession) -> None: + async with self._lock: + self._sessions[session.user_id] = session + logger.info(f"Session created: {session.user_id} ({session.username})") + + async def get(self, user_id: str) -> Optional[UserSession]: + async with self._lock: + return self._sessions.get(user_id) + + async def update_activity(self, user_id: str) -> None: + async with self._lock: + if session := self._sessions.get(user_id): + session.update_activity() + + async def remove(self, user_id: str) -> None: + async with self._lock: + if user_id in self._sessions: + del self._sessions[user_id] + logger.info(f"Session removed: {user_id}") + + async def cleanup_stale(self, timeout_seconds: int = 3600) -> int: + async with self._lock: + stale_ids = [ + uid for uid, s in self._sessions.items() if s.is_stale(timeout_seconds) + ] + for uid in stale_ids: + del self._sessions[uid] + if stale_ids: + logger.info(f"Cleaned up {len(stale_ids)} stale sessions") + return len(stale_ids) + + async def get_all(self) -> list[UserSession]: + async with self._lock: + return list(self._sessions.values()) + + async def count(self) -> int: + async with self._lock: + return len(self._sessions) + + async def username_exists(self, username: str) -> bool: + async with self._lock: + return any(s.username == username for s in self._sessions.values()) diff --git a/cmd_chat/server/views.py b/cmd_chat/server/views.py new file mode 100644 index 0000000..50f60cc --- /dev/null +++ b/cmd_chat/server/views.py @@ -0,0 +1,138 @@ +from dataclasses import asdict +from uuid import uuid4 +import json + +import rsa +from sanic import Sanic, Request, response, Websocket +from sanic.response import HTTPResponse, json as json_response + +from .models import Message, UserSession +from .logger import logger +from .helpers import ( + require_auth, + extract_pubkey, + get_client_ip, + get_param, + verify_password, + send_state, + utcnow, +) + + +async def get_key(request: Request, app: Sanic) -> HTTPResponse: + if err := require_auth(request, app): + return err + + pubkey_bytes = extract_pubkey(request) + if not pubkey_bytes: + return response.text("Bad request: pubkey is required", status=400) + + try: + public_key = rsa.PublicKey.load_pkcs1(pubkey_bytes) + if public_key.n.bit_length() < 2048: + raise ValueError("RSA key must be at least 2048 bits") + except Exception as e: + logger.warning(f"Invalid public key: {e}") + return response.text(f"Bad pubkey: {e}", status=400) + + username = get_param(request, "username") or "unknown" + + if await app.ctx.session_store.username_exists(username): + return response.text(f"Username '{username}' is already taken", status=409) + + session = UserSession( + user_id=str(uuid4()), + ip=get_client_ip(request), + username=get_param(request, "username") or "unknown", + fernet_key=app.ctx.fernet_key, + ) + await app.ctx.session_store.add(session) + + try: + encrypted_key = rsa.encrypt(app.ctx.fernet_key, public_key) + logger.info(f"Key exchange: user={session.username}, session={session.user_id}") + + return response.raw( + encrypted_key, + content_type="application/octet-stream", + headers={"X-User-Id": session.user_id}, + ) + except Exception as e: + logger.error(f"Encryption failed: {e}") + return response.text("Key encryption failed", status=500) + + +async def chat_ws(request: Request, ws: Websocket, app: Sanic) -> None: + user_id = request.args.get("user_id") + + if not user_id: + await ws.close(code=4002, reason="user_id required") + return + + if not verify_password(request.args.get("password"), app.ctx.admin_password): + await ws.close(code=4001, reason="Unauthorized") + return + + session = await app.ctx.session_store.get(user_id) + if not session: + await ws.close(code=4002, reason="Invalid session") + return + + manager = app.ctx.connection_manager + await manager.connect(user_id, ws) + + try: + await send_state(ws, app) + + async for data in ws: + if data is None: + break + + await app.ctx.session_store.update_activity(user_id) + + message = Message( + text=str(data), + user_ip=session.ip, + username=session.username, + ) + await app.ctx.message_store.add(message) + + await manager.broadcast( + json.dumps( + { + "type": "message", + "data": asdict(message), + } + ) + ) + + except Exception as e: + logger.error(f"WebSocket error for {user_id}: {e}") + finally: + await manager.disconnect(user_id) + await manager.broadcast( + json.dumps( + { + "type": "user_left", + "user_id": user_id, + } + ) + ) + + +async def health(request: Request, app: Sanic) -> HTTPResponse: + return json_response( + { + "status": "ok", + "messages": await app.ctx.message_store.count(), + "users": await app.ctx.session_store.count(), + "timestamp": utcnow().isoformat(), + } + ) + + +async def clear_messages(request: Request, app: Sanic) -> HTTPResponse: + if err := require_auth(request, app): + return err + await app.ctx.message_store.clear() + return json_response({"status": "cleared"}) diff --git a/requirements.txt b/requirements.txt index a3cdb17..d0283a2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,6 @@ colorama pydantic websocket-client flask -rich \ No newline at end of file +rich +sanic_ext +websockets \ No newline at end of file diff --git a/setup.py b/setup.py index 7eef936..b30a158 100644 --- a/setup.py +++ b/setup.py @@ -14,13 +14,9 @@ setuptools.setup( long_description=description, long_description_content_type="text/markdown", url="https://github.com/dinosaurtirex/cmd-chat", - license='MIT', - python_requires='>=3.10', - entry_points={ - 'console_scripts': [ - 'cmd_chat = cmd_chat:main' - ] - }, + license="MIT", + python_requires=">=3.10", + entry_points={"console_scripts": ["cmd_chat = cmd_chat:main"]}, install_requires=[ "sanic", "requests", @@ -28,6 +24,6 @@ setuptools.setup( "cryptography", "colorama", "pydantic", - "websocket-client" - ] -) \ No newline at end of file + "websocket-client", + ], +)