feat(ai): in-RAM semantic recall (RAG) for conversation context
Give the agent recall of things said beyond the verbatim window, without breaking the RAM-only philosophy — nothing is persisted to disk. - MemoryIndex: a capped, in-memory pool of embedded messages with pure-Python cosine search (no numpy). Retains far more than the rolling transcript so old lines can be surfaced on demand; oldest evicted past the cap to bound RAM. - OllamaEmbedder: local embeddings via nomic-embed-text, on by default and independent of the chat provider (reuses the Ollama host when chat is Ollama). - Bridge: captured room messages (live + backfilled) are embedded on a background worker so a slow embedder can't stall frame draining. On a /ai question the agent retrieves top-k relevant lines, drops weak (<min_score) and windowed-duplicate hits, and prepends them as a clearly-fenced "recalled context" preamble — kept at user role, never elevated to system, so untrusted room text informs without instructing. Falls back to recency-only if the embedder is unreachable. - CLI: --no-rag, --embed-model, --embed-host, --rag-top-k. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
85fde59292
commit
e5e1ad8dee
|
|
@ -32,7 +32,7 @@ import sys
|
|||
|
||||
from .bridge import AgentBridge
|
||||
from .profiles import load_profiles, provider_from_profile
|
||||
from .providers import make_provider, preflight
|
||||
from .providers import OllamaEmbedder, make_provider, preflight
|
||||
|
||||
|
||||
def _build_provider(args, ap):
|
||||
|
|
@ -80,6 +80,14 @@ def main() -> None:
|
|||
help="max prior messages fed to the model per reply")
|
||||
ap.add_argument("--token-budget", type=int, default=3000,
|
||||
help="approx token cap on the context window (whichever is smaller wins)")
|
||||
ap.add_argument("--no-rag", action="store_true",
|
||||
help="disable in-RAM semantic recall (recency-only context)")
|
||||
ap.add_argument("--embed-model", default="nomic-embed-text",
|
||||
help="Ollama model used to embed messages for recall")
|
||||
ap.add_argument("--embed-host", default=None,
|
||||
help="Ollama host for embeddings (default: chat host or $OLLAMA_HOST)")
|
||||
ap.add_argument("--rag-top-k", type=int, default=4,
|
||||
help="how many recalled messages to surface per reply")
|
||||
ap.add_argument("--list-models", action="store_true",
|
||||
help="list models the backend can serve, then exit")
|
||||
ap.add_argument("--check", action="store_true",
|
||||
|
|
@ -112,11 +120,21 @@ def main() -> None:
|
|||
if not ok:
|
||||
print(f"⚠ preflight: {msg}", file=sys.stderr)
|
||||
|
||||
# In-RAM semantic recall is on by default and local (Ollama embeddings),
|
||||
# independent of which provider answers chat. Reuse the chat host if it's an
|
||||
# Ollama provider so a single --host/profile covers both.
|
||||
embedder = None
|
||||
if not args.no_rag:
|
||||
embedder = OllamaEmbedder(
|
||||
model=args.embed_model,
|
||||
host=args.embed_host or getattr(provider, "host", None),
|
||||
)
|
||||
|
||||
bridge = AgentBridge(
|
||||
args.server, args.port, name=args.name, provider=provider,
|
||||
password=args.password, insecure=args.insecure, no_tls=args.no_tls,
|
||||
system_prompt=args.system, context_window=args.context_window,
|
||||
token_budget=args.token_budget,
|
||||
token_budget=args.token_budget, embedder=embedder, rag_top_k=args.rag_top_k,
|
||||
)
|
||||
try:
|
||||
bridge.run()
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ import re
|
|||
import websockets
|
||||
|
||||
from ..client.client import Client
|
||||
from .memory import MemoryIndex
|
||||
from .providers import Msg, Provider
|
||||
|
||||
DEFAULT_SYSTEM = (
|
||||
|
|
@ -66,7 +67,8 @@ class AgentBridge(Client):
|
|||
def __init__(self, server: str, port: int, name: str, provider: Provider,
|
||||
password: str | None = None, insecure: bool = False, no_tls: bool = False,
|
||||
system_prompt: str | None = None, context_window: int = 12,
|
||||
token_budget: int = 3000):
|
||||
token_budget: int = 3000, embedder=None, rag_top_k: int = 4,
|
||||
rag_min_score: float = 0.35):
|
||||
super().__init__(server, port, username=name, password=password,
|
||||
insecure=insecure, no_tls=no_tls)
|
||||
self.name = name
|
||||
|
|
@ -78,6 +80,15 @@ class AgentBridge(Client):
|
|||
# is smaller wins. Keeps small local models inside their effective ctx.
|
||||
self.token_budget = token_budget
|
||||
self.transcript: list[Msg] = []
|
||||
# In-RAM semantic recall (RAG). The long-term store keeps far more than
|
||||
# the verbatim window so we can surface relevant old lines on demand.
|
||||
# Disabled (embedder=None) → falls back to recency-only context.
|
||||
self.embedder = embedder
|
||||
self.memory = MemoryIndex() if embedder is not None else None
|
||||
self.rag_top_k = rag_top_k
|
||||
self.rag_min_score = rag_min_score # drop weak cosine matches as noise
|
||||
self._embed_q: asyncio.Queue[Msg] = asyncio.Queue()
|
||||
self._embed_warned = False # log embedder failure once, then stay quiet
|
||||
# Sandbox-drive state, mirrored from the owner's `_perm:acl` broadcasts.
|
||||
self.granted = False # may we type into the shared PTY?
|
||||
self.can_sudo = False # does our VM account have sudo?
|
||||
|
|
@ -119,13 +130,69 @@ class AgentBridge(Client):
|
|||
text = dec.get("text", "")
|
||||
if not text or text == "[decrypt failed]" or text.startswith('{"_'):
|
||||
continue
|
||||
self.transcript.append(Msg("user", f"{sender}: {text}"))
|
||||
msg = Msg("user", f"{sender}: {text}")
|
||||
self.transcript.append(msg)
|
||||
self._remember(msg)
|
||||
seeded += 1
|
||||
# Keep the same rolling bound the live path uses.
|
||||
self.transcript = self.transcript[-(self.context_window * 2):]
|
||||
if seeded:
|
||||
self.info(f"backfilled {seeded} prior message(s) for context")
|
||||
|
||||
# ── In-RAM semantic recall (RAG) ─────────────────────────────────────
|
||||
def _remember(self, msg: Msg) -> None:
|
||||
"""Queue a message for background embedding into the memory index. A
|
||||
no-op when RAG is disabled. Embedding happens off the recv loop so a
|
||||
slow embedder can never stall frame draining."""
|
||||
if self.memory is not None:
|
||||
self._embed_q.put_nowait(msg)
|
||||
|
||||
async def _embed_worker(self) -> None:
|
||||
"""Drain the embed queue, embedding each message and storing the vector.
|
||||
Eventually-consistent on purpose: a question may arrive before the most
|
||||
recent line is indexed — that line is still in the verbatim window, so
|
||||
nothing is lost. If the embedder is unreachable we say so once and keep
|
||||
accepting work (it may recover)."""
|
||||
while self.running:
|
||||
msg = await self._embed_q.get()
|
||||
try:
|
||||
vec = await asyncio.to_thread(self.embedder.embed, msg.content)
|
||||
self.memory.add(msg, vec)
|
||||
except Exception as e: # noqa: BLE001 — degrade to recency-only recall
|
||||
if not self._embed_warned:
|
||||
self.info(f"semantic recall unavailable (embedder: {e})")
|
||||
self._embed_warned = True
|
||||
finally:
|
||||
self._embed_q.task_done()
|
||||
|
||||
async def _retrieve(self, query: str, exclude: set[str]) -> list[Msg]:
|
||||
"""Top-k past messages semantically relevant to ``query``, minus weak
|
||||
matches and anything already in the recent verbatim window (``exclude``)."""
|
||||
if self.memory is None or len(self.memory) == 0:
|
||||
return []
|
||||
try:
|
||||
qvec = await asyncio.to_thread(self.embedder.embed, query)
|
||||
except Exception: # noqa: BLE001 — embedder down → just skip recall
|
||||
return []
|
||||
hits = self.memory.search(qvec, self.rag_top_k)
|
||||
return [
|
||||
m for score, m in hits
|
||||
if score >= self.rag_min_score and m.content not in exclude
|
||||
]
|
||||
|
||||
async def _model_messages(self, query: str) -> list[Msg]:
|
||||
"""Assemble the message list for a model call: a single recalled-context
|
||||
preamble (if RAG surfaced anything) followed by the recent verbatim
|
||||
window. Recalled lines are clearly fenced as stale, untrusted context —
|
||||
never elevated to system role — so they inform without instructing."""
|
||||
window = self._window()
|
||||
retrieved = await self._retrieve(query, exclude={m.content for m in window})
|
||||
if not retrieved:
|
||||
return window
|
||||
recall = "Relevant earlier messages (recalled context, may be stale):\n" + \
|
||||
"\n".join(f"- {m.content}" for m in retrieved)
|
||||
return [Msg("user", recall)] + window
|
||||
|
||||
def _addressed_question(self, text: str) -> str | None:
|
||||
"""Return the question if this ``/ai …`` line targets us, else None."""
|
||||
t = text.strip()
|
||||
|
|
@ -261,11 +328,11 @@ class AgentBridge(Client):
|
|||
return
|
||||
await self._send_typing(ws, True)
|
||||
try:
|
||||
context = await self._model_messages(task)
|
||||
plan = await asyncio.to_thread(
|
||||
self.provider.complete,
|
||||
SANDBOX_SYSTEM.format(name=self.name),
|
||||
self._window()
|
||||
+ [Msg("user", f"{asker} wants this done in the shell: {task}")],
|
||||
context + [Msg("user", f"{asker} wants this done in the shell: {task}")],
|
||||
)
|
||||
except Exception as e: # noqa: BLE001 — surface provider failure in-room
|
||||
await self._send_typing(ws, False)
|
||||
|
|
@ -304,10 +371,11 @@ class AgentBridge(Client):
|
|||
self.transcript.append(Msg("user", f"{asker}: {question}"))
|
||||
await self._send_typing(ws, True)
|
||||
try:
|
||||
context = await self._model_messages(question)
|
||||
reply = await asyncio.to_thread(
|
||||
self.provider.complete,
|
||||
self.system_prompt,
|
||||
self._window(),
|
||||
context,
|
||||
)
|
||||
except Exception as e: # noqa: BLE001 — surface any provider failure in-room
|
||||
reply = f"[ai error: {e}]"
|
||||
|
|
@ -330,41 +398,55 @@ class AgentBridge(Client):
|
|||
)
|
||||
await ws.send(self.room_fernet.encrypt(announce.encode()).decode())
|
||||
self.success("agent online")
|
||||
async for raw in ws:
|
||||
if not self.running:
|
||||
break
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
mtype = data.get("type")
|
||||
if mtype == "init":
|
||||
self.users = data.get("users", [])
|
||||
self._seed_transcript(data.get("messages", []))
|
||||
continue
|
||||
if mtype == "roster":
|
||||
self.users = data.get("users", [])
|
||||
continue
|
||||
if mtype != "message":
|
||||
continue
|
||||
msg = self.decrypt_message(data.get("data", {}))
|
||||
text = msg.get("text", "")
|
||||
sender = msg.get("username", "?")
|
||||
if sender == self.name:
|
||||
continue # never react to our own messages
|
||||
if text.startswith('{"_'):
|
||||
self._handle_control(text) # track ACL grants; ignore other ctrl frames
|
||||
continue
|
||||
question = self._addressed_question(text)
|
||||
if question is None:
|
||||
# keep a short rolling transcript for context on future asks
|
||||
self.transcript.append(Msg("user", f"{sender}: {text}"))
|
||||
self.transcript = self.transcript[-(self.context_window * 2):]
|
||||
elif question.startswith("!"):
|
||||
self.info(f"{sender} → /ai !sbx: {question[1:].strip()}")
|
||||
await self._run_in_sandbox(ws, question[1:].strip(), sender)
|
||||
elif question.strip().lower() == "confirm":
|
||||
await self._confirm_pending(ws, sender)
|
||||
else:
|
||||
self.info(f"{sender} → /ai: {question}")
|
||||
await self._answer(ws, question, sender)
|
||||
embed_task = (
|
||||
asyncio.create_task(self._embed_worker())
|
||||
if self.memory is not None else None
|
||||
)
|
||||
try:
|
||||
await self._serve(ws)
|
||||
finally:
|
||||
if embed_task is not None:
|
||||
embed_task.cancel()
|
||||
|
||||
async def _serve(self, ws) -> None:
|
||||
async for raw in ws:
|
||||
if not self.running:
|
||||
break
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
mtype = data.get("type")
|
||||
if mtype == "init":
|
||||
self.users = data.get("users", [])
|
||||
self._seed_transcript(data.get("messages", []))
|
||||
continue
|
||||
if mtype == "roster":
|
||||
self.users = data.get("users", [])
|
||||
continue
|
||||
if mtype != "message":
|
||||
continue
|
||||
msg = self.decrypt_message(data.get("data", {}))
|
||||
text = msg.get("text", "")
|
||||
sender = msg.get("username", "?")
|
||||
if sender == self.name:
|
||||
continue # never react to our own messages
|
||||
if text.startswith('{"_'):
|
||||
self._handle_control(text) # track ACL grants; ignore other ctrl frames
|
||||
continue
|
||||
question = self._addressed_question(text)
|
||||
if question is None:
|
||||
# keep a short rolling transcript for context on future asks,
|
||||
# and feed the line to long-term semantic memory
|
||||
captured = Msg("user", f"{sender}: {text}")
|
||||
self.transcript.append(captured)
|
||||
self.transcript = self.transcript[-(self.context_window * 2):]
|
||||
self._remember(captured)
|
||||
elif question.startswith("!"):
|
||||
self.info(f"{sender} → /ai !sbx: {question[1:].strip()}")
|
||||
await self._run_in_sandbox(ws, question[1:].strip(), sender)
|
||||
elif question.strip().lower() == "confirm":
|
||||
await self._confirm_pending(ws, sender)
|
||||
else:
|
||||
self.info(f"{sender} → /ai: {question}")
|
||||
await self._answer(ws, question, sender)
|
||||
|
|
|
|||
58
cmd_chat/agent/memory.py
Normal file
58
cmd_chat/agent/memory.py
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
"""In-RAM semantic memory for the hack-house AI agent.
|
||||
|
||||
Holds embedded past messages in process memory only — no disk, no DB. The
|
||||
store is bounded and dies with the agent, exactly like the room's own history
|
||||
and the rolling transcript. Cosine similarity is computed in pure Python (the
|
||||
vectors are small and the store is capped), so there's no numpy dependency.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .providers import Msg
|
||||
|
||||
|
||||
@dataclass
|
||||
class _Entry:
|
||||
msg: Msg
|
||||
vec: list[float]
|
||||
norm: float # precomputed ||vec|| so search is a dot product + divide
|
||||
|
||||
|
||||
class MemoryIndex:
|
||||
"""A capped, in-memory pool of embedded messages for semantic recall.
|
||||
|
||||
This is the *long-term* store — it deliberately retains far more than the
|
||||
verbatim transcript window, so the agent can recall something said long
|
||||
before the recent slice. Oldest entries are evicted past ``max_entries`` to
|
||||
bound RAM (≈3 MB at 500 × 768-float vectors).
|
||||
"""
|
||||
|
||||
def __init__(self, max_entries: int = 500):
|
||||
self.max_entries = max_entries
|
||||
self._entries: list[_Entry] = []
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._entries)
|
||||
|
||||
def add(self, msg: Msg, vec: list[float]) -> None:
|
||||
norm = math.sqrt(sum(x * x for x in vec)) if vec else 0.0
|
||||
if norm == 0.0:
|
||||
return # empty / failed embedding — skip rather than poison search
|
||||
self._entries.append(_Entry(msg, vec, norm))
|
||||
if len(self._entries) > self.max_entries:
|
||||
self._entries = self._entries[-self.max_entries:]
|
||||
|
||||
def search(self, qvec: list[float], k: int) -> list[tuple[float, Msg]]:
|
||||
"""Top-``k`` entries by cosine similarity, highest first."""
|
||||
qnorm = math.sqrt(sum(x * x for x in qvec)) if qvec else 0.0
|
||||
if qnorm == 0.0 or not self._entries:
|
||||
return []
|
||||
scored = [
|
||||
(sum(a * b for a, b in zip(qvec, e.vec)) / (qnorm * e.norm), e.msg)
|
||||
for e in self._entries
|
||||
]
|
||||
scored.sort(key=lambda t: t[0], reverse=True)
|
||||
return scored[:k]
|
||||
|
|
@ -72,6 +72,29 @@ class OllamaProvider:
|
|||
return [m.get("name", "") for m in r.json().get("models", [])]
|
||||
|
||||
|
||||
class OllamaEmbedder:
|
||||
"""Local text embeddings via Ollama (default ``nomic-embed-text``), used for
|
||||
the agent's in-RAM semantic recall. Local + free, so it stays on by default
|
||||
regardless of which provider answers chat. No key, nothing persisted."""
|
||||
|
||||
name = "ollama-embed"
|
||||
|
||||
def __init__(self, model: str = "nomic-embed-text", host: str | None = None,
|
||||
timeout: int = 60):
|
||||
self.model = model
|
||||
self.host = (host or os.environ.get("OLLAMA_HOST", "http://localhost:11434")).rstrip("/")
|
||||
self.timeout = timeout
|
||||
|
||||
def embed(self, text: str) -> list[float]:
|
||||
r = requests.post(
|
||||
f"{self.host}/api/embeddings",
|
||||
json={"model": self.model, "prompt": text},
|
||||
timeout=self.timeout,
|
||||
)
|
||||
r.raise_for_status()
|
||||
return r.json().get("embedding") or []
|
||||
|
||||
|
||||
class AnthropicProvider:
|
||||
"""Anthropic Messages API. Cloud — opt-in. Needs ANTHROPIC_API_KEY."""
|
||||
|
||||
|
|
|
|||
|
|
@ -43,12 +43,16 @@ That is a RAM-only history the agent can backfill from on join at zero new cost.
|
|||
8. **Tune Ollama `options`** — explicit `num_ctx` (so the larger window in #1/#2
|
||||
is actually honored) and bounded `num_predict`. *(implementing)*
|
||||
|
||||
### Tier 2 — deeper context (next branch)
|
||||
3. **In-RAM semantic retrieval (RAG, no disk).** Embed each captured message
|
||||
with the already-present `nomic-embed-text`, hold vectors in a numpy array in
|
||||
memory; on a `/ai` question retrieve top-k by cosine and prepend to the
|
||||
recency window. Fully ephemeral.
|
||||
4. **In-RAM hierarchical compaction.** When over budget, summarize the oldest
|
||||
### Tier 2 — deeper context
|
||||
3. **In-RAM semantic retrieval (RAG, no disk).** *(done)* Each captured message
|
||||
is embedded with the already-present `nomic-embed-text` and held in a capped
|
||||
in-memory `MemoryIndex` (pure-Python cosine, no numpy). On a `/ai` question
|
||||
the agent embeds the query, retrieves top-k, drops weak/duplicate hits, and
|
||||
prepends them as a clearly-fenced "recalled context" preamble (never system
|
||||
role — keeps untrusted text from instructing). Embedding runs on a background
|
||||
worker so it can't stall the recv loop; if the embedder is unreachable it
|
||||
degrades to recency-only. Toggle with `--no-rag` / `--rag-top-k`.
|
||||
4. **In-RAM hierarchical compaction.** *(staged)* When over budget, summarize the oldest
|
||||
chunk into a single rolling `Msg("system", "earlier: …")` instead of dropping
|
||||
it — the Claude Code auto-compaction pattern, kept in RAM.
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user