feat(ai): backfill context on join + token-budget window

The server already ships the full RAM message backlog in the init frame; the
agent was discarding it. _seed_transcript now decrypts that history with the
room key (skipping our own lines, control frames, and undecryptable blobs) so
the agent has context the moment it joins instead of starting amnesiac.

_window() replaces the fixed last-12 slice on both the answer and sandbox
paths: it walks newest-to-oldest and keeps messages up to --token-budget
(approx, ~4 chars/token), still capped at --context-window count. Keeps small
local models inside their effective context. Nothing touches disk.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
leetcrypt 2026-06-02 17:43:02 -07:00
parent bbb9e82425
commit 9b85255d80
2 changed files with 61 additions and 5 deletions

View File

@ -76,7 +76,10 @@ def main() -> None:
ap.add_argument("--model", default=None, help="model name (provider default if omitted)") ap.add_argument("--model", default=None, help="model name (provider default if omitted)")
ap.add_argument("--base-url", default=None, help="endpoint for openai-compatible providers") ap.add_argument("--base-url", default=None, help="endpoint for openai-compatible providers")
ap.add_argument("--system", default=None, help="override the system prompt") ap.add_argument("--system", default=None, help="override the system prompt")
ap.add_argument("--context-window", type=int, default=12) ap.add_argument("--context-window", type=int, default=12,
help="max prior messages fed to the model per reply")
ap.add_argument("--token-budget", type=int, default=3000,
help="approx token cap on the context window (whichever is smaller wins)")
ap.add_argument("--list-models", action="store_true", ap.add_argument("--list-models", action="store_true",
help="list models the backend can serve, then exit") help="list models the backend can serve, then exit")
ap.add_argument("--check", action="store_true", ap.add_argument("--check", action="store_true",
@ -113,6 +116,7 @@ def main() -> None:
args.server, args.port, name=args.name, provider=provider, args.server, args.port, name=args.name, provider=provider,
password=args.password, insecure=args.insecure, no_tls=args.no_tls, password=args.password, insecure=args.insecure, no_tls=args.no_tls,
system_prompt=args.system, context_window=args.context_window, system_prompt=args.system, context_window=args.context_window,
token_budget=args.token_budget,
) )
try: try:
bridge.run() bridge.run()

View File

@ -65,19 +65,67 @@ MAX_BYTES = 8192
class AgentBridge(Client): class AgentBridge(Client):
def __init__(self, server: str, port: int, name: str, provider: Provider, def __init__(self, server: str, port: int, name: str, provider: Provider,
password: str | None = None, insecure: bool = False, no_tls: bool = False, password: str | None = None, insecure: bool = False, no_tls: bool = False,
system_prompt: str | None = None, context_window: int = 12): system_prompt: str | None = None, context_window: int = 12,
token_budget: int = 3000):
super().__init__(server, port, username=name, password=password, super().__init__(server, port, username=name, password=password,
insecure=insecure, no_tls=no_tls) insecure=insecure, no_tls=no_tls)
self.name = name self.name = name
self.provider = provider self.provider = provider
self.system_prompt = (system_prompt or DEFAULT_SYSTEM).format(name=name) self.system_prompt = (system_prompt or DEFAULT_SYSTEM).format(name=name)
self.context_window = context_window self.context_window = context_window
# Soft cap (approx tokens) on how much transcript we feed the model per
# call. context_window stays a hard ceiling on message *count*; whichever
# is smaller wins. Keeps small local models inside their effective ctx.
self.token_budget = token_budget
self.transcript: list[Msg] = [] self.transcript: list[Msg] = []
# Sandbox-drive state, mirrored from the owner's `_perm:acl` broadcasts. # Sandbox-drive state, mirrored from the owner's `_perm:acl` broadcasts.
self.granted = False # may we type into the shared PTY? self.granted = False # may we type into the shared PTY?
self.can_sudo = False # does our VM account have sudo? self.can_sudo = False # does our VM account have sudo?
self._pending: list[str] | None = None # destructive plan awaiting /confirm self._pending: list[str] | None = None # destructive plan awaiting /confirm
@staticmethod
def _est_tokens(text: str) -> int:
"""Cheap token estimate (~4 chars/token) — good enough to budget a
window without pulling in a tokenizer dependency."""
return len(text) // 4 + 1
def _window(self) -> list[Msg]:
"""The slice of transcript to feed the model: the most recent messages
that fit the token budget, capped at context_window messages. Walk from
newest to oldest so we keep the freshest context when the budget is tight."""
out: list[Msg] = []
used = 0
for m in reversed(self.transcript[-self.context_window:]):
cost = self._est_tokens(m.content)
if out and used + cost > self.token_budget:
break
out.append(m)
used += cost
out.reverse()
return out
def _seed_transcript(self, messages: list[dict]) -> None:
"""Backfill conversational context from the server's RAM history (the
`init` frame). Messages arrive as encrypted {text, username} dicts the
same shape live messages use so we decrypt with the room key, skip our
own lines, control frames, and anything that won't decrypt. Pure RAM:
nothing is written to disk, and it's gone when the process exits."""
seeded = 0
for raw in messages:
sender = raw.get("username", "?")
if sender == self.name:
continue
dec = self.decrypt_message(dict(raw))
text = dec.get("text", "")
if not text or text == "[decrypt failed]" or text.startswith('{"_'):
continue
self.transcript.append(Msg("user", f"{sender}: {text}"))
seeded += 1
# Keep the same rolling bound the live path uses.
self.transcript = self.transcript[-(self.context_window * 2):]
if seeded:
self.info(f"backfilled {seeded} prior message(s) for context")
def _addressed_question(self, text: str) -> str | None: def _addressed_question(self, text: str) -> str | None:
"""Return the question if this ``/ai …`` line targets us, else None.""" """Return the question if this ``/ai …`` line targets us, else None."""
t = text.strip() t = text.strip()
@ -216,7 +264,7 @@ class AgentBridge(Client):
plan = await asyncio.to_thread( plan = await asyncio.to_thread(
self.provider.complete, self.provider.complete,
SANDBOX_SYSTEM.format(name=self.name), SANDBOX_SYSTEM.format(name=self.name),
self.transcript[-self.context_window:] self._window()
+ [Msg("user", f"{asker} wants this done in the shell: {task}")], + [Msg("user", f"{asker} wants this done in the shell: {task}")],
) )
except Exception as e: # noqa: BLE001 — surface provider failure in-room except Exception as e: # noqa: BLE001 — surface provider failure in-room
@ -259,7 +307,7 @@ class AgentBridge(Client):
reply = await asyncio.to_thread( reply = await asyncio.to_thread(
self.provider.complete, self.provider.complete,
self.system_prompt, self.system_prompt,
self.transcript[-self.context_window:], self._window(),
) )
except Exception as e: # noqa: BLE001 — surface any provider failure in-room except Exception as e: # noqa: BLE001 — surface any provider failure in-room
reply = f"[ai error: {e}]" reply = f"[ai error: {e}]"
@ -290,7 +338,11 @@ class AgentBridge(Client):
except json.JSONDecodeError: except json.JSONDecodeError:
continue continue
mtype = data.get("type") mtype = data.get("type")
if mtype in ("init", "roster"): if mtype == "init":
self.users = data.get("users", [])
self._seed_transcript(data.get("messages", []))
continue
if mtype == "roster":
self.users = data.get("users", []) self.users = data.get("users", [])
continue continue
if mtype != "message": if mtype != "message":