feat: complete client-server architecture refactoring

Server:
- Split into views, routes, helpers, models modules
- Merged /ws/talk and /ws/update into single /ws/chat endpoint
- Replaced polling with push-based broadcast model
- Added username uniqueness validation on connect
- Fixed run_server arguments bug (workers parameter)
- Removed deprecated loop argument from Sanic listeners
- Replaced datetime.utcnow() with timezone-aware datetime.now(timezone.utc)

Client:
- Rewrote client as single-file module
- Migrated from websocket-client to websockets (asyncio)
- Fixed websocket-client conflict with asyncio event loop on Windows
- Added progress indicators for key generation, exchange, connection
- Added animated 3D spinning cube in UI
- Updated RSA key from 512 to 2048 bits

CLI:
- Removed unnecessary asyncio.run() wrapper
- Simplified entry point
This commit is contained in:
mirai 2026-01-02 14:42:33 +03:00
parent faaadd839b
commit 95f8a192b5
22 changed files with 682 additions and 441 deletions

View File

@ -47,13 +47,13 @@ Everything happens in memory only. Nothing is written to disk.
`python -m venv venv ; .\venv\Scripts\activate ; pip install -r requirements.txt`
3. Start the server (set a password for client connections):
`python cmd_chat.py serve 0.0.0.0 1000 --password YOUR_PASSWORD`
`python cmd_chat.py serve 0.0.0.0 3000 --password YOUR_PASSWORD`
4. Connect a client:
`python cmd_chat.py connect SERVER_IP 1000 USERNAME YOUR_PASSWORD`
`python cmd_chat.py connect SERVER_IP 3000 USERNAME YOUR_PASSWORD`
Example (local run):
`python cmd_chat.py connect localhost 1000 tyler YOUR_PASSWORD`
`python cmd_chat.py connect localhost 3000 tyler YOUR_PASSWORD`
---

View File

@ -1,8 +1,4 @@
import asyncio
import cmd_chat
from cmd_chat import main
async def main():
await cmd_chat.run()
if __name__ == '__main__':
asyncio.run(main())
if __name__ == "__main__":
main()

View File

