From e5e1ad8dee8fd070cf090d9fca90bd057bb1cd5d Mon Sep 17 00:00:00 2001 From: leetcrypt Date: Tue, 2 Jun 2026 17:59:01 -0700 Subject: [PATCH] feat(ai): in-RAM semantic recall (RAG) for conversation context MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Give the agent recall of things said beyond the verbatim window, without breaking the RAM-only philosophy — nothing is persisted to disk. - MemoryIndex: a capped, in-memory pool of embedded messages with pure-Python cosine search (no numpy). Retains far more than the rolling transcript so old lines can be surfaced on demand; oldest evicted past the cap to bound RAM. - OllamaEmbedder: local embeddings via nomic-embed-text, on by default and independent of the chat provider (reuses the Ollama host when chat is Ollama). - Bridge: captured room messages (live + backfilled) are embedded on a background worker so a slow embedder can't stall frame draining. On a /ai question the agent retrieves top-k relevant lines, drops weak ( --- cmd_chat/agent/__main__.py | 22 ++++- cmd_chat/agent/bridge.py | 168 +++++++++++++++++++++++++++--------- cmd_chat/agent/memory.py | 58 +++++++++++++ cmd_chat/agent/providers.py | 23 +++++ docs/ai-context-plan.md | 16 ++-- 5 files changed, 236 insertions(+), 51 deletions(-) create mode 100644 cmd_chat/agent/memory.py diff --git a/cmd_chat/agent/__main__.py b/cmd_chat/agent/__main__.py index 61ba968..54138a3 100644 --- a/cmd_chat/agent/__main__.py +++ b/cmd_chat/agent/__main__.py @@ -32,7 +32,7 @@ import sys from .bridge import AgentBridge from .profiles import load_profiles, provider_from_profile -from .providers import make_provider, preflight +from .providers import OllamaEmbedder, make_provider, preflight def _build_provider(args, ap): @@ -80,6 +80,14 @@ def main() -> None: help="max prior messages fed to the model per reply") ap.add_argument("--token-budget", type=int, default=3000, help="approx token cap on the context window (whichever is smaller wins)") + ap.add_argument("--no-rag", action="store_true", + help="disable in-RAM semantic recall (recency-only context)") + ap.add_argument("--embed-model", default="nomic-embed-text", + help="Ollama model used to embed messages for recall") + ap.add_argument("--embed-host", default=None, + help="Ollama host for embeddings (default: chat host or $OLLAMA_HOST)") + ap.add_argument("--rag-top-k", type=int, default=4, + help="how many recalled messages to surface per reply") ap.add_argument("--list-models", action="store_true", help="list models the backend can serve, then exit") ap.add_argument("--check", action="store_true", @@ -112,11 +120,21 @@ def main() -> None: if not ok: print(f"⚠ preflight: {msg}", file=sys.stderr) + # In-RAM semantic recall is on by default and local (Ollama embeddings), + # independent of which provider answers chat. Reuse the chat host if it's an + # Ollama provider so a single --host/profile covers both. + embedder = None + if not args.no_rag: + embedder = OllamaEmbedder( + model=args.embed_model, + host=args.embed_host or getattr(provider, "host", None), + ) + bridge = AgentBridge( args.server, args.port, name=args.name, provider=provider, password=args.password, insecure=args.insecure, no_tls=args.no_tls, system_prompt=args.system, context_window=args.context_window, - token_budget=args.token_budget, + token_budget=args.token_budget, embedder=embedder, rag_top_k=args.rag_top_k, ) try: bridge.run() diff --git a/cmd_chat/agent/bridge.py b/cmd_chat/agent/bridge.py index 7cb9716..86ae1e1 100644 --- a/cmd_chat/agent/bridge.py +++ b/cmd_chat/agent/bridge.py @@ -18,6 +18,7 @@ import re import websockets from ..client.client import Client +from .memory import MemoryIndex from .providers import Msg, Provider DEFAULT_SYSTEM = ( @@ -66,7 +67,8 @@ class AgentBridge(Client): def __init__(self, server: str, port: int, name: str, provider: Provider, password: str | None = None, insecure: bool = False, no_tls: bool = False, system_prompt: str | None = None, context_window: int = 12, - token_budget: int = 3000): + token_budget: int = 3000, embedder=None, rag_top_k: int = 4, + rag_min_score: float = 0.35): super().__init__(server, port, username=name, password=password, insecure=insecure, no_tls=no_tls) self.name = name @@ -78,6 +80,15 @@ class AgentBridge(Client): # is smaller wins. Keeps small local models inside their effective ctx. self.token_budget = token_budget self.transcript: list[Msg] = [] + # In-RAM semantic recall (RAG). The long-term store keeps far more than + # the verbatim window so we can surface relevant old lines on demand. + # Disabled (embedder=None) → falls back to recency-only context. + self.embedder = embedder + self.memory = MemoryIndex() if embedder is not None else None + self.rag_top_k = rag_top_k + self.rag_min_score = rag_min_score # drop weak cosine matches as noise + self._embed_q: asyncio.Queue[Msg] = asyncio.Queue() + self._embed_warned = False # log embedder failure once, then stay quiet # Sandbox-drive state, mirrored from the owner's `_perm:acl` broadcasts. self.granted = False # may we type into the shared PTY? self.can_sudo = False # does our VM account have sudo? @@ -119,13 +130,69 @@ class AgentBridge(Client): text = dec.get("text", "") if not text or text == "[decrypt failed]" or text.startswith('{"_'): continue - self.transcript.append(Msg("user", f"{sender}: {text}")) + msg = Msg("user", f"{sender}: {text}") + self.transcript.append(msg) + self._remember(msg) seeded += 1 # Keep the same rolling bound the live path uses. self.transcript = self.transcript[-(self.context_window * 2):] if seeded: self.info(f"backfilled {seeded} prior message(s) for context") + # ── In-RAM semantic recall (RAG) ───────────────────────────────────── + def _remember(self, msg: Msg) -> None: + """Queue a message for background embedding into the memory index. A + no-op when RAG is disabled. Embedding happens off the recv loop so a + slow embedder can never stall frame draining.""" + if self.memory is not None: + self._embed_q.put_nowait(msg) + + async def _embed_worker(self) -> None: + """Drain the embed queue, embedding each message and storing the vector. + Eventually-consistent on purpose: a question may arrive before the most + recent line is indexed — that line is still in the verbatim window, so + nothing is lost. If the embedder is unreachable we say so once and keep + accepting work (it may recover).""" + while self.running: + msg = await self._embed_q.get() + try: + vec = await asyncio.to_thread(self.embedder.embed, msg.content) + self.memory.add(msg, vec) + except Exception as e: # noqa: BLE001 — degrade to recency-only recall + if not self._embed_warned: + self.info(f"semantic recall unavailable (embedder: {e})") + self._embed_warned = True + finally: + self._embed_q.task_done() + + async def _retrieve(self, query: str, exclude: set[str]) -> list[Msg]: + """Top-k past messages semantically relevant to ``query``, minus weak + matches and anything already in the recent verbatim window (``exclude``).""" + if self.memory is None or len(self.memory) == 0: + return [] + try: + qvec = await asyncio.to_thread(self.embedder.embed, query) + except Exception: # noqa: BLE001 — embedder down → just skip recall + return [] + hits = self.memory.search(qvec, self.rag_top_k) + return [ + m for score, m in hits + if score >= self.rag_min_score and m.content not in exclude + ] + + async def _model_messages(self, query: str) -> list[Msg]: + """Assemble the message list for a model call: a single recalled-context + preamble (if RAG surfaced anything) followed by the recent verbatim + window. Recalled lines are clearly fenced as stale, untrusted context — + never elevated to system role — so they inform without instructing.""" + window = self._window() + retrieved = await self._retrieve(query, exclude={m.content for m in window}) + if not retrieved: + return window + recall = "Relevant earlier messages (recalled context, may be stale):\n" + \ + "\n".join(f"- {m.content}" for m in retrieved) + return [Msg("user", recall)] + window + def _addressed_question(self, text: str) -> str | None: """Return the question if this ``/ai …`` line targets us, else None.""" t = text.strip() @@ -261,11 +328,11 @@ class AgentBridge(Client): return await self._send_typing(ws, True) try: + context = await self._model_messages(task) plan = await asyncio.to_thread( self.provider.complete, SANDBOX_SYSTEM.format(name=self.name), - self._window() - + [Msg("user", f"{asker} wants this done in the shell: {task}")], + context + [Msg("user", f"{asker} wants this done in the shell: {task}")], ) except Exception as e: # noqa: BLE001 — surface provider failure in-room await self._send_typing(ws, False) @@ -304,10 +371,11 @@ class AgentBridge(Client): self.transcript.append(Msg("user", f"{asker}: {question}")) await self._send_typing(ws, True) try: + context = await self._model_messages(question) reply = await asyncio.to_thread( self.provider.complete, self.system_prompt, - self._window(), + context, ) except Exception as e: # noqa: BLE001 — surface any provider failure in-room reply = f"[ai error: {e}]" @@ -330,41 +398,55 @@ class AgentBridge(Client): ) await ws.send(self.room_fernet.encrypt(announce.encode()).decode()) self.success("agent online") - async for raw in ws: - if not self.running: - break - try: - data = json.loads(raw) - except json.JSONDecodeError: - continue - mtype = data.get("type") - if mtype == "init": - self.users = data.get("users", []) - self._seed_transcript(data.get("messages", [])) - continue - if mtype == "roster": - self.users = data.get("users", []) - continue - if mtype != "message": - continue - msg = self.decrypt_message(data.get("data", {})) - text = msg.get("text", "") - sender = msg.get("username", "?") - if sender == self.name: - continue # never react to our own messages - if text.startswith('{"_'): - self._handle_control(text) # track ACL grants; ignore other ctrl frames - continue - question = self._addressed_question(text) - if question is None: - # keep a short rolling transcript for context on future asks - self.transcript.append(Msg("user", f"{sender}: {text}")) - self.transcript = self.transcript[-(self.context_window * 2):] - elif question.startswith("!"): - self.info(f"{sender} → /ai !sbx: {question[1:].strip()}") - await self._run_in_sandbox(ws, question[1:].strip(), sender) - elif question.strip().lower() == "confirm": - await self._confirm_pending(ws, sender) - else: - self.info(f"{sender} → /ai: {question}") - await self._answer(ws, question, sender) + embed_task = ( + asyncio.create_task(self._embed_worker()) + if self.memory is not None else None + ) + try: + await self._serve(ws) + finally: + if embed_task is not None: + embed_task.cancel() + + async def _serve(self, ws) -> None: + async for raw in ws: + if not self.running: + break + try: + data = json.loads(raw) + except json.JSONDecodeError: + continue + mtype = data.get("type") + if mtype == "init": + self.users = data.get("users", []) + self._seed_transcript(data.get("messages", [])) + continue + if mtype == "roster": + self.users = data.get("users", []) + continue + if mtype != "message": + continue + msg = self.decrypt_message(data.get("data", {})) + text = msg.get("text", "") + sender = msg.get("username", "?") + if sender == self.name: + continue # never react to our own messages + if text.startswith('{"_'): + self._handle_control(text) # track ACL grants; ignore other ctrl frames + continue + question = self._addressed_question(text) + if question is None: + # keep a short rolling transcript for context on future asks, + # and feed the line to long-term semantic memory + captured = Msg("user", f"{sender}: {text}") + self.transcript.append(captured) + self.transcript = self.transcript[-(self.context_window * 2):] + self._remember(captured) + elif question.startswith("!"): + self.info(f"{sender} → /ai !sbx: {question[1:].strip()}") + await self._run_in_sandbox(ws, question[1:].strip(), sender) + elif question.strip().lower() == "confirm": + await self._confirm_pending(ws, sender) + else: + self.info(f"{sender} → /ai: {question}") + await self._answer(ws, question, sender) diff --git a/cmd_chat/agent/memory.py b/cmd_chat/agent/memory.py new file mode 100644 index 0000000..c6a8cbb --- /dev/null +++ b/cmd_chat/agent/memory.py @@ -0,0 +1,58 @@ +"""In-RAM semantic memory for the hack-house AI agent. + +Holds embedded past messages in process memory only — no disk, no DB. The +store is bounded and dies with the agent, exactly like the room's own history +and the rolling transcript. Cosine similarity is computed in pure Python (the +vectors are small and the store is capped), so there's no numpy dependency. +""" + +from __future__ import annotations + +import math +from dataclasses import dataclass + +from .providers import Msg + + +@dataclass +class _Entry: + msg: Msg + vec: list[float] + norm: float # precomputed ||vec|| so search is a dot product + divide + + +class MemoryIndex: + """A capped, in-memory pool of embedded messages for semantic recall. + + This is the *long-term* store — it deliberately retains far more than the + verbatim transcript window, so the agent can recall something said long + before the recent slice. Oldest entries are evicted past ``max_entries`` to + bound RAM (≈3 MB at 500 × 768-float vectors). + """ + + def __init__(self, max_entries: int = 500): + self.max_entries = max_entries + self._entries: list[_Entry] = [] + + def __len__(self) -> int: + return len(self._entries) + + def add(self, msg: Msg, vec: list[float]) -> None: + norm = math.sqrt(sum(x * x for x in vec)) if vec else 0.0 + if norm == 0.0: + return # empty / failed embedding — skip rather than poison search + self._entries.append(_Entry(msg, vec, norm)) + if len(self._entries) > self.max_entries: + self._entries = self._entries[-self.max_entries:] + + def search(self, qvec: list[float], k: int) -> list[tuple[float, Msg]]: + """Top-``k`` entries by cosine similarity, highest first.""" + qnorm = math.sqrt(sum(x * x for x in qvec)) if qvec else 0.0 + if qnorm == 0.0 or not self._entries: + return [] + scored = [ + (sum(a * b for a, b in zip(qvec, e.vec)) / (qnorm * e.norm), e.msg) + for e in self._entries + ] + scored.sort(key=lambda t: t[0], reverse=True) + return scored[:k] diff --git a/cmd_chat/agent/providers.py b/cmd_chat/agent/providers.py index 7bbedfa..3a1c2b6 100644 --- a/cmd_chat/agent/providers.py +++ b/cmd_chat/agent/providers.py @@ -72,6 +72,29 @@ class OllamaProvider: return [m.get("name", "") for m in r.json().get("models", [])] +class OllamaEmbedder: + """Local text embeddings via Ollama (default ``nomic-embed-text``), used for + the agent's in-RAM semantic recall. Local + free, so it stays on by default + regardless of which provider answers chat. No key, nothing persisted.""" + + name = "ollama-embed" + + def __init__(self, model: str = "nomic-embed-text", host: str | None = None, + timeout: int = 60): + self.model = model + self.host = (host or os.environ.get("OLLAMA_HOST", "http://localhost:11434")).rstrip("/") + self.timeout = timeout + + def embed(self, text: str) -> list[float]: + r = requests.post( + f"{self.host}/api/embeddings", + json={"model": self.model, "prompt": text}, + timeout=self.timeout, + ) + r.raise_for_status() + return r.json().get("embedding") or [] + + class AnthropicProvider: """Anthropic Messages API. Cloud — opt-in. Needs ANTHROPIC_API_KEY.""" diff --git a/docs/ai-context-plan.md b/docs/ai-context-plan.md index 8a5b3f2..1d5a3c5 100644 --- a/docs/ai-context-plan.md +++ b/docs/ai-context-plan.md @@ -43,12 +43,16 @@ That is a RAM-only history the agent can backfill from on join at zero new cost. 8. **Tune Ollama `options`** — explicit `num_ctx` (so the larger window in #1/#2 is actually honored) and bounded `num_predict`. *(implementing)* -### Tier 2 — deeper context (next branch) -3. **In-RAM semantic retrieval (RAG, no disk).** Embed each captured message - with the already-present `nomic-embed-text`, hold vectors in a numpy array in - memory; on a `/ai` question retrieve top-k by cosine and prepend to the - recency window. Fully ephemeral. -4. **In-RAM hierarchical compaction.** When over budget, summarize the oldest +### Tier 2 — deeper context +3. **In-RAM semantic retrieval (RAG, no disk).** *(done)* Each captured message + is embedded with the already-present `nomic-embed-text` and held in a capped + in-memory `MemoryIndex` (pure-Python cosine, no numpy). On a `/ai` question + the agent embeds the query, retrieves top-k, drops weak/duplicate hits, and + prepends them as a clearly-fenced "recalled context" preamble (never system + role — keeps untrusted text from instructing). Embedding runs on a background + worker so it can't stall the recv loop; if the embedder is unreachable it + degrades to recency-only. Toggle with `--no-rag` / `--rag-top-k`. +4. **In-RAM hierarchical compaction.** *(staged)* When over budget, summarize the oldest chunk into a single rolling `Msg("system", "earlier: …")` instead of dropping it — the Claude Code auto-compaction pattern, kept in RAM.