New storage scheme

This commit is contained in:
mirai 2026-01-06 21:42:50 +08:00
parent 264d19e932
commit 467d942877
9 changed files with 1287 additions and 23 deletions

View File

@ -29,37 +29,66 @@ every "secure" messenger still stores metadata somewhere. this doesn't. it's jus
│ │─────── POST /srp/init {username, A} ───────► │ │ │ │─────── POST /srp/init {username, A} ───────► │ │
│ │ (A = client public ephemeral) │ │ │ │ (A = client public ephemeral) │ │
│ │ │ │ │ │ │ │
│ │◄────────── {user_id, B, salt} ────────────── │ │ │ │◄──── {user_id, B, salt, room_salt} ───────── │ │
│ │ (B = server public ephemeral) │ │ │ │ (B = server public ephemeral) │ │
│ │ (room_salt = E2E key derivation) │ │
│ │ │ │ │ │ │ │
│ │ [client derives room_key via HKDF: │ │
│ │ room_key = HKDF(password, room_salt)] │ │
│ │ │ │ │ │ │ │
│ │ [both sides compute shared session key │ │ │ │ [both sides compute SRP session key │ │
│ │ using password + ephemeral values] │ │ │ │ using password + ephemeral values] │ │
│ │ │ │ │ │ │ │
│ │ │ │ │ │─────── POST /srp/verify {user_id, M} ──────► │ │
│ │─────── POST /srp/verify {user_id, M} ──────► │ |
│ │ (M = client proof) │ │ │ │ (M = client proof) │ │
│ │ │ │ │ │ │ │
│ │◄────────── {H_AMK, session_key} ──────────── │ │ │ │◄────────── {H_AMK, session_key} ──────────── │ │
│ │ (H_AMK = server proof) │ │ │ │ (H_AMK = server proof) │ │
│ │ │ │ │ │ │ │
│ │ [password never transmitted] │ │ │ │ [password never transmitted] │ │
│ │ [MITM can't derive session key] │ | │ │ [MITM can't derive session key] │
│ │ │ │ │ │ │ │
├──────────────────────────────────────────────────────────────────┤ ├──────────────────────────────────────────────────────────────────┤
ENCRYPTED CHAT E2E ENCRYPTED CHAT
├──────────────────────────────────────────────────────────────────┤ ├──────────────────────────────────────────────────────────────────┤
│ │ │ │ │ │ │ │
│ │═══════ WebSocket /ws/chat?user_id ═════════► │ │ │ │═══════ WebSocket /ws/chat?user_id ═════════► │ │
│ │ (authenticated session) │ │ │ │ (authenticated session) │ │
│ │ │ │ │ │ │ │
│ │◄═══════════ AES-encrypted messages ════════► │ │
│ │ (Fernet = AES-128-CBC + HMAC) │ │
│ │ │ │ │ │ │ │
│ ┌─┴─┐ ┌──┴──┐ │
│ │ C │──── encrypt(msg, room_key) ───────────►│ S │ │
│ │ L │ │ E │ │
│ │ I │◄─── ciphertext (broadcast) ────────────│ R │ │
│ │ E │ │ V │ │
│ │ N │ decrypt(ciphertext, room_key) │ E │ │
│ │ T │ │ R │ │
│ └─┬─┘ └──┬──┘ │
│ │ │ │
│ │ [server stores ONLY ciphertext] │ │
│ │ [server CANNOT read messages] │ │
│ │ [all clients with same password │ │
│ │ derive identical room_key] │ │
│ │ │ │
│ │ Encryption: Fernet (AES-128-CBC + HMAC) │ │
│ │ Key derivation: HKDF-SHA256 │ │
│ │ │ │ │ │ │ │
│ │ [on disconnect: keys wiped from RAM] │ │ │ │ [on disconnect: keys wiped from RAM] │ │
│ │ │ │ │ │ │ │
└──────────────────────────────────────────────────────────────────┘ └──────────────────────────────────────────────────────────────────┘
┌──────────────────────────────────────────────────────────────────┐
│ KEY HIERARCHY │
├──────────────────────────────────────────────────────────────────┤
│ │
│ password ──┬──► SRP ──► session_key (per-user, auth only) │
│ │ │
│ └──► HKDF(password, room_salt) ──► room_key (shared) │
│ │
│ room_salt: generated once at server start │
│ room_key: deterministic, same for all clients with same pwd │
│ │
└──────────────────────────────────────────────────────────────────┘
``` ```
**SRP (Secure Remote Password)** — password is never sent over the network. both sides prove they know it via zero-knowledge proof, then derive identical session keys. **SRP (Secure Remote Password)** — password is never sent over the network. both sides prove they know it via zero-knowledge proof, then derive identical session keys.
@ -67,8 +96,6 @@ every "secure" messenger still stores metadata somewhere. this doesn't. it's jus
## install ## install
```bash ```bash
git clone https://github.com/emilycodestar/cmd-chat.git
cd cmd-chat
python -m venv venv && source venv/bin/activate && pip install -r requirements.txt python -m venv venv && source venv/bin/activate && pip install -r requirements.txt
``` ```

View File

