diff --git a/cmd_chat/agent/__main__.py b/cmd_chat/agent/__main__.py index 1b9ca2e..61ba968 100644 --- a/cmd_chat/agent/__main__.py +++ b/cmd_chat/agent/__main__.py @@ -76,7 +76,10 @@ def main() -> None: ap.add_argument("--model", default=None, help="model name (provider default if omitted)") ap.add_argument("--base-url", default=None, help="endpoint for openai-compatible providers") ap.add_argument("--system", default=None, help="override the system prompt") - ap.add_argument("--context-window", type=int, default=12) + ap.add_argument("--context-window", type=int, default=12, + help="max prior messages fed to the model per reply") + ap.add_argument("--token-budget", type=int, default=3000, + help="approx token cap on the context window (whichever is smaller wins)") ap.add_argument("--list-models", action="store_true", help="list models the backend can serve, then exit") ap.add_argument("--check", action="store_true", @@ -113,6 +116,7 @@ def main() -> None: args.server, args.port, name=args.name, provider=provider, password=args.password, insecure=args.insecure, no_tls=args.no_tls, system_prompt=args.system, context_window=args.context_window, + token_budget=args.token_budget, ) try: bridge.run() diff --git a/cmd_chat/agent/bridge.py b/cmd_chat/agent/bridge.py index 25d3798..7cb9716 100644 --- a/cmd_chat/agent/bridge.py +++ b/cmd_chat/agent/bridge.py @@ -65,19 +65,67 @@ MAX_BYTES = 8192 class AgentBridge(Client): def __init__(self, server: str, port: int, name: str, provider: Provider, password: str | None = None, insecure: bool = False, no_tls: bool = False, - system_prompt: str | None = None, context_window: int = 12): + system_prompt: str | None = None, context_window: int = 12, + token_budget: int = 3000): super().__init__(server, port, username=name, password=password, insecure=insecure, no_tls=no_tls) self.name = name self.provider = provider self.system_prompt = (system_prompt or DEFAULT_SYSTEM).format(name=name) self.context_window = context_window + # Soft cap (approx tokens) on how much transcript we feed the model per + # call. context_window stays a hard ceiling on message *count*; whichever + # is smaller wins. Keeps small local models inside their effective ctx. + self.token_budget = token_budget self.transcript: list[Msg] = [] # Sandbox-drive state, mirrored from the owner's `_perm:acl` broadcasts. self.granted = False # may we type into the shared PTY? self.can_sudo = False # does our VM account have sudo? self._pending: list[str] | None = None # destructive plan awaiting /confirm + @staticmethod + def _est_tokens(text: str) -> int: + """Cheap token estimate (~4 chars/token) — good enough to budget a + window without pulling in a tokenizer dependency.""" + return len(text) // 4 + 1 + + def _window(self) -> list[Msg]: + """The slice of transcript to feed the model: the most recent messages + that fit the token budget, capped at context_window messages. Walk from + newest to oldest so we keep the freshest context when the budget is tight.""" + out: list[Msg] = [] + used = 0 + for m in reversed(self.transcript[-self.context_window:]): + cost = self._est_tokens(m.content) + if out and used + cost > self.token_budget: + break + out.append(m) + used += cost + out.reverse() + return out + + def _seed_transcript(self, messages: list[dict]) -> None: + """Backfill conversational context from the server's RAM history (the + `init` frame). Messages arrive as encrypted {text, username} dicts — the + same shape live messages use — so we decrypt with the room key, skip our + own lines, control frames, and anything that won't decrypt. Pure RAM: + nothing is written to disk, and it's gone when the process exits.""" + seeded = 0 + for raw in messages: + sender = raw.get("username", "?") + if sender == self.name: + continue + dec = self.decrypt_message(dict(raw)) + text = dec.get("text", "") + if not text or text == "[decrypt failed]" or text.startswith('{"_'): + continue + self.transcript.append(Msg("user", f"{sender}: {text}")) + seeded += 1 + # Keep the same rolling bound the live path uses. + self.transcript = self.transcript[-(self.context_window * 2):] + if seeded: + self.info(f"backfilled {seeded} prior message(s) for context") + def _addressed_question(self, text: str) -> str | None: """Return the question if this ``/ai …`` line targets us, else None.""" t = text.strip() @@ -216,7 +264,7 @@ class AgentBridge(Client): plan = await asyncio.to_thread( self.provider.complete, SANDBOX_SYSTEM.format(name=self.name), - self.transcript[-self.context_window:] + self._window() + [Msg("user", f"{asker} wants this done in the shell: {task}")], ) except Exception as e: # noqa: BLE001 — surface provider failure in-room @@ -259,7 +307,7 @@ class AgentBridge(Client): reply = await asyncio.to_thread( self.provider.complete, self.system_prompt, - self.transcript[-self.context_window:], + self._window(), ) except Exception as e: # noqa: BLE001 — surface any provider failure in-room reply = f"[ai error: {e}]" @@ -290,7 +338,11 @@ class AgentBridge(Client): except json.JSONDecodeError: continue mtype = data.get("type") - if mtype in ("init", "roster"): + if mtype == "init": + self.users = data.get("users", []) + self._seed_transcript(data.get("messages", [])) + continue + if mtype == "roster": self.users = data.get("users", []) continue if mtype != "message":