feat: complete client-server architecture refactoring
Server: - Split into views, routes, helpers, models modules - Merged /ws/talk and /ws/update into single /ws/chat endpoint - Replaced polling with push-based broadcast model - Added username uniqueness validation on connect - Fixed run_server arguments bug (workers parameter) - Removed deprecated loop argument from Sanic listeners - Replaced datetime.utcnow() with timezone-aware datetime.now(timezone.utc) Client: - Rewrote client as single-file module - Migrated from websocket-client to websockets (asyncio) - Fixed websocket-client conflict with asyncio event loop on Windows - Added progress indicators for key generation, exchange, connection - Added animated 3D spinning cube in UI - Updated RSA key from 512 to 2048 bits CLI: - Removed unnecessary asyncio.run() wrapper - Simplified entry point
This commit is contained in:
parent
faaadd839b
commit
95f8a192b5
|
|
@ -47,13 +47,13 @@ Everything happens in memory only. Nothing is written to disk.
|
||||||
`python -m venv venv ; .\venv\Scripts\activate ; pip install -r requirements.txt`
|
`python -m venv venv ; .\venv\Scripts\activate ; pip install -r requirements.txt`
|
||||||
|
|
||||||
3. Start the server (set a password for client connections):
|
3. Start the server (set a password for client connections):
|
||||||
`python cmd_chat.py serve 0.0.0.0 1000 --password YOUR_PASSWORD`
|
`python cmd_chat.py serve 0.0.0.0 3000 --password YOUR_PASSWORD`
|
||||||
|
|
||||||
4. Connect a client:
|
4. Connect a client:
|
||||||
`python cmd_chat.py connect SERVER_IP 1000 USERNAME YOUR_PASSWORD`
|
`python cmd_chat.py connect SERVER_IP 3000 USERNAME YOUR_PASSWORD`
|
||||||
|
|
||||||
Example (local run):
|
Example (local run):
|
||||||
`python cmd_chat.py connect localhost 1000 tyler YOUR_PASSWORD`
|
`python cmd_chat.py connect localhost 3000 tyler YOUR_PASSWORD`
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|
|
||||||
10
cmd_chat.py
10
cmd_chat.py
|
|
@ -1,8 +1,4 @@
|
||||||
import asyncio
|
from cmd_chat import main
|
||||||
import cmd_chat
|
|
||||||
|
|
||||||
async def main():
|
if __name__ == "__main__":
|
||||||
await cmd_chat.run()
|
main()
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
asyncio.run(main())
|
|
||||||
|
|
|
||||||
|
|
@ -1,38 +1,39 @@
|
||||||
import asyncio
|
|
||||||
import argparse
|
import argparse
|
||||||
from cmd_chat.server.server import run_server
|
from cmd_chat.server.server import run_server
|
||||||
from cmd_chat.client.client import Client
|
from cmd_chat.client.client import Client
|
||||||
|
|
||||||
|
|
||||||
def run_http_server(ip: str, port: int, password: str | None) -> None:
|
def run_http_server(ip: str, port: int, password: str | None) -> None:
|
||||||
run_server(ip, int(port), False, password)
|
run_server(host=ip, port=int(port), admin_password=password)
|
||||||
|
|
||||||
async def run_client(username: str, server: str, port: int, password: str | None) -> None:
|
|
||||||
Client(server=server, port=port, username=username, password=password).run()
|
|
||||||
|
|
||||||
async def run() -> None:
|
def main():
|
||||||
parser = argparse.ArgumentParser(description='Command-line chat application')
|
parser = argparse.ArgumentParser(description="Command-line chat application")
|
||||||
subparsers = parser.add_subparsers(dest='command', required=True)
|
subparsers = parser.add_subparsers(dest="command", required=True)
|
||||||
|
|
||||||
serve_p = subparsers.add_parser('serve', help='Run server')
|
serve_p = subparsers.add_parser("serve", help="Run server")
|
||||||
serve_p.add_argument('ip_address')
|
serve_p.add_argument("ip_address")
|
||||||
serve_p.add_argument('port')
|
serve_p.add_argument("port")
|
||||||
serve_p.add_argument('--password', '-p', required=True, help='Admin password required for clients')
|
serve_p.add_argument("--password", "-p", required=True)
|
||||||
|
|
||||||
connect_p = subparsers.add_parser('connect', help='Connect to server')
|
connect_p = subparsers.add_parser("connect", help="Connect to server")
|
||||||
connect_p.add_argument('ip_address')
|
connect_p.add_argument("ip_address")
|
||||||
connect_p.add_argument('port')
|
connect_p.add_argument("port")
|
||||||
connect_p.add_argument('username')
|
connect_p.add_argument("username")
|
||||||
connect_p.add_argument('password', help='Password to auth on server')
|
connect_p.add_argument("password")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.command == 'serve':
|
if args.command == "serve":
|
||||||
run_http_server(args.ip_address, args.port, args.password)
|
run_http_server(args.ip_address, args.port, args.password)
|
||||||
elif args.command == 'connect':
|
elif args.command == "connect":
|
||||||
await run_client(args.username, args.ip_address, int(args.port), args.password)
|
Client(
|
||||||
|
server=args.ip_address,
|
||||||
|
port=int(args.port),
|
||||||
|
username=args.username,
|
||||||
|
password=args.password,
|
||||||
|
).run()
|
||||||
|
|
||||||
def main():
|
|
||||||
asyncio.run(run())
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
|
||||||
|
|
@ -1,161 +1,185 @@
|
||||||
import ast
|
import asyncio
|
||||||
import time
|
import json
|
||||||
import threading
|
from typing import Optional
|
||||||
|
|
||||||
from websocket import create_connection, WebSocketConnectionClosedException
|
import rsa
|
||||||
|
import requests
|
||||||
from cmd_chat.client.core.crypto import RSAService
|
from cryptography.fernet import Fernet
|
||||||
from cmd_chat.client.core.default_renderer import DefaultClientRenderer
|
import websockets
|
||||||
from cmd_chat.client.core.rich_renderer import RichClientRenderer
|
from rich.console import Console
|
||||||
|
from rich.panel import Panel
|
||||||
from cmd_chat.client.config import RENDER_TIME
|
|
||||||
|
|
||||||
|
|
||||||
class Client(RSAService, RichClientRenderer):
|
class Client:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self, server: str, port: int, username: str, password: Optional[str] = None
|
||||||
server: str,
|
|
||||||
port: int,
|
|
||||||
username: str,
|
|
||||||
password: str | None = None
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
|
||||||
self.server = server
|
self.server = server
|
||||||
self.port = port
|
self.port = port
|
||||||
self.username = username
|
self.username = username
|
||||||
self.password = password or ""
|
self.password = password or ""
|
||||||
self.base_url = f"http://{self.server}:{self.port}"
|
self.user_id: Optional[str] = None
|
||||||
self.ws_url = f"ws://{self.server}:{self.port}"
|
|
||||||
self.close_response = str({
|
|
||||||
"action": "close",
|
|
||||||
"username": self.username
|
|
||||||
})
|
|
||||||
self.__stop_threads = False
|
|
||||||
|
|
||||||
def _ws_full(self, path: str) -> str:
|
self.public_key: Optional[rsa.PublicKey] = None
|
||||||
if self.password:
|
self.private_key: Optional[rsa.PrivateKey] = None
|
||||||
return f"{self.ws_url}{path}?password={self.password}"
|
self.fernet: Optional[Fernet] = None
|
||||||
return f"{self.ws_url}{path}"
|
|
||||||
|
|
||||||
def _connect_ws(self, path: str, retries: int = 5, backoff: float = 0.5):
|
self.console = Console()
|
||||||
last_exc: Exception = ConnectionError("Failed to connect")
|
self.messages: list[dict] = []
|
||||||
for attempt in range(retries):
|
self.users: list[dict] = []
|
||||||
try:
|
self.connected = False
|
||||||
return create_connection(self._ws_full(path))
|
self.running = False
|
||||||
except Exception as exc:
|
|
||||||
last_exc = exc
|
|
||||||
time.sleep(backoff * (2 ** attempt))
|
|
||||||
print(f"Can't connect to {path}: {last_exc}")
|
|
||||||
raise last_exc
|
|
||||||
|
|
||||||
def send_info(self):
|
@property
|
||||||
ws = self._connect_ws("/talk")
|
def base_url(self) -> str:
|
||||||
try:
|
return f"http://{self.server}:{self.port}"
|
||||||
while not self.__stop_threads:
|
|
||||||
try:
|
|
||||||
user_input = input("You're message: ")
|
|
||||||
if user_input == "q":
|
|
||||||
self.__stop_threads = True
|
|
||||||
try:
|
|
||||||
if ws:
|
|
||||||
ws.send(self.close_response)
|
|
||||||
ws.close()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
break
|
|
||||||
message = f'{self.username}: {user_input}'
|
|
||||||
socket_message = str({
|
|
||||||
"text": self._encrypt(message),
|
|
||||||
"username": self.username
|
|
||||||
})
|
|
||||||
ws.send(socket_message)
|
|
||||||
except (WebSocketConnectionClosedException, ConnectionResetError, ConnectionAbortedError, OSError):
|
|
||||||
try:
|
|
||||||
if ws:
|
|
||||||
try:
|
|
||||||
ws.close()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
ws = self._connect_ws("/talk")
|
|
||||||
continue
|
|
||||||
except Exception:
|
|
||||||
print("Can't establish channel")
|
|
||||||
self.__stop_threads = True
|
|
||||||
break
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
self.__stop_threads = True
|
|
||||||
try:
|
|
||||||
ws.send(self.close_response)
|
|
||||||
ws.close()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
break
|
|
||||||
finally:
|
|
||||||
try:
|
|
||||||
ws.close()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def update_info(self):
|
@property
|
||||||
ws = self._connect_ws("/update")
|
def ws_url(self) -> str:
|
||||||
last_try = None
|
return f"ws://{self.server}:{self.port}"
|
||||||
try:
|
|
||||||
while not self.__stop_threads:
|
|
||||||
try:
|
|
||||||
time.sleep(RENDER_TIME)
|
|
||||||
raw = ws.recv()
|
|
||||||
if isinstance(raw, bytes):
|
|
||||||
raw = raw.decode("utf-8")
|
|
||||||
response = ast.literal_eval(raw)
|
|
||||||
if last_try == response:
|
|
||||||
continue
|
|
||||||
last_try = response
|
|
||||||
self.clear_console()
|
|
||||||
if len(last_try["messages"]) > 0:
|
|
||||||
self.print_chat(response=last_try)
|
|
||||||
except (WebSocketConnectionClosedException, ConnectionResetError, ConnectionAbortedError, OSError):
|
|
||||||
try:
|
|
||||||
if ws:
|
|
||||||
try:
|
|
||||||
ws.close()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
ws = self._connect_ws("/update")
|
|
||||||
continue
|
|
||||||
except Exception:
|
|
||||||
print("Connection lost: can't establish update channel")
|
|
||||||
self.__stop_threads = True
|
|
||||||
break
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
self.__stop_threads = True
|
|
||||||
try:
|
|
||||||
ws.send(self.close_response)
|
|
||||||
ws.close()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
break
|
|
||||||
finally:
|
|
||||||
try:
|
|
||||||
ws.close()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _validate_keys(self) -> None:
|
def success(self, message: str) -> None:
|
||||||
self._request_key(
|
self.console.print(f"[green]✓ {message}[/]")
|
||||||
url=f"{self.base_url}/get_key",
|
|
||||||
username=self.username,
|
def error(self, message: str) -> None:
|
||||||
password=self.password
|
self.console.print(f"[red]✗ {message}[/]")
|
||||||
|
|
||||||
|
def info(self, message: str) -> None:
|
||||||
|
self.console.print(f"[cyan]• {message}[/]")
|
||||||
|
|
||||||
|
def generate_keys(self) -> None:
|
||||||
|
with self.console.status(
|
||||||
|
"[cyan]Generating RSA keys (2048 bit)...[/]", spinner="dots"
|
||||||
|
):
|
||||||
|
self.public_key, self.private_key = rsa.newkeys(2048)
|
||||||
|
self.success("RSA keys generated")
|
||||||
|
|
||||||
|
def exchange_keys(self) -> None:
|
||||||
|
with self.console.status(
|
||||||
|
"[cyan]Exchanging keys with server...[/]", spinner="dots"
|
||||||
|
):
|
||||||
|
pubkey_bytes = self.public_key.save_pkcs1()
|
||||||
|
response = requests.post(
|
||||||
|
f"{self.base_url}/get_key",
|
||||||
|
files={"pubkey": ("key.pem", pubkey_bytes)},
|
||||||
|
data={"username": self.username, "password": self.password},
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
self.user_id = response.headers.get("X-User-Id")
|
||||||
|
encrypted_key = response.content
|
||||||
|
symmetric_key = rsa.decrypt(encrypted_key, self.private_key)
|
||||||
|
self.fernet = Fernet(symmetric_key)
|
||||||
|
|
||||||
|
self.success(f"Key exchange complete (session: {self.user_id[:8]}...)")
|
||||||
|
self.public_key = None
|
||||||
|
self.private_key = None
|
||||||
|
|
||||||
|
def render_messages(self) -> None:
|
||||||
|
self.console.clear()
|
||||||
|
|
||||||
|
users_online = ", ".join(u.get("username", "?") for u in self.users) or "none"
|
||||||
|
self.console.print(f"[dim]Online: {users_online}[/]")
|
||||||
|
self.console.print("─" * 60)
|
||||||
|
|
||||||
|
display_messages = (
|
||||||
|
self.messages[-15:] if len(self.messages) > 15 else self.messages
|
||||||
)
|
)
|
||||||
self._remove_keys()
|
|
||||||
|
|
||||||
def run(self):
|
for msg in display_messages:
|
||||||
self._validate_keys()
|
username = msg.get("username", "unknown")
|
||||||
threads = [
|
text = msg.get("text", "")
|
||||||
threading.Thread(target=self.send_info, daemon=True),
|
timestamp = str(msg.get("timestamp", ""))[:19].replace("T", " ")
|
||||||
threading.Thread(target=self.update_info, daemon=True)
|
|
||||||
]
|
style = "green" if username == self.username else "cyan"
|
||||||
for th in threads:
|
self.console.print(f"[dim]{timestamp}[/] [{style}]{username}[/]: {text}")
|
||||||
th.start()
|
|
||||||
for th in threads:
|
if not display_messages:
|
||||||
th.join()
|
self.console.print("[dim italic]No messages yet...[/]")
|
||||||
|
|
||||||
|
self.console.print("─" * 60)
|
||||||
|
self.console.print("[dim]Type message and press Enter. 'q' to quit.[/]")
|
||||||
|
|
||||||
|
async def receive_loop(self, ws) -> None:
|
||||||
|
try:
|
||||||
|
async for raw in ws:
|
||||||
|
if not self.running:
|
||||||
|
break
|
||||||
|
|
||||||
|
data = json.loads(raw)
|
||||||
|
msg_type = data.get("type", "")
|
||||||
|
|
||||||
|
if msg_type == "init":
|
||||||
|
self.messages = data.get("messages", [])
|
||||||
|
self.users = data.get("users", [])
|
||||||
|
self.connected = True
|
||||||
|
self.render_messages()
|
||||||
|
elif msg_type == "message":
|
||||||
|
msg_data = data.get("data", {})
|
||||||
|
self.messages.append(msg_data)
|
||||||
|
self.render_messages()
|
||||||
|
elif msg_type == "user_left":
|
||||||
|
left_id = data.get("user_id")
|
||||||
|
self.users = [u for u in self.users if u.get("user_id") != left_id]
|
||||||
|
self.render_messages()
|
||||||
|
|
||||||
|
except websockets.ConnectionClosed:
|
||||||
|
self.connected = False
|
||||||
|
|
||||||
|
async def input_loop(self, ws) -> None:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
while self.running:
|
||||||
|
try:
|
||||||
|
text = await loop.run_in_executor(None, input)
|
||||||
|
if text.lower() in ("q", "quit", "exit"):
|
||||||
|
self.running = False
|
||||||
|
break
|
||||||
|
if text.strip():
|
||||||
|
await ws.send(text)
|
||||||
|
except (EOFError, KeyboardInterrupt):
|
||||||
|
self.running = False
|
||||||
|
break
|
||||||
|
|
||||||
|
async def run_async(self) -> None:
|
||||||
|
self.console.clear()
|
||||||
|
self.console.print(Panel("[bold cyan]CMD Chat Client[/]", expand=False))
|
||||||
|
self.console.print()
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.generate_keys()
|
||||||
|
self.exchange_keys()
|
||||||
|
|
||||||
|
self.info("Connecting to chat...")
|
||||||
|
url = (
|
||||||
|
f"{self.ws_url}/ws/chat?user_id={self.user_id}&password={self.password}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async with websockets.connect(url) as ws:
|
||||||
|
self.success("Connected to chat server")
|
||||||
|
self.running = True
|
||||||
|
|
||||||
|
receive_task = asyncio.create_task(self.receive_loop(ws))
|
||||||
|
input_task = asyncio.create_task(self.input_loop(ws))
|
||||||
|
|
||||||
|
done, pending = await asyncio.wait(
|
||||||
|
[receive_task, input_task], return_when=asyncio.FIRST_COMPLETED
|
||||||
|
)
|
||||||
|
|
||||||
|
for task in pending:
|
||||||
|
task.cancel()
|
||||||
|
|
||||||
|
self.console.print("\n[yellow]Disconnected[/]")
|
||||||
|
|
||||||
|
except requests.exceptions.ConnectionError:
|
||||||
|
self.error(f"Cannot connect to {self.base_url}")
|
||||||
|
except requests.exceptions.HTTPError as e:
|
||||||
|
self.error(f"Server error: {e.response.status_code} - {e.response.text}")
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
self.error(f"Error: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
def run(self) -> None:
|
||||||
|
asyncio.run(self.run_async())
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1 @@
|
||||||
from colorama import Fore
|
|
||||||
|
|
||||||
COLORS = {
|
|
||||||
"text_color": Fore.WHITE,
|
|
||||||
"my_username_color": Fore.MAGENTA,
|
|
||||||
"ip_color": Fore.MAGENTA,
|
|
||||||
"username_color": Fore.GREEN
|
|
||||||
}
|
|
||||||
|
|
||||||
RENDER_TIME = 0.05
|
|
||||||
MESSAGES_TO_SHOW = 5
|
MESSAGES_TO_SHOW = 5
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
class CryptoService(ABC):
|
class CryptoService(ABC):
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|
|
||||||
|
|
@ -2,10 +2,9 @@ from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
class ClientRenderer(ABC):
|
class ClientRenderer(ABC):
|
||||||
# These attributes are expected to be provided by subclasses
|
|
||||||
# (typically via multiple inheritance with CryptoService)
|
|
||||||
username: str
|
username: str
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _decrypt(self, message: str) -> str:
|
def _decrypt(self, message: str) -> str:
|
||||||
"""Decrypt an encrypted message (provided by crypto mixin)."""
|
"""Decrypt an encrypted message (provided by crypto mixin)."""
|
||||||
|
|
|
||||||
|
|
@ -28,17 +28,16 @@ class RSAService(CryptoService):
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
# read the full response content (server returns encrypted symmetric key)
|
|
||||||
message = r.content
|
message = r.content
|
||||||
self.symmetric_key = rsa.decrypt(message, self.private_key)
|
self.symmetric_key = rsa.decrypt(message, self.private_key)
|
||||||
self.fernet = Fernet(self.symmetric_key)
|
self.fernet = Fernet(self.symmetric_key)
|
||||||
|
|
||||||
def _generate_keys(self):
|
def _generate_keys(self):
|
||||||
self.public_key, self.private_key = rsa.newkeys(512)
|
self.public_key, self.private_key = rsa.newkeys(2048)
|
||||||
|
|
||||||
def _get_generated_keys(self):
|
def _get_generated_keys(self):
|
||||||
return self.private_key, self.public_key
|
return self.private_key, self.public_key
|
||||||
|
|
||||||
def _remove_keys(self):
|
def _remove_keys(self):
|
||||||
self.public_key = None
|
self.public_key = None
|
||||||
self.private_key = None
|
self.private_key = None
|
||||||
|
|
|
||||||
|
|
@ -1,61 +0,0 @@
|
||||||
import os
|
|
||||||
import platform
|
|
||||||
|
|
||||||
from cmd_chat.client.core.abs.abs_renderer import ClientRenderer
|
|
||||||
from cmd_chat.client.config import COLORS
|
|
||||||
|
|
||||||
from colorama import init
|
|
||||||
|
|
||||||
init()
|
|
||||||
|
|
||||||
|
|
||||||
class DefaultClientRenderer(ClientRenderer):
|
|
||||||
|
|
||||||
def __get_os(self) -> str:
|
|
||||||
""" checking what kind of platform you need
|
|
||||||
"""
|
|
||||||
if "Linux" in str(platform.platform()):
|
|
||||||
return "Linux"
|
|
||||||
return "Windows"
|
|
||||||
|
|
||||||
def print_message(self, message: str) -> str:
|
|
||||||
""" generating string with message in required format
|
|
||||||
"""
|
|
||||||
# split only on the first ':' to keep message contents intact
|
|
||||||
parts = message.split(":", 1)
|
|
||||||
if parts[0] == self.username:
|
|
||||||
return COLORS["my_username_color"] + parts[0] + ": " + parts[1] + COLORS["text_color"]
|
|
||||||
return parts[0] + ": " + parts[1] + COLORS["text_color"]
|
|
||||||
|
|
||||||
def clear_console(self):
|
|
||||||
# For windows clear command its cls
|
|
||||||
# For linux clear command its clear
|
|
||||||
if self.__get_os() == "Linux":
|
|
||||||
os.system("clear")
|
|
||||||
else:
|
|
||||||
os.system("cls")
|
|
||||||
|
|
||||||
def print_ip(
|
|
||||||
self,
|
|
||||||
ip: str
|
|
||||||
) -> str:
|
|
||||||
return f"IP: " + COLORS["ip_color"] + ip + COLORS["text_color"]
|
|
||||||
|
|
||||||
def print_username(
|
|
||||||
self,
|
|
||||||
username: str
|
|
||||||
) -> str:
|
|
||||||
# Username label + colored username
|
|
||||||
return f"USERNAME: " + COLORS["username_color"] + username + COLORS["text_color"]
|
|
||||||
|
|
||||||
def print_chat(self, response) -> None:
|
|
||||||
for i, msg in enumerate(response["messages"]):
|
|
||||||
actual_message = self._decrypt(msg)
|
|
||||||
if i == 0:
|
|
||||||
for user in response["users_in_chat"]:
|
|
||||||
print(self.print_ip(user.split(",")[0]))
|
|
||||||
print(self.print_username(user.split(",")[1]))
|
|
||||||
print("Write 'q' to quit from chat")
|
|
||||||
print(f"\n{self.print_message(actual_message)}")
|
|
||||||
else:
|
|
||||||
print(f"{self.print_message(actual_message)}")
|
|
||||||
|
|
@ -1,9 +1,9 @@
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
|
|
||||||
from rich.text import Text
|
from rich.text import Text
|
||||||
from rich.style import Style
|
from rich.style import Style
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
|
|
||||||
from rich.table import Table
|
from rich.table import Table
|
||||||
from cmd_chat.client.core.abs.abs_renderer import ClientRenderer
|
from cmd_chat.client.core.abs.abs_renderer import ClientRenderer
|
||||||
|
|
@ -16,26 +16,26 @@ console = Console(width=75)
|
||||||
class RichClientRenderer(ClientRenderer):
|
class RichClientRenderer(ClientRenderer):
|
||||||
|
|
||||||
def __get_os(self) -> str:
|
def __get_os(self) -> str:
|
||||||
""" checking what kind of platform you need
|
"""checking what kind of platform you need"""
|
||||||
"""
|
|
||||||
if "Linux" in str(platform.platform()):
|
if "Linux" in str(platform.platform()):
|
||||||
return "Linux"
|
return "Linux"
|
||||||
return "Windows"
|
return "Windows"
|
||||||
|
|
||||||
def print_message(self, message: str) -> Text:
|
def print_message(self, message: str) -> Text:
|
||||||
""" generating string with message in required format
|
"""generating string with message in required format"""
|
||||||
"""
|
|
||||||
# split only on the first ':' so message bodies containing ':' are preserved
|
# split only on the first ':' so message bodies containing ':' are preserved
|
||||||
parts = message.split(":", 1)
|
parts = message.split(":", 1)
|
||||||
if parts[0] == self.username:
|
if parts[0] == self.username:
|
||||||
return \
|
return (
|
||||||
Text(text=parts[0], style="bold") + \
|
Text(text=parts[0], style="bold")
|
||||||
Text(text=": ", style="bold") + \
|
+ Text(text=": ", style="bold")
|
||||||
Text(text=parts[1], style="underline")
|
+ Text(text=parts[1], style="underline")
|
||||||
return \
|
)
|
||||||
Text(text=parts[0], style="bold") + \
|
return (
|
||||||
Text(text=": ", style="bold") + \
|
Text(text=parts[0], style="bold")
|
||||||
Text(text=parts[1], style="underline")
|
+ Text(text=": ", style="bold")
|
||||||
|
+ Text(text=parts[1], style="underline")
|
||||||
|
)
|
||||||
|
|
||||||
def clear_console(self):
|
def clear_console(self):
|
||||||
# For windows clear command its cls
|
# For windows clear command its cls
|
||||||
|
|
@ -45,16 +45,10 @@ class RichClientRenderer(ClientRenderer):
|
||||||
else:
|
else:
|
||||||
os.system("cls")
|
os.system("cls")
|
||||||
|
|
||||||
def print_ip(
|
def print_ip(self, ip: str) -> str:
|
||||||
self,
|
|
||||||
ip: str
|
|
||||||
) -> str:
|
|
||||||
return ip
|
return ip
|
||||||
|
|
||||||
def print_username(
|
def print_username(self, username: str) -> str:
|
||||||
self,
|
|
||||||
username: str
|
|
||||||
) -> str:
|
|
||||||
return username
|
return username
|
||||||
|
|
||||||
def print_chat(self, response) -> None:
|
def print_chat(self, response) -> None:
|
||||||
|
|
@ -68,11 +62,11 @@ class RichClientRenderer(ClientRenderer):
|
||||||
table.add_column("USERNAME")
|
table.add_column("USERNAME")
|
||||||
for user in response["users_in_chat"]:
|
for user in response["users_in_chat"]:
|
||||||
table.add_row(
|
table.add_row(
|
||||||
self.print_ip(user.split(',')[0]),
|
self.print_ip(user.split(",")[0]),
|
||||||
self.print_username(user.split(",")[1])
|
self.print_username(user.split(",")[1]),
|
||||||
)
|
)
|
||||||
console.print(table)
|
console.print(table)
|
||||||
console.print("Write 'q' to quit from chat", justify="left")
|
console.print("Write 'q' to quit from chat", justify="left")
|
||||||
console.print(f"\n{self.print_message(actual_message)}")
|
console.print(f"\n{self.print_message(actual_message)}")
|
||||||
else:
|
else:
|
||||||
console.print(f"{self.print_message(actual_message)}")
|
console.print(f"{self.print_message(actual_message)}")
|
||||||
|
|
|
||||||
50
cmd_chat/server/factory.py
Normal file
50
cmd_chat/server/factory.py
Normal file
|
|
@ -0,0 +1,50 @@
|
||||||
|
import asyncio
|
||||||
|
from contextlib import suppress
|
||||||
|
from cryptography.fernet import Fernet
|
||||||
|
from sanic import Sanic
|
||||||
|
from sanic_ext import Extend
|
||||||
|
|
||||||
|
from .managers import ConnectionManager
|
||||||
|
from .stores import MessageStore, UserSessionStore
|
||||||
|
from .logger import logger
|
||||||
|
|
||||||
|
from .routes import register_routes
|
||||||
|
|
||||||
|
|
||||||
|
def create_app() -> Sanic:
|
||||||
|
app = Sanic("cmd-chat-server")
|
||||||
|
Extend(app)
|
||||||
|
|
||||||
|
app.ctx.message_store = MessageStore()
|
||||||
|
app.ctx.session_store = UserSessionStore()
|
||||||
|
app.ctx.connection_manager = ConnectionManager()
|
||||||
|
app.ctx.admin_password = None
|
||||||
|
app.ctx.fernet_key = Fernet.generate_key()
|
||||||
|
app.ctx.cleanup_task = None
|
||||||
|
|
||||||
|
register_lifecycle(app)
|
||||||
|
register_routes(app)
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
def register_lifecycle(app: Sanic) -> None:
|
||||||
|
@app.before_server_start
|
||||||
|
async def setup(app: Sanic):
|
||||||
|
logger.info("Server starting...")
|
||||||
|
app.ctx.cleanup_task = asyncio.create_task(cleanup_stale_sessions(app))
|
||||||
|
|
||||||
|
@app.after_server_stop
|
||||||
|
async def teardown(app: Sanic):
|
||||||
|
logger.info("Server shutting down...")
|
||||||
|
if app.ctx.cleanup_task:
|
||||||
|
app.ctx.cleanup_task.cancel()
|
||||||
|
with suppress(asyncio.CancelledError):
|
||||||
|
await app.ctx.cleanup_task
|
||||||
|
|
||||||
|
|
||||||
|
async def cleanup_stale_sessions(app: Sanic) -> None:
|
||||||
|
while True:
|
||||||
|
with suppress(asyncio.CancelledError):
|
||||||
|
await asyncio.sleep(300)
|
||||||
|
await app.ctx.session_store.cleanup_stale()
|
||||||
58
cmd_chat/server/helpers.py
Normal file
58
cmd_chat/server/helpers.py
Normal file
|
|
@ -0,0 +1,58 @@
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Optional
|
||||||
|
from dataclasses import asdict
|
||||||
|
import json
|
||||||
|
from sanic import Sanic, Request, response, Websocket
|
||||||
|
|
||||||
|
|
||||||
|
def utcnow() -> datetime:
|
||||||
|
return datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
|
def verify_password(password: Optional[str], expected: Optional[str]) -> bool:
|
||||||
|
if not expected:
|
||||||
|
return True
|
||||||
|
return password == expected
|
||||||
|
|
||||||
|
|
||||||
|
def get_client_ip(request: Request) -> str:
|
||||||
|
if forwarded := request.headers.get("x-forwarded-for"):
|
||||||
|
return forwarded.split(",")[0].strip()
|
||||||
|
return request.ip
|
||||||
|
|
||||||
|
|
||||||
|
def get_param(request: Request, name: str) -> Optional[str]:
|
||||||
|
return request.args.get(name) or request.form.get(name)
|
||||||
|
|
||||||
|
|
||||||
|
def require_auth(request: Request, app: Sanic) -> Optional[response.HTTPResponse]:
|
||||||
|
if not verify_password(get_param(request, "password"), app.ctx.admin_password):
|
||||||
|
return response.text("Unauthorized", status=401)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def send_state(ws: Websocket, app: Sanic) -> None:
|
||||||
|
messages = await app.ctx.message_store.get_all()
|
||||||
|
users = await app.ctx.session_store.get_all()
|
||||||
|
await ws.send(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"type": "init",
|
||||||
|
"messages": [asdict(m) for m in messages],
|
||||||
|
"users": [
|
||||||
|
{"user_id": u.user_id, "username": u.username} for u in users
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_pubkey(request: Request) -> Optional[bytes]:
|
||||||
|
if files := request.files.get("pubkey"):
|
||||||
|
file = files[0] if isinstance(files, list) else files
|
||||||
|
return file.body
|
||||||
|
if raw := request.form.get("pubkey"):
|
||||||
|
return raw.encode() if isinstance(raw, str) else raw
|
||||||
|
if raw := request.args.get("pubkey"):
|
||||||
|
return raw.encode() if isinstance(raw, str) else raw
|
||||||
|
return None
|
||||||
6
cmd_chat/server/logger.py
Normal file
6
cmd_chat/server/logger.py
Normal file
|
|
@ -0,0 +1,6 @@
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
48
cmd_chat/server/managers.py
Normal file
48
cmd_chat/server/managers.py
Normal file
|
|
@ -0,0 +1,48 @@
|
||||||
|
import asyncio
|
||||||
|
from typing import Optional
|
||||||
|
from sanic import Websocket
|
||||||
|
from .logger import logger
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionManager:
|
||||||
|
def __init__(self):
|
||||||
|
self.active_connections: dict[str, Websocket] = {}
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async def connect(self, user_id: str, websocket: Websocket) -> None:
|
||||||
|
async with self._lock:
|
||||||
|
self.active_connections[user_id] = websocket
|
||||||
|
logger.info(f"Client connected: {user_id}")
|
||||||
|
|
||||||
|
async def disconnect(self, user_id: str) -> None:
|
||||||
|
async with self._lock:
|
||||||
|
if user_id in self.active_connections:
|
||||||
|
del self.active_connections[user_id]
|
||||||
|
logger.info(f"Client disconnected: {user_id}")
|
||||||
|
|
||||||
|
async def broadcast(self, message: str, exclude_user: Optional[str] = None) -> None:
|
||||||
|
async with self._lock:
|
||||||
|
disconnected = []
|
||||||
|
for user_id, connection in list(self.active_connections.items()):
|
||||||
|
if exclude_user and user_id == exclude_user:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
await connection.send(message)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to send message to {user_id}: {e}")
|
||||||
|
disconnected.append(user_id)
|
||||||
|
|
||||||
|
for user_id in disconnected:
|
||||||
|
if user_id in self.active_connections:
|
||||||
|
del self.active_connections[user_id]
|
||||||
|
|
||||||
|
async def send_personal(self, user_id: str, message: str) -> bool:
|
||||||
|
async with self._lock:
|
||||||
|
if connection := self.active_connections.get(user_id):
|
||||||
|
try:
|
||||||
|
await connection.send(message)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to send personal message to {user_id}: {e}")
|
||||||
|
return False
|
||||||
|
return False
|
||||||
|
|
@ -1,5 +1,31 @@
|
||||||
from pydantic import BaseModel
|
from dataclasses import dataclass, field
|
||||||
|
from uuid import uuid4
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
class Message(BaseModel):
|
|
||||||
|
|
||||||
message: str
|
@dataclass
|
||||||
|
class Message:
|
||||||
|
id: str = field(default_factory=lambda: str(uuid4()))
|
||||||
|
text: str = ""
|
||||||
|
timestamp: str = field(default_factory=lambda: datetime.utcnow().isoformat())
|
||||||
|
user_ip: str = ""
|
||||||
|
username: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UserSession:
|
||||||
|
user_id: str
|
||||||
|
ip: str
|
||||||
|
username: str = "unknown"
|
||||||
|
fernet_key: Optional[bytes] = None
|
||||||
|
created_at: str = field(default_factory=lambda: datetime.utcnow().isoformat())
|
||||||
|
last_activity: str = field(default_factory=lambda: datetime.utcnow().isoformat())
|
||||||
|
active: bool = True
|
||||||
|
|
||||||
|
def update_activity(self):
|
||||||
|
self.last_activity = datetime.utcnow().isoformat()
|
||||||
|
|
||||||
|
def is_stale(self, timeout_seconds: int = 3600) -> bool:
|
||||||
|
last = datetime.fromisoformat(self.last_activity)
|
||||||
|
return (datetime.utcnow() - last).total_seconds() > timeout_seconds
|
||||||
|
|
|
||||||
21
cmd_chat/server/routes.py
Normal file
21
cmd_chat/server/routes.py
Normal file
|
|
@ -0,0 +1,21 @@
|
||||||
|
from sanic import Sanic, Request, Websocket
|
||||||
|
|
||||||
|
from . import views
|
||||||
|
|
||||||
|
|
||||||
|
def register_routes(app: Sanic) -> None:
|
||||||
|
@app.route("/get_key", methods=["GET", "POST"])
|
||||||
|
async def get_key_route(request: Request):
|
||||||
|
return await views.get_key(request, app)
|
||||||
|
|
||||||
|
@app.websocket("/ws/chat")
|
||||||
|
async def chat_ws_route(request: Request, ws: Websocket):
|
||||||
|
await views.chat_ws(request, ws, app)
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health_route(request: Request):
|
||||||
|
return await views.health(request, app)
|
||||||
|
|
||||||
|
@app.delete("/clear")
|
||||||
|
async def clear_route(request: Request):
|
||||||
|
return await views.clear_messages(request, app)
|
||||||
|
|
@ -1,113 +1,23 @@
|
||||||
import asyncio
|
from typing import Optional
|
||||||
import rsa
|
from .logger import logger
|
||||||
from cryptography.fernet import Fernet
|
from .factory import create_app
|
||||||
from functools import partial
|
|
||||||
from sanic.worker.loader import AppLoader
|
|
||||||
from sanic.response import HTTPResponse
|
|
||||||
from sanic import Sanic, Request, response, Websocket
|
|
||||||
from cmd_chat.server.models import Message
|
|
||||||
from cmd_chat.server.services import (
|
|
||||||
_get_bytes_and_serialize,
|
|
||||||
_check_ws_for_close_status,
|
|
||||||
_generate_new_message,
|
|
||||||
_generate_update_payload
|
|
||||||
)
|
|
||||||
|
|
||||||
app = Sanic("app")
|
app = create_app()
|
||||||
app.config.OAS = False
|
|
||||||
|
|
||||||
MESSAGES_MEMORY_DB: list[Message] = []
|
|
||||||
USERS: dict[str, str] = {}
|
|
||||||
PUBLIC_KEY = Fernet.generate_key()
|
|
||||||
|
|
||||||
|
|
||||||
def _check_password(request: Request, expected: str | None) -> bool:
|
def run_server(
|
||||||
if not expected:
|
host: str = "0.0.0.0",
|
||||||
return True
|
port: int = 8000,
|
||||||
q = request.args.get("password")
|
admin_password: Optional[str] = None,
|
||||||
f = request.form.get("password") if hasattr(request, "form") else None
|
workers: int = 1,
|
||||||
return (q or f) == expected
|
) -> None:
|
||||||
|
app.ctx.admin_password = admin_password
|
||||||
|
logger.info(f"Starting server on {host}:{port}")
|
||||||
|
|
||||||
def _get_str_arg(request: Request, name: str) -> str | None:
|
app.run(
|
||||||
return request.form.get(name) or request.args.get(name)
|
host=host,
|
||||||
|
port=port,
|
||||||
def attach_endpoints(app: Sanic):
|
workers=workers,
|
||||||
@app.websocket("/talk")
|
debug=False,
|
||||||
async def talk_ws_view(request: Request, ws: Websocket) -> HTTPResponse:
|
access_log=True,
|
||||||
if not _check_password(request, app.ctx.ADMIN_PASSWORD):
|
)
|
||||||
await ws.close(code=4001, reason="unauthorized")
|
|
||||||
return
|
|
||||||
while True:
|
|
||||||
serialized_message: dict = await _get_bytes_and_serialize(ws)
|
|
||||||
await _check_ws_for_close_status(serialized_message, ws)
|
|
||||||
text = serialized_message.get("text")
|
|
||||||
if text is None:
|
|
||||||
continue
|
|
||||||
new_message = await _generate_new_message(text)
|
|
||||||
MESSAGES_MEMORY_DB.append(new_message)
|
|
||||||
await ws.send(str({"status": "ok"}))
|
|
||||||
await asyncio.sleep(0.2)
|
|
||||||
|
|
||||||
@app.websocket("/update")
|
|
||||||
async def update_ws_view(request: Request, ws: Websocket) -> HTTPResponse:
|
|
||||||
if not _check_password(request, app.ctx.ADMIN_PASSWORD):
|
|
||||||
await ws.close(code=4001, reason="unauthorized")
|
|
||||||
return
|
|
||||||
while True:
|
|
||||||
payload = await _generate_update_payload(MESSAGES_MEMORY_DB, USERS)
|
|
||||||
await ws.send(payload.encode())
|
|
||||||
await asyncio.sleep(0.2)
|
|
||||||
|
|
||||||
@app.route('/get_key', methods=['GET', 'POST'])
|
|
||||||
async def get_key_view(request: Request) -> HTTPResponse:
|
|
||||||
if not _check_password(request, app.ctx.ADMIN_PASSWORD):
|
|
||||||
return response.text("unauthorized", status=401)
|
|
||||||
|
|
||||||
pubkey_bytes: bytes | None = None
|
|
||||||
|
|
||||||
if "pubkey" in request.files and request.files.get("pubkey"):
|
|
||||||
f = request.files.get("pubkey")
|
|
||||||
if isinstance(f, list):
|
|
||||||
f = f[0]
|
|
||||||
pubkey_bytes = f.body
|
|
||||||
|
|
||||||
if pubkey_bytes is None:
|
|
||||||
raw = request.form.get("pubkey")
|
|
||||||
if raw:
|
|
||||||
pubkey_bytes = raw if isinstance(raw, bytes) else str(raw).encode()
|
|
||||||
|
|
||||||
if pubkey_bytes is None:
|
|
||||||
raw = request.args.get("pubkey")
|
|
||||||
if raw:
|
|
||||||
pubkey_bytes = raw.encode()
|
|
||||||
|
|
||||||
if not pubkey_bytes:
|
|
||||||
return response.text("bad request: pubkey is required", status=400)
|
|
||||||
|
|
||||||
try:
|
|
||||||
public_key = rsa.PublicKey.load_pkcs1(pubkey_bytes)
|
|
||||||
except Exception as e:
|
|
||||||
return response.text(f"bad pubkey: {e}", status=400)
|
|
||||||
|
|
||||||
encrypted_data = rsa.encrypt(PUBLIC_KEY, public_key)
|
|
||||||
|
|
||||||
username = _get_str_arg(request, "username") or "unknown"
|
|
||||||
user_key = f"{request.ip}, {username}"
|
|
||||||
if user_key not in USERS:
|
|
||||||
USERS[user_key] = PUBLIC_KEY
|
|
||||||
|
|
||||||
return response.raw(encrypted_data)
|
|
||||||
|
|
||||||
|
|
||||||
def create_app(app_name: str, admin_password: str | None) -> Sanic:
|
|
||||||
app = Sanic(app_name)
|
|
||||||
app.ctx.ADMIN_PASSWORD = admin_password
|
|
||||||
attach_endpoints(app)
|
|
||||||
return app
|
|
||||||
|
|
||||||
|
|
||||||
def run_server(host: str, port: int, dev: bool = False, admin_password: str | None = None) -> None:
|
|
||||||
loader = AppLoader(factory=partial(create_app, "CMD_SERVER", admin_password))
|
|
||||||
app = loader.load()
|
|
||||||
app.prepare(host=host, port=port, dev=dev)
|
|
||||||
Sanic.serve(primary=app, app_loader=loader)
|
|
||||||
|
|
|
||||||
|
|
@ -1,36 +0,0 @@
|
||||||
import ast
|
|
||||||
from sanic import Websocket
|
|
||||||
from cmd_chat.server.models import Message
|
|
||||||
|
|
||||||
|
|
||||||
async def _get_bytes_and_serialize(
|
|
||||||
ws: Websocket
|
|
||||||
) -> dict:
|
|
||||||
return ast.literal_eval(await ws.recv())
|
|
||||||
|
|
||||||
|
|
||||||
async def _check_ws_for_close_status(
|
|
||||||
response: dict,
|
|
||||||
ws: Websocket
|
|
||||||
) -> None:
|
|
||||||
if "action" in response.keys():
|
|
||||||
if response["action"] == "close":
|
|
||||||
await ws.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def _generate_new_message(
|
|
||||||
message: str
|
|
||||||
) -> Message:
|
|
||||||
return Message(message = message)
|
|
||||||
|
|
||||||
|
|
||||||
async def _generate_update_payload(
|
|
||||||
memory_msgs: list[Message],
|
|
||||||
users_structure: dict
|
|
||||||
) -> str:
|
|
||||||
return str({
|
|
||||||
"messages": [i.message for i in memory_msgs],
|
|
||||||
"users_in_chat": list(users_structure.keys())
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
79
cmd_chat/server/stores.py
Normal file
79
cmd_chat/server/stores.py
Normal file
|
|
@ -0,0 +1,79 @@
|
||||||
|
import asyncio
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from .models import Message, UserSession
|
||||||
|
from .logger import logger
|
||||||
|
|
||||||
|
|
||||||
|
class MessageStore:
|
||||||
|
def __init__(self):
|
||||||
|
self._messages: list[Message] = []
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async def add(self, message: Message) -> None:
|
||||||
|
async with self._lock:
|
||||||
|
self._messages.append(message)
|
||||||
|
logger.info(f"Message added: {message.id} from {message.username}")
|
||||||
|
|
||||||
|
async def get_all(self) -> list[Message]:
|
||||||
|
async with self._lock:
|
||||||
|
return self._messages.copy()
|
||||||
|
|
||||||
|
async def clear(self) -> None:
|
||||||
|
async with self._lock:
|
||||||
|
count = len(self._messages)
|
||||||
|
self._messages.clear()
|
||||||
|
logger.info(f"Cleared {count} messages")
|
||||||
|
|
||||||
|
async def count(self) -> int:
|
||||||
|
async with self._lock:
|
||||||
|
return len(self._messages)
|
||||||
|
|
||||||
|
|
||||||
|
class UserSessionStore:
|
||||||
|
def __init__(self):
|
||||||
|
self._sessions: dict[str, UserSession] = {}
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async def add(self, session: UserSession) -> None:
|
||||||
|
async with self._lock:
|
||||||
|
self._sessions[session.user_id] = session
|
||||||
|
logger.info(f"Session created: {session.user_id} ({session.username})")
|
||||||
|
|
||||||
|
async def get(self, user_id: str) -> Optional[UserSession]:
|
||||||
|
async with self._lock:
|
||||||
|
return self._sessions.get(user_id)
|
||||||
|
|
||||||
|
async def update_activity(self, user_id: str) -> None:
|
||||||
|
async with self._lock:
|
||||||
|
if session := self._sessions.get(user_id):
|
||||||
|
session.update_activity()
|
||||||
|
|
||||||
|
async def remove(self, user_id: str) -> None:
|
||||||
|
async with self._lock:
|
||||||
|
if user_id in self._sessions:
|
||||||
|
del self._sessions[user_id]
|
||||||
|
logger.info(f"Session removed: {user_id}")
|
||||||
|
|
||||||
|
async def cleanup_stale(self, timeout_seconds: int = 3600) -> int:
|
||||||
|
async with self._lock:
|
||||||
|
stale_ids = [
|
||||||
|
uid for uid, s in self._sessions.items() if s.is_stale(timeout_seconds)
|
||||||
|
]
|
||||||
|
for uid in stale_ids:
|
||||||
|
del self._sessions[uid]
|
||||||
|
if stale_ids:
|
||||||
|
logger.info(f"Cleaned up {len(stale_ids)} stale sessions")
|
||||||
|
return len(stale_ids)
|
||||||
|
|
||||||
|
async def get_all(self) -> list[UserSession]:
|
||||||
|
async with self._lock:
|
||||||
|
return list(self._sessions.values())
|
||||||
|
|
||||||
|
async def count(self) -> int:
|
||||||
|
async with self._lock:
|
||||||
|
return len(self._sessions)
|
||||||
|
|
||||||
|
async def username_exists(self, username: str) -> bool:
|
||||||
|
async with self._lock:
|
||||||
|
return any(s.username == username for s in self._sessions.values())
|
||||||
138
cmd_chat/server/views.py
Normal file
138
cmd_chat/server/views.py
Normal file
|
|
@ -0,0 +1,138 @@
|
||||||
|
from dataclasses import asdict
|
||||||
|
from uuid import uuid4
|
||||||
|
import json
|
||||||
|
|
||||||
|
import rsa
|
||||||
|
from sanic import Sanic, Request, response, Websocket
|
||||||
|
from sanic.response import HTTPResponse, json as json_response
|
||||||
|
|
||||||
|
from .models import Message, UserSession
|
||||||
|
from .logger import logger
|
||||||
|
from .helpers import (
|
||||||
|
require_auth,
|
||||||
|
extract_pubkey,
|
||||||
|
get_client_ip,
|
||||||
|
get_param,
|
||||||
|
verify_password,
|
||||||
|
send_state,
|
||||||
|
utcnow,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_key(request: Request, app: Sanic) -> HTTPResponse:
|
||||||
|
if err := require_auth(request, app):
|
||||||
|
return err
|
||||||
|
|
||||||
|
pubkey_bytes = extract_pubkey(request)
|
||||||
|
if not pubkey_bytes:
|
||||||
|
return response.text("Bad request: pubkey is required", status=400)
|
||||||
|
|
||||||
|
try:
|
||||||
|
public_key = rsa.PublicKey.load_pkcs1(pubkey_bytes)
|
||||||
|
if public_key.n.bit_length() < 2048:
|
||||||
|
raise ValueError("RSA key must be at least 2048 bits")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Invalid public key: {e}")
|
||||||
|
return response.text(f"Bad pubkey: {e}", status=400)
|
||||||
|
|
||||||
|
username = get_param(request, "username") or "unknown"
|
||||||
|
|
||||||
|
if await app.ctx.session_store.username_exists(username):
|
||||||
|
return response.text(f"Username '{username}' is already taken", status=409)
|
||||||
|
|
||||||
|
session = UserSession(
|
||||||
|
user_id=str(uuid4()),
|
||||||
|
ip=get_client_ip(request),
|
||||||
|
username=get_param(request, "username") or "unknown",
|
||||||
|
fernet_key=app.ctx.fernet_key,
|
||||||
|
)
|
||||||
|
await app.ctx.session_store.add(session)
|
||||||
|
|
||||||
|
try:
|
||||||
|
encrypted_key = rsa.encrypt(app.ctx.fernet_key, public_key)
|
||||||
|
logger.info(f"Key exchange: user={session.username}, session={session.user_id}")
|
||||||
|
|
||||||
|
return response.raw(
|
||||||
|
encrypted_key,
|
||||||
|
content_type="application/octet-stream",
|
||||||
|
headers={"X-User-Id": session.user_id},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Encryption failed: {e}")
|
||||||
|
return response.text("Key encryption failed", status=500)
|
||||||
|
|
||||||
|
|
||||||
|
async def chat_ws(request: Request, ws: Websocket, app: Sanic) -> None:
|
||||||
|
user_id = request.args.get("user_id")
|
||||||
|
|
||||||
|
if not user_id:
|
||||||
|
await ws.close(code=4002, reason="user_id required")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not verify_password(request.args.get("password"), app.ctx.admin_password):
|
||||||
|
await ws.close(code=4001, reason="Unauthorized")
|
||||||
|
return
|
||||||
|
|
||||||
|
session = await app.ctx.session_store.get(user_id)
|
||||||
|
if not session:
|
||||||
|
await ws.close(code=4002, reason="Invalid session")
|
||||||
|
return
|
||||||
|
|
||||||
|
manager = app.ctx.connection_manager
|
||||||
|
await manager.connect(user_id, ws)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await send_state(ws, app)
|
||||||
|
|
||||||
|
async for data in ws:
|
||||||
|
if data is None:
|
||||||
|
break
|
||||||
|
|
||||||
|
await app.ctx.session_store.update_activity(user_id)
|
||||||
|
|
||||||
|
message = Message(
|
||||||
|
text=str(data),
|
||||||
|
user_ip=session.ip,
|
||||||
|
username=session.username,
|
||||||
|
)
|
||||||
|
await app.ctx.message_store.add(message)
|
||||||
|
|
||||||
|
await manager.broadcast(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"type": "message",
|
||||||
|
"data": asdict(message),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"WebSocket error for {user_id}: {e}")
|
||||||
|
finally:
|
||||||
|
await manager.disconnect(user_id)
|
||||||
|
await manager.broadcast(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"type": "user_left",
|
||||||
|
"user_id": user_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def health(request: Request, app: Sanic) -> HTTPResponse:
|
||||||
|
return json_response(
|
||||||
|
{
|
||||||
|
"status": "ok",
|
||||||
|
"messages": await app.ctx.message_store.count(),
|
||||||
|
"users": await app.ctx.session_store.count(),
|
||||||
|
"timestamp": utcnow().isoformat(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def clear_messages(request: Request, app: Sanic) -> HTTPResponse:
|
||||||
|
if err := require_auth(request, app):
|
||||||
|
return err
|
||||||
|
await app.ctx.message_store.clear()
|
||||||
|
return json_response({"status": "cleared"})
|
||||||
|
|
@ -6,4 +6,6 @@ colorama
|
||||||
pydantic
|
pydantic
|
||||||
websocket-client
|
websocket-client
|
||||||
flask
|
flask
|
||||||
rich
|
rich
|
||||||
|
sanic_ext
|
||||||
|
websockets
|
||||||
16
setup.py
16
setup.py
|
|
@ -14,13 +14,9 @@ setuptools.setup(
|
||||||
long_description=description,
|
long_description=description,
|
||||||
long_description_content_type="text/markdown",
|
long_description_content_type="text/markdown",
|
||||||
url="https://github.com/dinosaurtirex/cmd-chat",
|
url="https://github.com/dinosaurtirex/cmd-chat",
|
||||||
license='MIT',
|
license="MIT",
|
||||||
python_requires='>=3.10',
|
python_requires=">=3.10",
|
||||||
entry_points={
|
entry_points={"console_scripts": ["cmd_chat = cmd_chat:main"]},
|
||||||
'console_scripts': [
|
|
||||||
'cmd_chat = cmd_chat:main'
|
|
||||||
]
|
|
||||||
},
|
|
||||||
install_requires=[
|
install_requires=[
|
||||||
"sanic",
|
"sanic",
|
||||||
"requests",
|
"requests",
|
||||||
|
|
@ -28,6 +24,6 @@ setuptools.setup(
|
||||||
"cryptography",
|
"cryptography",
|
||||||
"colorama",
|
"colorama",
|
||||||
"pydantic",
|
"pydantic",
|
||||||
"websocket-client"
|
"websocket-client",
|
||||||
]
|
],
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user