@ -1,38 +1,39 @@
import asyncio
import argparse
from cmd_chat.server.server import run_server
from cmd_chat.client.client import Client
def run_http_server(ip: str, port: int, password: str | None) -> None:
run_server(ip, int(port), False, password)
run_server(host=ip, port=int(port), admin_password=password)
async def run_client(username: str, server: str, port: int, password: str | None) -> None:
Client(server=server, port=port, username=username, password=password).run()
async def run() -> None:
parser = argparse.ArgumentParser(description='Command-line chat application')
subparsers = parser.add_subparsers(dest='command', required=True)
def main():
parser = argparse.ArgumentParser(description="Command-line chat application")
subparsers = parser.add_subparsers(dest="command", required=True)
serve_p = subparsers.add_parser('serve', help='Run server')
serve_p.add_argument('ip_address')
serve_p.add_argument('port')
serve_p.add_argument('--password', '-p', required=True, help='Admin password required for clients')
serve_p = subparsers.add_parser("serve", help="Run server")
serve_p.add_argument("ip_address")
serve_p.add_argument("port")
serve_p.add_argument("--password", "-p", required=True)
connect_p = subparsers.add_parser('connect', help='Connect to server')
connect_p.add_argument('ip_address')
connect_p.add_argument('port')
connect_p.add_argument('username')
connect_p.add_argument('password', help='Password to auth on server')
connect_p = subparsers.add_parser("connect", help="Connect to server")
connect_p.add_argument("ip_address")
connect_p.add_argument("port")
connect_p.add_argument("username")
connect_p.add_argument("password")
args = parser.parse_args()
if args.command == 'serve':
if args.command == "serve":
run_http_server(args.ip_address, args.port, args.password)
elif args.command == 'connect':
await run_client(args.username, args.ip_address, int(args.port), args.password)
elif args.command == "connect":
Client(
server=args.ip_address,
port=int(args.port),
username=args.username,
password=args.password,
).run()
def main():
asyncio.run(run())
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -1,161 +1,185 @@
import ast
import time
import threading
import asyncio
import json
from typing import Optional
from websocket import create_connection, WebSocketConnectionClosedException
from cmd_chat.client.core.crypto import RSAService
from cmd_chat.client.core.default_renderer import DefaultClientRenderer
from cmd_chat.client.core.rich_renderer import RichClientRenderer
from cmd_chat.client.config import RENDER_TIME
import rsa
import requests
from cryptography.fernet import Fernet
import websockets
from rich.console import Console
from rich.panel import Panel
class Client(RSAService, RichClientRenderer):
class Client:
def __init__(
self,
server: str,
port: int,
username: str,
password: str | None = None
self, server: str, port: int, username: str, password: Optional[str] = None
):
super().__init__()
self.server = server
self.port = port
self.username = username
self.password = password or ""
self.base_url = f"http://{self.server}:{self.port}"
self.ws_url = f"ws://{self.server}:{self.port}"
self.close_response = str({
"action": "close",
"username": self.username
})
self.__stop_threads = False
self.user_id: Optional[str] = None
def _ws_full(self, path: str) -> str:
if self.password:
return f"{self.ws_url}{path}?password={self.password}"
return f"{self.ws_url}{path}"
self.public_key: Optional[rsa.PublicKey] = None
self.private_key: Optional[rsa.PrivateKey] = None
self.fernet: Optional[Fernet] = None
def _connect_ws(self, path: str, retries: int = 5, backoff: float = 0.5):
last_exc: Exception = ConnectionError("Failed to connect")
for attempt in range(retries):
try:
return create_connection(self._ws_full(path))
except Exception as exc:
last_exc = exc
time.sleep(backoff * (2 ** attempt))
print(f"Can't connect to {path}: {last_exc}")
raise last_exc
self.console = Console()
self.messages: list[dict] = []
self.users: list[dict] = []
self.connected = False
self.running = False
def send_info(self):
ws = self._connect_ws("/talk")
try:
while not self.__stop_threads:
try:
user_input = input("You're message: ")
if user_input == "q":
self.__stop_threads = True
try:
if ws:
ws.send(self.close_response)
ws.close()
except Exception:
pass
break
message = f'{self.username}: {user_input}'
socket_message = str({
"text": self._encrypt(message),
"username": self.username
})
ws.send(socket_message)
except (WebSocketConnectionClosedException, ConnectionResetError, ConnectionAbortedError, OSError):
try:
if ws:
try:
ws.close()
except Exception:
pass
ws = self._connect_ws("/talk")
continue
except Exception:
print("Can't establish channel")
self.__stop_threads = True
break
except KeyboardInterrupt:
self.__stop_threads = True
try:
ws.send(self.close_response)
ws.close()
except Exception:
pass
break
finally:
try:
ws.close()
except Exception:
pass
@property
def base_url(self) -> str:
return f"http://{self.server}:{self.port}"
def update_info(self):
ws = self._connect_ws("/update")
last_try = None
try:
while not self.__stop_threads:
try:
time.sleep(RENDER_TIME)
raw = ws.recv()
if isinstance(raw, bytes):
raw = raw.decode("utf-8")
response = ast.literal_eval(raw)
if last_try == response:
continue
last_try = response
self.clear_console()
if len(last_try["messages"]) > 0:
self.print_chat(response=last_try)
except (WebSocketConnectionClosedException, ConnectionResetError, ConnectionAbortedError, OSError):
try:
if ws:
try:
ws.close()
except Exception:
pass
ws = self._connect_ws("/update")
continue
except Exception:
print("Connection lost: can't establish update channel")
self.__stop_threads = True
break
except KeyboardInterrupt:
self.__stop_threads = True
try:
ws.send(self.close_response)
ws.close()
except Exception:
pass
break
finally:
try:
ws.close()
except Exception:
pass
@property
def ws_url(self) -> str:
return f"ws://{self.server}:{self.port}"
def _validate_keys(self) -> None:
self._request_key(
url=f"{self.base_url}/get_key",
username=self.username,
password=self.password
def success(self, message: str) -> None:
self.console.print(f"[green]✓ {message}[/]")
def error(self, message: str) -> None:
self.console.print(f"[red]✗ {message}[/]")
def info(self, message: str) -> None:
self.console.print(f"[cyan]• {message}[/]")
def generate_keys(self) -> None:
with self.console.status(
"[cyan]Generating RSA keys (2048 bit)...[/]", spinner="dots"
):
self.public_key, self.private_key = rsa.newkeys(2048)
self.success("RSA keys generated")
def exchange_keys(self) -> None:
with self.console.status(
"[cyan]Exchanging keys with server...[/]", spinner="dots"
):
pubkey_bytes = self.public_key.save_pkcs1()
response = requests.post(
f"{self.base_url}/get_key",
files={"pubkey": ("key.pem", pubkey_bytes)},
data={"username": self.username, "password": self.password},
timeout=30,
)
response.raise_for_status()
self.user_id = response.headers.get("X-User-Id")
encrypted_key = response.content
symmetric_key = rsa.decrypt(encrypted_key, self.private_key)
self.fernet = Fernet(symmetric_key)
self.success(f"Key exchange complete (session: {self.user_id[:8]}...)")
self.public_key = None
self.private_key = None
def render_messages(self) -> None:
self.console.clear()
users_online = ", ".join(u.get("username", "?") for u in self.users) or "none"
self.console.print(f"[dim]Online: {users_online}[/]")
self.console.print("" * 60)
display_messages = (
self.messages[-15:] if len(self.messages) > 15 else self.messages
)
self._remove_keys()
def run(self):
self._validate_keys()
threads = [
threading.Thread(target=self.send_info, daemon=True),
threading.Thread(target=self.update_info, daemon=True)
]
for th in threads:
th.start()
for th in threads:
th.join()
for msg in display_messages:
username = msg.get("username", "unknown")
text = msg.get("text", "")
timestamp = str(msg.get("timestamp", ""))[:19].replace("T", " ")
style = "green" if username == self.username else "cyan"
self.console.print(f"[dim]{timestamp}[/] [{style}]{username}[/]: {text}")
if not display_messages:
self.console.print("[dim italic]No messages yet...[/]")
self.console.print("" * 60)
self.console.print("[dim]Type message and press Enter. 'q' to quit.[/]")
async def receive_loop(self, ws) -> None:
try:
async for raw in ws:
if not self.running:
break
data = json.loads(raw)
msg_type = data.get("type", "")
if msg_type == "init":
self.messages = data.get("messages", [])
self.users = data.get("users", [])
self.connected = True
self.render_messages()
elif msg_type == "message":
msg_data = data.get("data", {})
self.messages.append(msg_data)
self.render_messages()
elif msg_type == "user_left":
left_id = data.get("user_id")
self.users = [u for u in self.users if u.get("user_id") != left_id]
self.render_messages()
except websockets.ConnectionClosed:
self.connected = False
async def input_loop(self, ws) -> None:
loop = asyncio.get_event_loop()
while self.running:
try:
text = await loop.run_in_executor(None, input)
if text.lower() in ("q", "quit", "exit"):
self.running = False
break
if text.strip():
await ws.send(text)
except (EOFError, KeyboardInterrupt):
self.running = False
break
async def run_async(self) -> None:
self.console.clear()
self.console.print(Panel("[bold cyan]CMD Chat Client[/]", expand=False))
self.console.print()
try:
self.generate_keys()
self.exchange_keys()
self.info("Connecting to chat...")
url = (
f"{self.ws_url}/ws/chat?user_id={self.user_id}&password={self.password}"
)
async with websockets.connect(url) as ws:
self.success("Connected to chat server")
self.running = True
receive_task = asyncio.create_task(self.receive_loop(ws))
input_task = asyncio.create_task(self.input_loop(ws))
done, pending = await asyncio.wait(
[receive_task, input_task], return_when=asyncio.FIRST_COMPLETED
)
for task in pending:
task.cancel()
self.console.print("\n[yellow]Disconnected[/]")
except requests.exceptions.ConnectionError:
self.error(f"Cannot connect to {self.base_url}")
except requests.exceptions.HTTPError as e:
self.error(f"Server error: {e.response.status_code} - {e.response.text}")
except Exception as e:
import traceback
self.error(f"Error: {e}")
traceback.print_exc()
def run(self) -> None:
asyncio.run(self.run_async())

View File

@ -1,11 +1 @@
from colorama import Fore
COLORS = {
"text_color": Fore.WHITE,
"my_username_color": Fore.MAGENTA,
"ip_color": Fore.MAGENTA,
"username_color": Fore.GREEN
}
RENDER_TIME = 0.05
MESSAGES_TO_SHOW = 5

View File

@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
class CryptoService(ABC):
@abstractmethod

View File

@ -2,10 +2,9 @@ from abc import ABC, abstractmethod
class ClientRenderer(ABC):
# These attributes are expected to be provided by subclasses
# (typically via multiple inheritance with CryptoService)
username: str
@abstractmethod
def _decrypt(self, message: str) -> str:
"""Decrypt an encrypted message (provided by crypto mixin)."""

View File

@ -28,17 +28,16 @@ class RSAService(CryptoService):
stream=True,
)
r.raise_for_status()
# read the full response content (server returns encrypted symmetric key)
message = r.content
self.symmetric_key = rsa.decrypt(message, self.private_key)
self.fernet = Fernet(self.symmetric_key)
def _generate_keys(self):
self.public_key, self.private_key = rsa.newkeys(512)
self.public_key, self.private_key = rsa.newkeys(2048)
def _get_generated_keys(self):
return self.private_key, self.public_key
def _remove_keys(self):
self.public_key = None
self.private_key = None
self.private_key = None