@ -6,6 +6,8 @@ from typing import Optional
import srp import srp
import requests import requests
from cryptography.fernet import Fernet from cryptography.fernet import Fernet
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.hazmat.primitives import hashes
import websockets import websockets
from rich.console import Console from rich.console import Console
from rich.panel import Panel from rich.panel import Panel
@ -23,6 +25,7 @@ class Client:
self.password = (password or "").encode() self.password = (password or "").encode()
self.user_id: Optional[str] = None self.user_id: Optional[str] = None
self.fernet: Optional[Fernet] = None self.fernet: Optional[Fernet] = None
self.room_fernet: Optional[Fernet] = None
self.console = Console() self.console = Console()
self.messages: list[dict] = [] self.messages: list[dict] = []
@ -67,6 +70,16 @@ class Client:
self.user_id = init_data["user_id"] self.user_id = init_data["user_id"]
B = base64.b64decode(init_data["B"]) B = base64.b64decode(init_data["B"])
salt = base64.b64decode(init_data["salt"]) salt = base64.b64decode(init_data["salt"])
room_salt = base64.b64decode(init_data["room_salt"])
hkdf = HKDF(
algorithm=hashes.SHA256(),
length=32,
salt=room_salt,
info=b"cmd-chat-room-key",
)
room_key = hkdf.derive(self.password)
self.room_fernet = Fernet(base64.urlsafe_b64encode(room_key))
M = usr.process_challenge(salt, B) M = usr.process_challenge(salt, B)
@ -96,6 +109,15 @@ class Client:
self.success(f"SRP authenticated (session: {self.user_id[:8]}...)") self.success(f"SRP authenticated (session: {self.user_id[:8]}...)")
def decrypt_message(self, msg: dict) -> dict:
if "text" in msg and msg["text"]:
try:
decrypted = self.room_fernet.decrypt(msg["text"].encode()).decode()
msg["text"] = decrypted
except Exception:
msg["text"] = "[decrypt failed]"
return msg
def render_messages(self) -> None: def render_messages(self) -> None:
self.console.clear() self.console.clear()
@ -131,12 +153,15 @@ class Client:
msg_type = data.get("type", "") msg_type = data.get("type", "")
if msg_type == "init": if msg_type == "init":
self.messages = data.get("messages", []) messages = [
self.decrypt_message(m) for m in data.get("messages", [])
]
self.messages = messages
self.users = data.get("users", []) self.users = data.get("users", [])
self.connected = True self.connected = True
self.render_messages() self.render_messages()
elif msg_type == "message": elif msg_type == "message":
msg_data = data.get("data", {}) msg_data = self.decrypt_message(data.get("data", {}))
self.messages.append(msg_data) self.messages.append(msg_data)
self.render_messages() self.render_messages()
elif msg_type == "user_left": elif msg_type == "user_left":
@ -156,7 +181,8 @@ class Client:
self.running = False self.running = False
break break
if text.strip(): if text.strip():
await ws.send(text) encrypted = self.room_fernet.encrypt(text.encode()).decode()
await ws.send(encrypted)
except (EOFError, KeyboardInterrupt): except (EOFError, KeyboardInterrupt):
self.running = False self.running = False
break break

View File

@ -1 +0,0 @@
MESSAGES_TO_SHOW = 5

View File

@ -3,7 +3,7 @@ from contextlib import suppress
from cryptography.fernet import Fernet from cryptography.fernet import Fernet
from sanic import Sanic from sanic import Sanic
from sanic_ext import Extend from sanic_ext import Extend
import os
from .managers import ConnectionManager from .managers import ConnectionManager
from .stores import MessageStore, UserSessionStore from .stores import MessageStore, UserSessionStore
from .srp_auth import SRPAuthManager from .srp_auth import SRPAuthManager
@ -19,7 +19,7 @@ def create_app(password: str = "", name: str = "cmd-chat-server") -> Sanic:
app.ctx.session_store = UserSessionStore() app.ctx.session_store = UserSessionStore()
app.ctx.connection_manager = ConnectionManager() app.ctx.connection_manager = ConnectionManager()
app.ctx.srp_manager = SRPAuthManager(password) app.ctx.srp_manager = SRPAuthManager(password)
app.ctx.fernet_key = Fernet.generate_key() app.ctx.room_salt = os.urandom(16)
app.ctx.cleanup_task = None app.ctx.cleanup_task = None
register_lifecycle(app) register_lifecycle(app)

View File