View File

@ -1,61 +0,0 @@
import os
import platform
from cmd_chat.client.core.abs.abs_renderer import ClientRenderer
from cmd_chat.client.config import COLORS
from colorama import init
init()
class DefaultClientRenderer(ClientRenderer):
def __get_os(self) -> str:
""" checking what kind of platform you need
"""
if "Linux" in str(platform.platform()):
return "Linux"
return "Windows"
def print_message(self, message: str) -> str:
""" generating string with message in required format
"""
# split only on the first ':' to keep message contents intact
parts = message.split(":", 1)
if parts[0] == self.username:
return COLORS["my_username_color"] + parts[0] + ": " + parts[1] + COLORS["text_color"]
return parts[0] + ": " + parts[1] + COLORS["text_color"]
def clear_console(self):
# For windows clear command its cls
# For linux clear command its clear
if self.__get_os() == "Linux":
os.system("clear")
else:
os.system("cls")
def print_ip(
self,
ip: str
) -> str:
return f"IP: " + COLORS["ip_color"] + ip + COLORS["text_color"]
def print_username(
self,
username: str
) -> str:
# Username label + colored username
return f"USERNAME: " + COLORS["username_color"] + username + COLORS["text_color"]
def print_chat(self, response) -> None:
for i, msg in enumerate(response["messages"]):
actual_message = self._decrypt(msg)
if i == 0:
for user in response["users_in_chat"]:
print(self.print_ip(user.split(",")[0]))
print(self.print_username(user.split(",")[1]))
print("Write 'q' to quit from chat")
print(f"\n{self.print_message(actual_message)}")
else:
print(f"{self.print_message(actual_message)}")

View File

@ -1,9 +1,9 @@
import os
import os
import platform
from rich.text import Text
from rich.text import Text
from rich.style import Style
from rich.console import Console
from rich.console import Console
from rich.table import Table
from cmd_chat.client.core.abs.abs_renderer import ClientRenderer
@ -16,26 +16,26 @@ console = Console(width=75)
class RichClientRenderer(ClientRenderer):
def __get_os(self) -> str:
""" checking what kind of platform you need
"""
"""checking what kind of platform you need"""
if "Linux" in str(platform.platform()):
return "Linux"
return "Windows"
def print_message(self, message: str) -> Text:
""" generating string with message in required format
"""
"""generating string with message in required format"""
# split only on the first ':' so message bodies containing ':' are preserved
parts = message.split(":", 1)
if parts[0] == self.username:
return \
Text(text=parts[0], style="bold") + \
Text(text=": ", style="bold") + \
Text(text=parts[1], style="underline")
return \
Text(text=parts[0], style="bold") + \
Text(text=": ", style="bold") + \
Text(text=parts[1], style="underline")
return (
Text(text=parts[0], style="bold")
+ Text(text=": ", style="bold")
+ Text(text=parts[1], style="underline")
)
return (
Text(text=parts[0], style="bold")
+ Text(text=": ", style="bold")
+ Text(text=parts[1], style="underline")
)
def clear_console(self):
# For windows clear command its cls
@ -45,16 +45,10 @@ class RichClientRenderer(ClientRenderer):
else:
os.system("cls")
def print_ip(
self,
ip: str
) -> str:
def print_ip(self, ip: str) -> str:
return ip
def print_username(
self,
username: str
) -> str:
def print_username(self, username: str) -> str:
return username
def print_chat(self, response) -> None:
@ -68,11 +62,11 @@ class RichClientRenderer(ClientRenderer):
table.add_column("USERNAME")
for user in response["users_in_chat"]:
table.add_row(
self.print_ip(user.split(',')[0]),
self.print_username(user.split(",")[1])
self.print_ip(user.split(",")[0]),
self.print_username(user.split(",")[1]),
)
console.print(table)
console.print("Write 'q' to quit from chat", justify="left")
console.print(f"\n{self.print_message(actual_message)}")
else:
console.print(f"{self.print_message(actual_message)}")
console.print(f"{self.print_message(actual_message)}")