@ -9,7 +9,6 @@ class MessageStore:
def add(self, message: Message) -> None: def add(self, message: Message) -> None:
self._messages.append(message) self._messages.append(message)
def get_all(self) -> list[Message]: def get_all(self) -> list[Message]:
return self._messages.copy() return self._messages.copy()
@ -17,7 +16,6 @@ class MessageStore:
count = len(self._messages) count = len(self._messages)
self._messages.clear() self._messages.clear()
def count(self) -> int: def count(self) -> int:
return len(self._messages) return len(self._messages)
@ -29,7 +27,6 @@ class UserSessionStore:
def add(self, session: UserSession) -> None: def add(self, session: UserSession) -> None:
self._sessions[session.user_id] = session self._sessions[session.user_id] = session
def get(self, user_id: str) -> Optional[UserSession]: def get(self, user_id: str) -> Optional[UserSession]:
return self._sessions.get(user_id) return self._sessions.get(user_id)
@ -41,7 +38,6 @@ class UserSessionStore:
if user_id in self._sessions: if user_id in self._sessions:
del self._sessions[user_id] del self._sessions[user_id]
def cleanup_stale(self, timeout_seconds: int = 3600) -> int: def cleanup_stale(self, timeout_seconds: int = 3600) -> int:
stale_ids = [ stale_ids = [
uid for uid, s in self._sessions.items() if s.is_stale(timeout_seconds) uid for uid, s in self._sessions.items() if s.is_stale(timeout_seconds)

View File

@ -35,6 +35,7 @@ async def srp_init(request: Request, app: Sanic) -> HTTPResponse:
"user_id": user_id, "user_id": user_id,
"B": base64.b64encode(B).decode(), "B": base64.b64encode(B).decode(),
"salt": base64.b64encode(salt).decode(), "salt": base64.b64encode(salt).decode(),
"room_salt": base64.b64encode(app.ctx.room_salt).decode(),
} }
) )
@ -66,7 +67,6 @@ async def srp_verify(request: Request, app: Sanic) -> HTTPResponse:
) )
app.ctx.session_store.add(session) app.ctx.session_store.add(session)
return response.json( return response.json(
{ {
"H_AMK": base64.b64encode(H_AMK).decode(), "H_AMK": base64.b64encode(H_AMK).decode(),

View File

@ -6,9 +6,9 @@ sys.path.insert(0, str(Path(__file__).parent.parent))
import uuid import uuid
import pytest import pytest
from sanic_testing import TestManager from sanic_testing import TestManager
import os
from sanic import Sanic from sanic import Sanic
from sanic_ext import Extend from sanic_ext import Extend
from cryptography.fernet import Fernet
from cmd_chat.server.managers import ConnectionManager from cmd_chat.server.managers import ConnectionManager
from cmd_chat.server.stores import MessageStore, UserSessionStore from cmd_chat.server.stores import MessageStore, UserSessionStore
@ -27,7 +27,7 @@ def app():
app.ctx.session_store = UserSessionStore() app.ctx.session_store = UserSessionStore()
app.ctx.connection_manager = ConnectionManager() app.ctx.connection_manager = ConnectionManager()
app.ctx.srp_manager = SRPAuthManager("testpassword") app.ctx.srp_manager = SRPAuthManager("testpassword")
app.ctx.fernet_key = Fernet.generate_key() app.ctx.room_salt = os.urandom(16)
app.ctx.cleanup_task = None app.ctx.cleanup_task = None
register_routes(app) register_routes(app)

365
tests/test_client.py Normal file
View File

@ -0,0 +1,365 @@
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
import os
import uuid
import base64
import json
import pytest
from unittest.mock import Mock, patch, AsyncMock, MagicMock
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.hazmat.primitives import hashes
from cmd_chat.client.client import Client
@pytest.fixture
def client():
return Client(
server="127.0.0.1",
port=3000,
username="testuser",
password="testpassword",
)
@pytest.fixture
def room_salt():
return os.urandom(16)
@pytest.fixture
def room_fernet(room_salt):
hkdf = HKDF(
algorithm=hashes.SHA256(),
length=32,
salt=room_salt,
info=b"cmd-chat-room-key",
)
room_key = hkdf.derive(b"testpassword")
return Fernet(base64.urlsafe_b64encode(room_key))
class TestClientInit:
def test_client_creation(self, client):
assert client.server == "127.0.0.1"
assert client.port == 3000
assert client.username == "testuser"
assert client.password == b"testpassword"
assert client.user_id is None
assert client.fernet is None
assert client.room_fernet is None
assert client.connected is False
assert client.running is False
def test_client_urls(self, client):
assert client.base_url == "http://127.0.0.1:3000"
assert client.ws_url == "ws://127.0.0.1:3000"
def test_client_empty_password(self):
client = Client("localhost", 8080, "user", None)
assert client.password == b""
class TestEncryption:
def test_decrypt_message_success(self, client, room_salt, room_fernet):
hkdf = HKDF(
algorithm=hashes.SHA256(),
length=32,
salt=room_salt,
info=b"cmd-chat-room-key",
)
room_key = hkdf.derive(client.password)
client.room_fernet = Fernet(base64.urlsafe_b64encode(room_key))
original_text = "Hello, World!"
encrypted = room_fernet.encrypt(original_text.encode()).decode()
msg = {"text": encrypted, "username": "other"}
decrypted_msg = client.decrypt_message(msg)
assert decrypted_msg["text"] == original_text
assert decrypted_msg["username"] == "other"
def test_decrypt_message_failure(self, client):
client.room_fernet = Fernet(Fernet.generate_key())
msg = {"text": "not-valid-ciphertext", "username": "other"}
decrypted_msg = client.decrypt_message(msg)
assert decrypted_msg["text"] == "[decrypt failed]"
def test_decrypt_message_empty_text(self, client):
client.room_fernet = Fernet(Fernet.generate_key())
msg = {"text": "", "username": "other"}
result = client.decrypt_message(msg)
assert result["text"] == ""
def test_decrypt_message_no_text_field(self, client):
client.room_fernet = Fernet(Fernet.generate_key())
msg = {"username": "other"}
result = client.decrypt_message(msg)
assert "text" not in result
def test_hkdf_deterministic(self, room_salt):
password = b"testpassword"
hkdf1 = HKDF(
algorithm=hashes.SHA256(),
length=32,
salt=room_salt,
info=b"cmd-chat-room-key",
)
key1 = hkdf1.derive(password)
hkdf2 = HKDF(
algorithm=hashes.SHA256(),
length=32,
salt=room_salt,
info=b"cmd-chat-room-key",
)
key2 = hkdf2.derive(password)
assert key1 == key2
def test_hkdf_different_passwords(self, room_salt):
hkdf1 = HKDF(
algorithm=hashes.SHA256(),
length=32,
salt=room_salt,
info=b"cmd-chat-room-key",
)
key1 = hkdf1.derive(b"password1")
hkdf2 = HKDF(
algorithm=hashes.SHA256(),
length=32,
salt=room_salt,
info=b"cmd-chat-room-key",
)
key2 = hkdf2.derive(b"password2")
assert key1 != key2
def test_hkdf_different_salts(self):
password = b"testpassword"
hkdf1 = HKDF(
algorithm=hashes.SHA256(),
length=32,
salt=os.urandom(16),
info=b"cmd-chat-room-key",
)
key1 = hkdf1.derive(password)
hkdf2 = HKDF(
algorithm=hashes.SHA256(),
length=32,
salt=os.urandom(16),
info=b"cmd-chat-room-key",
)
key2 = hkdf2.derive(password)
assert key1 != key2
class TestMessageHandling:
def test_render_messages_empty(self, client, capsys):
client.room_fernet = Fernet(Fernet.generate_key())
client.messages = []
client.users = []
with patch.object(client.console, "clear"):
client.render_messages()
def test_render_messages_with_data(self, client):
client.room_fernet = Fernet(Fernet.generate_key())
client.messages = [
{
"username": "testuser",
"text": "Hello",
"timestamp": "2024-01-01T12:00:00",
},
{"username": "other", "text": "Hi", "timestamp": "2024-01-01T12:01:00"},
]
client.users = [
{"user_id": "1", "username": "testuser"},
{"user_id": "2", "username": "other"},
]
with patch.object(client.console, "clear"):
client.render_messages()
def test_messages_limit_15(self, client):
client.room_fernet = Fernet(Fernet.generate_key())
client.messages = [
{"username": "user", "text": f"msg{i}", "timestamp": "2024-01-01T12:00:00"}
for i in range(20)
]
client.users = []
with patch.object(client.console, "clear"):
with patch.object(client.console, "print") as mock_print:
client.render_messages()
msg_calls = [
call
for call in mock_print.call_args_list
if any("msg" in str(arg) for arg in call[0])
]
assert len(msg_calls) == 15
class TestReceiveLoop:
@pytest.mark.asyncio
async def test_receive_init_message(self, client, room_fernet):
client.room_fernet = room_fernet
client.running = True
encrypted_text = room_fernet.encrypt(b"Hello").decode()
init_data = json.dumps(
{
"type": "init",
"messages": [{"text": encrypted_text, "username": "other"}],
"users": [{"user_id": "123", "username": "other"}],
}
)
mock_ws = AsyncMock()
mock_ws.__aiter__.return_value = [init_data]
with patch.object(client, "render_messages"):
await client.receive_loop(mock_ws)
assert client.connected is True
assert len(client.messages) == 1
assert client.messages[0]["text"] == "Hello"
assert len(client.users) == 1
@pytest.mark.asyncio
async def test_receive_message(self, client, room_fernet):
client.room_fernet = room_fernet
client.running = True
client.messages = []
encrypted_text = room_fernet.encrypt(b"New message").decode()
msg_data = json.dumps(
{
"type": "message",
"data": {"text": encrypted_text, "username": "sender"},
}
)
mock_ws = AsyncMock()
mock_ws.__aiter__.return_value = [msg_data]
with patch.object(client, "render_messages"):
await client.receive_loop(mock_ws)
assert len(client.messages) == 1
assert client.messages[0]["text"] == "New message"
@pytest.mark.asyncio
async def test_receive_user_left(self, client):
client.room_fernet = Fernet(Fernet.generate_key())
client.running = True
client.users = [
{"user_id": "123", "username": "user1"},
{"user_id": "456", "username": "user2"},
]
left_data = json.dumps(
{
"type": "user_left",
"user_id": "123",
}
)
mock_ws = AsyncMock()
mock_ws.__aiter__.return_value = [left_data]
with patch.object(client, "render_messages"):
await client.receive_loop(mock_ws)
assert len(client.users) == 1
assert client.users[0]["user_id"] == "456"
class TestInputLoop:
@pytest.mark.asyncio
async def test_send_encrypted_message(self, client, room_fernet):
client.room_fernet = room_fernet
client.running = True
mock_ws = AsyncMock()
sent_messages = []
mock_ws.send = AsyncMock(side_effect=lambda m: sent_messages.append(m))
inputs = iter(["hello", "q"])
with patch("asyncio.get_event_loop") as mock_loop:
mock_executor = AsyncMock(side_effect=lambda _, __: next(inputs))
mock_loop.return_value.run_in_executor = mock_executor
await client.input_loop(mock_ws)
assert len(sent_messages) == 1
decrypted = room_fernet.decrypt(sent_messages[0].encode()).decode()
assert decrypted == "hello"
@pytest.mark.asyncio
async def test_quit_command(self, client):
client.room_fernet = Fernet(Fernet.generate_key())
client.running = True
mock_ws = AsyncMock()
with patch("asyncio.get_event_loop") as mock_loop:
mock_executor = AsyncMock(return_value="quit")
mock_loop.return_value.run_in_executor = mock_executor
await client.input_loop(mock_ws)
assert client.running is False
@pytest.mark.asyncio
async def test_empty_message_not_sent(self, client):
client.room_fernet = Fernet(Fernet.generate_key())
client.running = True
mock_ws = AsyncMock()
inputs = iter(["", " ", "q"])
with patch("asyncio.get_event_loop") as mock_loop:
mock_executor = AsyncMock(side_effect=lambda _, __: next(inputs))
mock_loop.return_value.run_in_executor = mock_executor
await client.input_loop(mock_ws)
mock_ws.send.assert_not_called()
class TestConsoleOutput:
def test_success_message(self, client, capsys):
client.success("Test success")
def test_error_message(self, client, capsys):
client.error("Test error")
def test_info_message(self, client, capsys):
client.info("Test info")

View File

@ -0,0 +1,851 @@
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
import os
import base64
import json
import pytest
from unittest.mock import patch, AsyncMock, MagicMock, PropertyMock
import requests
import websockets
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.hazmat.primitives import hashes
from cmd_chat.client.client import Client
@pytest.fixture
def client():
return Client(
server="127.0.0.1",
port=3000,
username="testuser",
password="testpassword",
)
@pytest.fixture
def room_salt():
return os.urandom(16)
@pytest.fixture
def room_fernet(room_salt):
hkdf = HKDF(
algorithm=hashes.SHA256(),
length=32,
salt=room_salt,
info=b"cmd-chat-room-key",
)
room_key = hkdf.derive(b"testpassword")
return Fernet(base64.urlsafe_b64encode(room_key))
class TestClientProperties:
def test_base_url_different_ports(self):
client = Client("example.com", 8080, "user", "pass")
assert client.base_url == "http://example.com:8080"
def test_ws_url_different_ports(self):
client = Client("example.com", 8080, "user", "pass")
assert client.ws_url == "ws://example.com:8080"
def test_base_url_localhost(self):
client = Client("localhost", 443, "user", "pass")
assert client.base_url == "http://localhost:443"
def test_password_encoding_unicode(self):
client = Client("localhost", 3000, "user", "пароль123")
assert client.password == "пароль123".encode()
def test_password_encoding_special_chars(self):
client = Client("localhost", 3000, "user", "p@$$w0rd!#%")
assert client.password == b"p@$$w0rd!#%"
class TestSRPAuthentication:
@patch("cmd_chat.client.client.requests.post")
def test_srp_authenticate_success(self, mock_post, client, room_salt):
import srp
init_response = MagicMock()
init_response.json.return_value = {
"user_id": "test-user-id-12345",
"B": base64.b64encode(os.urandom(256)).decode(),
"salt": base64.b64encode(os.urandom(16)).decode(),
"room_salt": base64.b64encode(room_salt).decode(),
}
init_response.raise_for_status = MagicMock()
verify_response = MagicMock()
verify_response.json.return_value = {
"H_AMK": base64.b64encode(os.urandom(32)).decode(),
"session_key": base64.b64encode(Fernet.generate_key()).decode(),
}
verify_response.raise_for_status = MagicMock()
mock_post.side_effect = [init_response, verify_response]
with patch("cmd_chat.client.client.srp.User") as mock_srp_user:
mock_usr = MagicMock()
mock_usr.start_authentication.return_value = (None, os.urandom(256))
mock_usr.process_challenge.return_value = os.urandom(32)
mock_usr.verify_session.return_value = None
mock_usr.authenticated.return_value = True
mock_srp_user.return_value = mock_usr
client.srp_authenticate()
assert client.user_id == "test-user-id-12345"
assert client.room_fernet is not None
assert client.fernet is not None
@patch("cmd_chat.client.client.requests.post")
def test_srp_authenticate_init_fails(self, mock_post, client):
mock_post.side_effect = requests.exceptions.HTTPError(
response=MagicMock(status_code=500, text="Server error")
)
with pytest.raises(requests.exceptions.HTTPError):
client.srp_authenticate()
@patch("cmd_chat.client.client.requests.post")
def test_srp_authenticate_verify_fails(self, mock_post, client, room_salt):
init_response = MagicMock()
init_response.json.return_value = {
"user_id": "test-user-id",
"B": base64.b64encode(os.urandom(256)).decode(),
"salt": base64.b64encode(os.urandom(16)).decode(),
"room_salt": base64.b64encode(room_salt).decode(),
}
init_response.raise_for_status = MagicMock()
verify_response = MagicMock()
verify_response.raise_for_status.side_effect = requests.exceptions.HTTPError(
response=MagicMock(status_code=401, text="Invalid proof")
)
mock_post.side_effect = [init_response, verify_response]
with patch("cmd_chat.client.client.srp.User") as mock_srp_user:
mock_usr = MagicMock()
mock_usr.start_authentication.return_value = (None, os.urandom(256))
mock_usr.process_challenge.return_value = os.urandom(32)
mock_srp_user.return_value = mock_usr
with pytest.raises(requests.exceptions.HTTPError):
client.srp_authenticate()
@patch("cmd_chat.client.client.requests.post")
def test_srp_authenticate_challenge_none(self, mock_post, client, room_salt):
init_response = MagicMock()
init_response.json.return_value = {
"user_id": "test-user-id",
"B": base64.b64encode(os.urandom(256)).decode(),
"salt": base64.b64encode(os.urandom(16)).decode(),
"room_salt": base64.b64encode(room_salt).decode(),
}
init_response.raise_for_status = MagicMock()
mock_post.return_value = init_response
with patch("cmd_chat.client.client.srp.User") as mock_srp_user:
mock_usr = MagicMock()
mock_usr.start_authentication.return_value = (None, os.urandom(256))
mock_usr.process_challenge.return_value = None
mock_srp_user.return_value = mock_usr
with pytest.raises(ValueError, match="SRP challenge processing failed"):
client.srp_authenticate()
@patch("cmd_chat.client.client.requests.post")
def test_srp_authenticate_server_not_authenticated(
self, mock_post, client, room_salt
):
init_response = MagicMock()
init_response.json.return_value = {
"user_id": "test-user-id",
"B": base64.b64encode(os.urandom(256)).decode(),
"salt": base64.b64encode(os.urandom(16)).decode(),
"room_salt": base64.b64encode(room_salt).decode(),
}
init_response.raise_for_status = MagicMock()
verify_response = MagicMock()
verify_response.json.return_value = {
"H_AMK": base64.b64encode(os.urandom(32)).decode(),
"session_key": base64.b64encode(Fernet.generate_key()).decode(),
}
verify_response.raise_for_status = MagicMock()
mock_post.side_effect = [init_response, verify_response]
with patch("cmd_chat.client.client.srp.User") as mock_srp_user:
mock_usr = MagicMock()
mock_usr.start_authentication.return_value = (None, os.urandom(256))
mock_usr.process_challenge.return_value = os.urandom(32)
mock_usr.verify_session.return_value = None
mock_usr.authenticated.return_value = False
mock_srp_user.return_value = mock_usr
with pytest.raises(ValueError, match="Server authentication failed"):
client.srp_authenticate()
@patch("cmd_chat.client.client.requests.post")
def test_srp_authenticate_connection_timeout(self, mock_post, client):
mock_post.side_effect = requests.exceptions.Timeout()
with pytest.raises(requests.exceptions.Timeout):
client.srp_authenticate()
class TestDecryptMessage:
def test_decrypt_multiple_messages(self, client, room_fernet, room_salt):
hkdf = HKDF(
algorithm=hashes.SHA256(),
length=32,
salt=room_salt,
info=b"cmd-chat-room-key",
)
room_key = hkdf.derive(client.password)
client.room_fernet = Fernet(base64.urlsafe_b64encode(room_key))
messages = ["Hello", "World", "Test123", "Привет мир"]
for original in messages:
encrypted = room_fernet.encrypt(original.encode()).decode()
msg = {"text": encrypted, "username": "other"}
decrypted = client.decrypt_message(msg)
assert decrypted["text"] == original
def test_decrypt_preserves_other_fields(self, client, room_fernet, room_salt):
hkdf = HKDF(
algorithm=hashes.SHA256(),
length=32,
salt=room_salt,
info=b"cmd-chat-room-key",
)
room_key = hkdf.derive(client.password)
client.room_fernet = Fernet(base64.urlsafe_b64encode(room_key))
encrypted = room_fernet.encrypt(b"test").decode()
msg = {
"text": encrypted,
"username": "sender",
"timestamp": "2024-01-01T12:00:00",
"id": "msg-123",
"user_ip": "192.168.1.1",
}
decrypted = client.decrypt_message(msg)
assert decrypted["text"] == "test"
assert decrypted["username"] == "sender"
assert decrypted["timestamp"] == "2024-01-01T12:00:00"
assert decrypted["id"] == "msg-123"
assert decrypted["user_ip"] == "192.168.1.1"
def test_decrypt_wrong_key_marks_failed(self, client):
fernet1 = Fernet(Fernet.generate_key())
encrypted = fernet1.encrypt(b"secret").decode()
client.room_fernet = Fernet(Fernet.generate_key())
msg = {"text": encrypted, "username": "other"}
decrypted = client.decrypt_message(msg)
assert decrypted["text"] == "[decrypt failed]"
def test_decrypt_corrupted_ciphertext(self, client):
client.room_fernet = Fernet(Fernet.generate_key())
msg = {"text": "YWJjZGVmZ2hpamtsbW5vcA==", "username": "other"}
decrypted = client.decrypt_message(msg)
assert decrypted["text"] == "[decrypt failed]"
def test_decrypt_none_text(self, client):
client.room_fernet = Fernet(Fernet.generate_key())
msg = {"text": None, "username": "other"}
result = client.decrypt_message(msg)
assert result["text"] is None
class TestReceiveLoopExtended:
@pytest.mark.asyncio
async def test_receive_multiple_messages_sequence(self, client, room_fernet):
client.room_fernet = room_fernet
client.running = True
client.messages = []
msg1 = room_fernet.encrypt(b"First").decode()
msg2 = room_fernet.encrypt(b"Second").decode()
msg3 = room_fernet.encrypt(b"Third").decode()
messages = [
json.dumps(
{"type": "message", "data": {"text": msg1, "username": "user1"}}
),
json.dumps(
{"type": "message", "data": {"text": msg2, "username": "user2"}}
),
json.dumps(
{"type": "message", "data": {"text": msg3, "username": "user1"}}
),
]
mock_ws = AsyncMock()
mock_ws.__aiter__.return_value = messages
with patch.object(client, "render_messages"):
await client.receive_loop(mock_ws)
assert len(client.messages) == 3
assert client.messages[0]["text"] == "First"
assert client.messages[1]["text"] == "Second"
assert client.messages[2]["text"] == "Third"
@pytest.mark.asyncio
async def test_receive_stops_when_not_running(self, client, room_fernet):
client.room_fernet = room_fernet
client.running = False
mock_ws = AsyncMock()
mock_ws.__aiter__.return_value = [
json.dumps({"type": "message", "data": {"text": "test", "username": "u"}})
]
with patch.object(client, "render_messages") as mock_render:
await client.receive_loop(mock_ws)
mock_render.assert_not_called()
@pytest.mark.asyncio
async def test_receive_handles_connection_closed(self, client):
client.room_fernet = Fernet(Fernet.generate_key())
client.running = True
client.connected = True
mock_ws = AsyncMock()
mock_ws.__aiter__.side_effect = websockets.ConnectionClosed(None, None)
await client.receive_loop(mock_ws)
assert client.connected is False
@pytest.mark.asyncio
async def test_receive_unknown_message_type(self, client, room_fernet):
client.room_fernet = room_fernet
client.running = True
client.messages = []
unknown_msg = json.dumps({"type": "unknown_type", "data": {}})
mock_ws = AsyncMock()
mock_ws.__aiter__.return_value = [unknown_msg]
with patch.object(client, "render_messages") as mock_render:
await client.receive_loop(mock_ws)
mock_render.assert_not_called()
@pytest.mark.asyncio
async def test_receive_user_joined_updates_list(self, client):
client.room_fernet = Fernet(Fernet.generate_key())
client.running = True
client.users = []
init_msg = json.dumps(
{
"type": "init",
"messages": [],
"users": [
{"user_id": "1", "username": "alice"},
{"user_id": "2", "username": "bob"},
],
}
)
mock_ws = AsyncMock()
mock_ws.__aiter__.return_value = [init_msg]
with patch.object(client, "render_messages"):
await client.receive_loop(mock_ws)
assert len(client.users) == 2
assert client.users[0]["username"] == "alice"
assert client.users[1]["username"] == "bob"
@pytest.mark.asyncio
async def test_receive_multiple_users_leave(self, client):
client.room_fernet = Fernet(Fernet.generate_key())
client.running = True
client.users = [
{"user_id": "1", "username": "alice"},
{"user_id": "2", "username": "bob"},
{"user_id": "3", "username": "charlie"},
]
leave_msgs = [
json.dumps({"type": "user_left", "user_id": "1"}),
json.dumps({"type": "user_left", "user_id": "3"}),
]
mock_ws = AsyncMock()
mock_ws.__aiter__.return_value = leave_msgs
with patch.object(client, "render_messages"):
await client.receive_loop(mock_ws)
assert len(client.users) == 1
assert client.users[0]["username"] == "bob"
class TestInputLoopExtended:
@pytest.mark.asyncio
async def test_input_keyboard_interrupt(self, client):
client.room_fernet = Fernet(Fernet.generate_key())
client.running = True
mock_ws = AsyncMock()
with patch("asyncio.get_event_loop") as mock_loop:
mock_executor = AsyncMock(side_effect=KeyboardInterrupt())
mock_loop.return_value.run_in_executor = mock_executor
await client.input_loop(mock_ws)
assert client.running is False
@pytest.mark.asyncio
async def test_input_eof_error(self, client):
client.room_fernet = Fernet(Fernet.generate_key())
client.running = True
mock_ws = AsyncMock()
with patch("asyncio.get_event_loop") as mock_loop:
mock_executor = AsyncMock(side_effect=EOFError())
mock_loop.return_value.run_in_executor = mock_executor
await client.input_loop(mock_ws)
assert client.running is False
@pytest.mark.asyncio
async def test_input_exit_command(self, client):
client.room_fernet = Fernet(Fernet.generate_key())
client.running = True
mock_ws = AsyncMock()
with patch("asyncio.get_event_loop") as mock_loop:
mock_executor = AsyncMock(return_value="exit")
mock_loop.return_value.run_in_executor = mock_executor
await client.input_loop(mock_ws)
assert client.running is False
@pytest.mark.asyncio
async def test_input_case_insensitive_quit(self, client):
client.room_fernet = Fernet(Fernet.generate_key())
client.running = True
mock_ws = AsyncMock()
inputs = iter(["QUIT"])
with patch("asyncio.get_event_loop") as mock_loop:
mock_executor = AsyncMock(side_effect=lambda _, __: next(inputs))
mock_loop.return_value.run_in_executor = mock_executor
await client.input_loop(mock_ws)
assert client.running is False
@pytest.mark.asyncio
async def test_input_multiple_messages_then_quit(self, client, room_fernet):
client.room_fernet = room_fernet
client.running = True
mock_ws = AsyncMock()
sent = []
mock_ws.send = AsyncMock(side_effect=lambda m: sent.append(m))
inputs = iter(["msg1", "msg2", "msg3", "q"])
with patch("asyncio.get_event_loop") as mock_loop:
mock_executor = AsyncMock(side_effect=lambda _, __: next(inputs))
mock_loop.return_value.run_in_executor = mock_executor
await client.input_loop(mock_ws)
assert len(sent) == 3
assert room_fernet.decrypt(sent[0].encode()).decode() == "msg1"
assert room_fernet.decrypt(sent[1].encode()).decode() == "msg2"
assert room_fernet.decrypt(sent[2].encode()).decode() == "msg3"
@pytest.mark.asyncio
async def test_input_whitespace_only_not_sent(self, client):
client.room_fernet = Fernet(Fernet.generate_key())
client.running = True
mock_ws = AsyncMock()
inputs = iter(["\t", "\n", " \t ", "q"])
with patch("asyncio.get_event_loop") as mock_loop:
mock_executor = AsyncMock(side_effect=lambda _, __: next(inputs))
mock_loop.return_value.run_in_executor = mock_executor
await client.input_loop(mock_ws)
mock_ws.send.assert_not_called()
class TestRunAsync:
@pytest.mark.asyncio
async def test_run_connection_error(self, client):
with patch.object(client, "srp_authenticate") as mock_auth:
mock_auth.side_effect = requests.exceptions.ConnectionError()
with patch.object(client.console, "clear"):
with patch.object(client.console, "print"):
await client.run_async()
@pytest.mark.asyncio
async def test_run_http_error(self, client):
mock_response = MagicMock()
mock_response.status_code = 403
mock_response.text = "Forbidden"
with patch.object(client, "srp_authenticate") as mock_auth:
mock_auth.side_effect = requests.exceptions.HTTPError(
response=mock_response
)
with patch.object(client.console, "clear"):
with patch.object(client.console, "print"):
await client.run_async()
@pytest.mark.asyncio
async def test_run_value_error(self, client):
with patch.object(client, "srp_authenticate") as mock_auth:
mock_auth.side_effect = ValueError("Auth failed")
with patch.object(client.console, "clear"):
with patch.object(client.console, "print"):
await client.run_async()
@pytest.mark.asyncio
async def test_run_generic_exception(self, client):
with patch.object(client, "srp_authenticate") as mock_auth:
mock_auth.side_effect = RuntimeError("Unexpected")
with patch.object(client.console, "clear"):
with patch.object(client.console, "print"):
await client.run_async()
@pytest.mark.asyncio
async def test_run_successful_connection_and_disconnect(self, client):
client.user_id = "test-id-123"
with patch.object(client, "srp_authenticate"):
with patch("cmd_chat.client.client.websockets.connect") as mock_connect:
mock_ws = AsyncMock()
mock_connect.return_value.__aenter__.return_value = mock_ws
with patch.object(
client, "receive_loop", new_callable=AsyncMock
) as mock_recv:
with patch.object(
client, "input_loop", new_callable=AsyncMock
) as mock_input:
mock_input.return_value = None
mock_recv.return_value = None
with patch.object(client.console, "clear"):
with patch.object(client.console, "print"):
await client.run_async()
class TestRenderMessagesExtended:
def test_render_own_message_green(self, client):
client.room_fernet = Fernet(Fernet.generate_key())
client.username = "testuser"
client.messages = [
{
"username": "testuser",
"text": "my msg",
"timestamp": "2024-01-01T12:00:00",
}
]
client.users = []
printed = []
with patch.object(client.console, "clear"):
with patch.object(
client.console, "print", side_effect=lambda x: printed.append(x)
):
client.render_messages()
msg_output = [p for p in printed if "my msg" in str(p)]
assert len(msg_output) == 1
assert "green" in str(msg_output[0])
def test_render_other_message_cyan(self, client):
client.room_fernet = Fernet(Fernet.generate_key())
client.username = "testuser"
client.messages = [
{
"username": "other",
"text": "their msg",
"timestamp": "2024-01-01T12:00:00",
}
]
client.users = []
printed = []
with patch.object(client.console, "clear"):
with patch.object(
client.console, "print", side_effect=lambda x: printed.append(x)
):
client.render_messages()
msg_output = [p for p in printed if "their msg" in str(p)]
assert len(msg_output) == 1
assert "cyan" in str(msg_output[0])
def test_render_timestamp_formatting(self, client):
client.room_fernet = Fernet(Fernet.generate_key())
client.messages = [
{
"username": "user",
"text": "test",
"timestamp": "2024-01-15T14:30:45.123456",
}
]
client.users = []
printed = []
with patch.object(client.console, "clear"):
with patch.object(
client.console, "print", side_effect=lambda x: printed.append(x)
):
client.render_messages()
msg_output = [p for p in printed if "2024-01-15 14:30:45" in str(p)]
assert len(msg_output) == 1
def test_render_users_online_display(self, client):
client.room_fernet = Fernet(Fernet.generate_key())
client.messages = []
client.users = [
{"user_id": "1", "username": "alice"},
{"user_id": "2", "username": "bob"},
{"user_id": "3", "username": "charlie"},
]
printed = []
with patch.object(client.console, "clear"):
with patch.object(
client.console, "print", side_effect=lambda x: printed.append(x)
):
client.render_messages()
online_line = [p for p in printed if "Online:" in str(p)]
assert len(online_line) == 1
assert "alice" in str(online_line[0])
assert "bob" in str(online_line[0])
assert "charlie" in str(online_line[0])
def test_render_no_users_shows_none(self, client):
client.room_fernet = Fernet(Fernet.generate_key())
client.messages = []
client.users = []
printed = []
with patch.object(client.console, "clear"):
with patch.object(
client.console, "print", side_effect=lambda x: printed.append(x)
):
client.render_messages()
online_line = [p for p in printed if "Online:" in str(p)]
assert "none" in str(online_line[0])
def test_render_missing_username_shows_unknown(self, client):
client.room_fernet = Fernet(Fernet.generate_key())
client.messages = [{"text": "test", "timestamp": "2024-01-01T12:00:00"}]
client.users = []
printed = []
with patch.object(client.console, "clear"):
with patch.object(
client.console, "print", side_effect=lambda x: printed.append(x)
):
client.render_messages()
msg_output = [p for p in printed if "unknown" in str(p)]
assert len(msg_output) >= 1
def test_render_missing_timestamp(self, client):
client.room_fernet = Fernet(Fernet.generate_key())
client.messages = [{"username": "user", "text": "test"}]
client.users = []
with patch.object(client.console, "clear"):
with patch.object(client.console, "print"):
client.render_messages()
class TestE2EEncryptionFlow:
def test_same_password_same_key(self, room_salt):
password = b"shared_secret"
hkdf1 = HKDF(
algorithm=hashes.SHA256(),
length=32,
salt=room_salt,
info=b"cmd-chat-room-key",
)
key1 = base64.urlsafe_b64encode(hkdf1.derive(password))
hkdf2 = HKDF(
algorithm=hashes.SHA256(),
length=32,
salt=room_salt,
info=b"cmd-chat-room-key",
)
key2 = base64.urlsafe_b64encode(hkdf2.derive(password))
fernet1 = Fernet(key1)
fernet2 = Fernet(key2)
ciphertext = fernet1.encrypt(b"Hello from client 1")
plaintext = fernet2.decrypt(ciphertext)
assert plaintext == b"Hello from client 1"
def test_different_password_cannot_decrypt(self, room_salt):
hkdf1 = HKDF(
algorithm=hashes.SHA256(),
length=32,
salt=room_salt,
info=b"cmd-chat-room-key",
)
key1 = base64.urlsafe_b64encode(hkdf1.derive(b"correct_password"))
hkdf2 = HKDF(
algorithm=hashes.SHA256(),
length=32,
salt=room_salt,
info=b"cmd-chat-room-key",
)
key2 = base64.urlsafe_b64encode(hkdf2.derive(b"wrong_password"))
fernet1 = Fernet(key1)
fernet2 = Fernet(key2)
ciphertext = fernet1.encrypt(b"Secret message")
with pytest.raises(Exception):
fernet2.decrypt(ciphertext)
def test_server_cannot_read_without_password(self, room_salt):
hkdf = HKDF(
algorithm=hashes.SHA256(),
length=32,
salt=room_salt,
info=b"cmd-chat-room-key",
)
client_key = base64.urlsafe_b64encode(hkdf.derive(b"client_password"))
client_fernet = Fernet(client_key)
ciphertext = client_fernet.encrypt(b"Private message")
server_random_key = Fernet.generate_key()
server_fernet = Fernet(server_random_key)
with pytest.raises(Exception):
server_fernet.decrypt(ciphertext)
class TestEdgeCases:
def test_empty_username(self):
client = Client("localhost", 3000, "", "password")
assert client.username == ""
def test_very_long_message(self, client, room_fernet, room_salt):
hkdf = HKDF(
algorithm=hashes.SHA256(),
length=32,
salt=room_salt,
info=b"cmd-chat-room-key",
)
room_key = hkdf.derive(client.password)
client.room_fernet = Fernet(base64.urlsafe_b64encode(room_key))
long_message = "x" * 10000
encrypted = room_fernet.encrypt(long_message.encode()).decode()
msg = {"text": encrypted, "username": "other"}
decrypted = client.decrypt_message(msg)
assert decrypted["text"] == long_message
def test_unicode_message(self, client, room_fernet, room_salt):
hkdf = HKDF(
algorithm=hashes.SHA256(),
length=32,
salt=room_salt,
info=b"cmd-chat-room-key",
)
room_key = hkdf.derive(client.password)
client.room_fernet = Fernet(base64.urlsafe_b64encode(room_key))
unicode_msg = "Привет 世界 🎉 مرحبا"
encrypted = room_fernet.encrypt(unicode_msg.encode()).decode()
msg = {"text": encrypted, "username": "other"}
decrypted = client.decrypt_message(msg)
assert decrypted["text"] == unicode_msg
def test_special_characters_in_message(self, client, room_fernet, room_salt):
hkdf = HKDF(
algorithm=hashes.SHA256(),
length=32,
salt=room_salt,
info=b"cmd-chat-room-key",
)
room_key = hkdf.derive(client.password)
client.room_fernet = Fernet(base64.urlsafe_b64encode(room_key))
special_msg = '<script>alert("xss")</script> & "quotes" \'single\' \n\t\r'
encrypted = room_fernet.encrypt(special_msg.encode()).decode()
msg = {"text": encrypted, "username": "other"}
decrypted = client.decrypt_message(msg)
assert decrypted["text"] == special_msg
def test_port_zero(self):
client = Client("localhost", 0, "user", "pass")
assert client.port == 0
assert client.base_url == "http://localhost:0"
def test_ipv6_server(self):
client = Client("::1", 3000, "user", "pass")
assert client.base_url == "http://::1:3000"