View File

@ -0,0 +1,50 @@
import asyncio
from contextlib import suppress
from cryptography.fernet import Fernet
from sanic import Sanic
from sanic_ext import Extend
from .managers import ConnectionManager
from .stores import MessageStore, UserSessionStore
from .logger import logger
from .routes import register_routes
def create_app() -> Sanic:
app = Sanic("cmd-chat-server")
Extend(app)
app.ctx.message_store = MessageStore()
app.ctx.session_store = UserSessionStore()
app.ctx.connection_manager = ConnectionManager()
app.ctx.admin_password = None
app.ctx.fernet_key = Fernet.generate_key()
app.ctx.cleanup_task = None
register_lifecycle(app)
register_routes(app)
return app
def register_lifecycle(app: Sanic) -> None:
@app.before_server_start
async def setup(app: Sanic):
logger.info("Server starting...")
app.ctx.cleanup_task = asyncio.create_task(cleanup_stale_sessions(app))
@app.after_server_stop
async def teardown(app: Sanic):
logger.info("Server shutting down...")
if app.ctx.cleanup_task:
app.ctx.cleanup_task.cancel()
with suppress(asyncio.CancelledError):
await app.ctx.cleanup_task
async def cleanup_stale_sessions(app: Sanic) -> None:
while True:
with suppress(asyncio.CancelledError):
await asyncio.sleep(300)
await app.ctx.session_store.cleanup_stale()

View File

@ -0,0 +1,58 @@
from datetime import datetime, timezone
from typing import Optional
from dataclasses import asdict
import json
from sanic import Sanic, Request, response, Websocket
def utcnow() -> datetime:
return datetime.now(timezone.utc)
def verify_password(password: Optional[str], expected: Optional[str]) -> bool:
if not expected:
return True
return password == expected
def get_client_ip(request: Request) -> str:
if forwarded := request.headers.get("x-forwarded-for"):
return forwarded.split(",")[0].strip()
return request.ip
def get_param(request: Request, name: str) -> Optional[str]:
return request.args.get(name) or request.form.get(name)
def require_auth(request: Request, app: Sanic) -> Optional[response.HTTPResponse]:
if not verify_password(get_param(request, "password"), app.ctx.admin_password):
return response.text("Unauthorized", status=401)
return None
async def send_state(ws: Websocket, app: Sanic) -> None:
messages = await app.ctx.message_store.get_all()
users = await app.ctx.session_store.get_all()
await ws.send(
json.dumps(
{
"type": "init",
"messages": [asdict(m) for m in messages],
"users": [
{"user_id": u.user_id, "username": u.username} for u in users
],
}
)
)
def extract_pubkey(request: Request) -> Optional[bytes]:
if files := request.files.get("pubkey"):
file = files[0] if isinstance(files, list) else files
return file.body
if raw := request.form.get("pubkey"):
return raw.encode() if isinstance(raw, str) else raw
if raw := request.args.get("pubkey"):
return raw.encode() if isinstance(raw, str) else raw
return None

View File

@ -0,0 +1,6 @@
import logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)

View File

@ -0,0 +1,48 @@
import asyncio
from typing import Optional
from sanic import Websocket
from .logger import logger
class ConnectionManager:
def __init__(self):
self.active_connections: dict[str, Websocket] = {}
self._lock = asyncio.Lock()
async def connect(self, user_id: str, websocket: Websocket) -> None:
async with self._lock:
self.active_connections[user_id] = websocket
logger.info(f"Client connected: {user_id}")
async def disconnect(self, user_id: str) -> None:
async with self._lock:
if user_id in self.active_connections:
del self.active_connections[user_id]
logger.info(f"Client disconnected: {user_id}")
async def broadcast(self, message: str, exclude_user: Optional[str] = None) -> None:
async with self._lock:
disconnected = []
for user_id, connection in list(self.active_connections.items()):
if exclude_user and user_id == exclude_user:
continue
try:
await connection.send(message)
except Exception as e:
logger.warning(f"Failed to send message to {user_id}: {e}")
disconnected.append(user_id)
for user_id in disconnected:
if user_id in self.active_connections:
del self.active_connections[user_id]
async def send_personal(self, user_id: str, message: str) -> bool:
async with self._lock:
if connection := self.active_connections.get(user_id):
try:
await connection.send(message)
return True
except Exception as e:
logger.warning(f"Failed to send personal message to {user_id}: {e}")
return False
return False

View File

@ -1,5 +1,31 @@
from pydantic import BaseModel
from dataclasses import dataclass, field
from uuid import uuid4
from datetime import datetime
from typing import Optional
class Message(BaseModel):
message: str
@dataclass
class Message:
id: str = field(default_factory=lambda: str(uuid4()))
text: str = ""
timestamp: str = field(default_factory=lambda: datetime.utcnow().isoformat())
user_ip: str = ""
username: str = ""
@dataclass
class UserSession:
user_id: str
ip: str
username: str = "unknown"
fernet_key: Optional[bytes] = None
created_at: str = field(default_factory=lambda: datetime.utcnow().isoformat())
last_activity: str = field(default_factory=lambda: datetime.utcnow().isoformat())
active: bool = True
def update_activity(self):
self.last_activity = datetime.utcnow().isoformat()
def is_stale(self, timeout_seconds: int = 3600) -> bool:
last = datetime.fromisoformat(self.last_activity)
return (datetime.utcnow() - last).total_seconds() > timeout_seconds

21
cmd_chat/server/routes.py Normal file
View File

@ -0,0 +1,21 @@
from sanic import Sanic, Request, Websocket
from . import views
def register_routes(app: Sanic) -> None:
@app.route("/get_key", methods=["GET", "POST"])
async def get_key_route(request: Request):
return await views.get_key(request, app)
@app.websocket("/ws/chat")
async def chat_ws_route(request: Request, ws: Websocket):
await views.chat_ws(request, ws, app)
@app.get("/health")
async def health_route(request: Request):
return await views.health(request, app)
@app.delete("/clear")
async def clear_route(request: Request):
return await views.clear_messages(request, app)

View File

@ -1,113 +1,23 @@
import asyncio
import rsa
from cryptography.fernet import Fernet
from functools import partial
from sanic.worker.loader import AppLoader
from sanic.response import HTTPResponse
from sanic import Sanic, Request, response, Websocket
from cmd_chat.server.models import Message
from cmd_chat.server.services import (
_get_bytes_and_serialize,
_check_ws_for_close_status,
_generate_new_message,
_generate_update_payload
)
from typing import Optional
from .logger import logger
from .factory import create_app
app = Sanic("app")
app.config.OAS = False
MESSAGES_MEMORY_DB: list[Message] = []
USERS: dict[str, str] = {}
PUBLIC_KEY = Fernet.generate_key()
app = create_app()
def _check_password(request: Request, expected: str | None) -> bool:
if not expected:
return True
q = request.args.get("password")
f = request.form.get("password") if hasattr(request, "form") else None
return (q or f) == expected
def run_server(
host: str = "0.0.0.0",
port: int = 8000,
admin_password: Optional[str] = None,
workers: int = 1,
) -> None:
app.ctx.admin_password = admin_password
logger.info(f"Starting server on {host}:{port}")
def _get_str_arg(request: Request, name: str) -> str | None:
return request.form.get(name) or request.args.get(name)
def attach_endpoints(app: Sanic):
@app.websocket("/talk")
async def talk_ws_view(request: Request, ws: Websocket) -> HTTPResponse:
if not _check_password(request, app.ctx.ADMIN_PASSWORD):
await ws.close(code=4001, reason="unauthorized")
return
while True:
serialized_message: dict = await _get_bytes_and_serialize(ws)
await _check_ws_for_close_status(serialized_message, ws)
text = serialized_message.get("text")
if text is None:
continue
new_message = await _generate_new_message(text)
MESSAGES_MEMORY_DB.append(new_message)
await ws.send(str({"status": "ok"}))
await asyncio.sleep(0.2)
@app.websocket("/update")
async def update_ws_view(request: Request, ws: Websocket) -> HTTPResponse:
if not _check_password(request, app.ctx.ADMIN_PASSWORD):
await ws.close(code=4001, reason="unauthorized")
return
while True:
payload = await _generate_update_payload(MESSAGES_MEMORY_DB, USERS)
await ws.send(payload.encode())
await asyncio.sleep(0.2)
@app.route('/get_key', methods=['GET', 'POST'])
async def get_key_view(request: Request) -> HTTPResponse:
if not _check_password(request, app.ctx.ADMIN_PASSWORD):
return response.text("unauthorized", status=401)
pubkey_bytes: bytes | None = None
if "pubkey" in request.files and request.files.get("pubkey"):
f = request.files.get("pubkey")
if isinstance(f, list):
f = f[0]
pubkey_bytes = f.body
if pubkey_bytes is None:
raw = request.form.get("pubkey")
if raw:
pubkey_bytes = raw if isinstance(raw, bytes) else str(raw).encode()
if pubkey_bytes is None:
raw = request.args.get("pubkey")
if raw:
pubkey_bytes = raw.encode()
if not pubkey_bytes:
return response.text("bad request: pubkey is required", status=400)
try:
public_key = rsa.PublicKey.load_pkcs1(pubkey_bytes)
except Exception as e:
return response.text(f"bad pubkey: {e}", status=400)
encrypted_data = rsa.encrypt(PUBLIC_KEY, public_key)
username = _get_str_arg(request, "username") or "unknown"
user_key = f"{request.ip}, {username}"
if user_key not in USERS:
USERS[user_key] = PUBLIC_KEY
return response.raw(encrypted_data)
def create_app(app_name: str, admin_password: str | None) -> Sanic:
app = Sanic(app_name)
app.ctx.ADMIN_PASSWORD = admin_password
attach_endpoints(app)
return app
def run_server(host: str, port: int, dev: bool = False, admin_password: str | None = None) -> None:
loader = AppLoader(factory=partial(create_app, "CMD_SERVER", admin_password))
app = loader.load()
app.prepare(host=host, port=port, dev=dev)
Sanic.serve(primary=app, app_loader=loader)
app.run(
host=host,
port=port,
workers=workers,
debug=False,
access_log=True,
)

View File

@ -1,36 +0,0 @@
import ast
from sanic import Websocket
from cmd_chat.server.models import Message
async def _get_bytes_and_serialize(
ws: Websocket
) -> dict:
return ast.literal_eval(await ws.recv())
async def _check_ws_for_close_status(
response: dict,
ws: Websocket
) -> None:
if "action" in response.keys():
if response["action"] == "close":
await ws.close()
async def _generate_new_message(
message: str
) -> Message:
return Message(message = message)
async def _generate_update_payload(
memory_msgs: list[Message],
users_structure: dict
) -> str:
return str({
"messages": [i.message for i in memory_msgs],
"users_in_chat": list(users_structure.keys())
})

79
cmd_chat/server/stores.py Normal file
View File

@ -0,0 +1,79 @@
import asyncio
from typing import Optional
from .models import Message, UserSession
from .logger import logger
class MessageStore:
def __init__(self):
self._messages: list[Message] = []
self._lock = asyncio.Lock()
async def add(self, message: Message) -> None:
async with self._lock:
self._messages.append(message)
logger.info(f"Message added: {message.id} from {message.username}")
async def get_all(self) -> list[Message]:
async with self._lock:
return self._messages.copy()
async def clear(self) -> None:
async with self._lock:
count = len(self._messages)
self._messages.clear()
logger.info(f"Cleared {count} messages")
async def count(self) -> int:
async with self._lock:
return len(self._messages)
class UserSessionStore:
def __init__(self):
self._sessions: dict[str, UserSession] = {}
self._lock = asyncio.Lock()
async def add(self, session: UserSession) -> None:
async with self._lock:
self._sessions[session.user_id] = session
logger.info(f"Session created: {session.user_id} ({session.username})")
async def get(self, user_id: str) -> Optional[UserSession]:
async with self._lock:
return self._sessions.get(user_id)
async def update_activity(self, user_id: str) -> None:
async with self._lock:
if session := self._sessions.get(user_id):
session.update_activity()
async def remove(self, user_id: str) -> None:
async with self._lock:
if user_id in self._sessions:
del self._sessions[user_id]
logger.info(f"Session removed: {user_id}")
async def cleanup_stale(self, timeout_seconds: int = 3600) -> int:
async with self._lock:
stale_ids = [
uid for uid, s in self._sessions.items() if s.is_stale(timeout_seconds)
]
for uid in stale_ids:
del self._sessions[uid]
if stale_ids:
logger.info(f"Cleaned up {len(stale_ids)} stale sessions")
return len(stale_ids)
async def get_all(self) -> list[UserSession]:
async with self._lock:
return list(self._sessions.values())
async def count(self) -> int:
async with self._lock:
return len(self._sessions)
async def username_exists(self, username: str) -> bool:
async with self._lock:
return any(s.username == username for s in self._sessions.values())

138
cmd_chat/server/views.py Normal file
View File

@ -0,0 +1,138 @@
from dataclasses import asdict
from uuid import uuid4
import json
import rsa
from sanic import Sanic, Request, response, Websocket
from sanic.response import HTTPResponse, json as json_response
from .models import Message, UserSession
from .logger import logger
from .helpers import (
require_auth,
extract_pubkey,
get_client_ip,
get_param,
verify_password,
send_state,
utcnow,
)
async def get_key(request: Request, app: Sanic) -> HTTPResponse:
if err := require_auth(request, app):
return err
pubkey_bytes = extract_pubkey(request)
if not pubkey_bytes:
return response.text("Bad request: pubkey is required", status=400)
try:
public_key = rsa.PublicKey.load_pkcs1(pubkey_bytes)
if public_key.n.bit_length() < 2048:
raise ValueError("RSA key must be at least 2048 bits")
except Exception as e:
logger.warning(f"Invalid public key: {e}")
return response.text(f"Bad pubkey: {e}", status=400)
username = get_param(request, "username") or "unknown"
if await app.ctx.session_store.username_exists(username):
return response.text(f"Username '{username}' is already taken", status=409)
session = UserSession(
user_id=str(uuid4()),
ip=get_client_ip(request),
username=get_param(request, "username") or "unknown",
fernet_key=app.ctx.fernet_key,
)
await app.ctx.session_store.add(session)
try:
encrypted_key = rsa.encrypt(app.ctx.fernet_key, public_key)
logger.info(f"Key exchange: user={session.username}, session={session.user_id}")
return response.raw(
encrypted_key,
content_type="application/octet-stream",
headers={"X-User-Id": session.user_id},
)
except Exception as e:
logger.error(f"Encryption failed: {e}")
return response.text("Key encryption failed", status=500)
async def chat_ws(request: Request, ws: Websocket, app: Sanic) -> None:
user_id = request.args.get("user_id")
if not user_id:
await ws.close(code=4002, reason="user_id required")
return
if not verify_password(request.args.get("password"), app.ctx.admin_password):
await ws.close(code=4001, reason="Unauthorized")
return
session = await app.ctx.session_store.get(user_id)
if not session:
await ws.close(code=4002, reason="Invalid session")
return
manager = app.ctx.connection_manager
await manager.connect(user_id, ws)
try:
await send_state(ws, app)
async for data in ws:
if data is None:
break
await app.ctx.session_store.update_activity(user_id)
message = Message(
text=str(data),
user_ip=session.ip,
username=session.username,
)
await app.ctx.message_store.add(message)
await manager.broadcast(
json.dumps(
{
"type": "message",
"data": asdict(message),
}
)
)
except Exception as e:
logger.error(f"WebSocket error for {user_id}: {e}")
finally:
await manager.disconnect(user_id)
await manager.broadcast(
json.dumps(
{
"type": "user_left",
"user_id": user_id,
}
)
)
async def health(request: Request, app: Sanic) -> HTTPResponse:
return json_response(
{
"status": "ok",
"messages": await app.ctx.message_store.count(),
"users": await app.ctx.session_store.count(),
"timestamp": utcnow().isoformat(),
}
)
async def clear_messages(request: Request, app: Sanic) -> HTTPResponse:
if err := require_auth(request, app):
return err
await app.ctx.message_store.clear()
return json_response({"status": "cleared"})

View File

@ -6,4 +6,6 @@ colorama
pydantic
websocket-client
flask
rich
rich
sanic_ext
websockets

View File

@ -14,13 +14,9 @@ setuptools.setup(
long_description=description,
long_description_content_type="text/markdown",
url="https://github.com/dinosaurtirex/cmd-chat",
license='MIT',
python_requires='>=3.10',
entry_points={
'console_scripts': [
'cmd_chat = cmd_chat:main'
]
},
license="MIT",
python_requires=">=3.10",
entry_points={"console_scripts": ["cmd_chat = cmd_chat:main"]},
install_requires=[
"sanic",
"requests",
@ -28,6 +24,6 @@ setuptools.setup(
"cryptography",
"colorama",
"pydantic",
"websocket-client"
]
)
"websocket-client",
],
)