Tier A/B/C wins for the CPU-only Ollama box (no GPU → optimize TTFT and tokens/sec, not VRAM): - Separate qwen2.5-coder provider for the sandbox `!task` path; chat keeps the general model. Auto-selected when chat is Ollama and a coder build is present, override with --code-model. - OllamaProvider num_ctx default 8192→4096 (8192 was a GPU-mindset default that inflates prefill/TTFT on CPU); expose num_thread; add --num-ctx, --num-thread, --num-predict. token_budget default 3000→2000 to fit. - OllamaProvider.stream() generator over Ollama's stream=True chat endpoint (provider half of token streaming; agent/Rust rendering is a follow-up). - Few-shot request→shell exemplars in SANDBOX_SYSTEM to anchor the small model's fenced-command output. - Matryoshka embedding truncation: OllamaEmbedder truncate_dim=256 (--embed-dim) for faster pure-Python cosine and less RAM; query+stored share the dim. - docs/ai-perf-plan.md records all 8 items with status and the server-side env (OLLAMA_NUM_PARALLEL=1, keep_alive) that must be set where ollama serve runs. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
274 lines
10 KiB
Python
274 lines
10 KiB
Python
"""Model-agnostic provider interface for the hack-house AI agent bridge.
|
|
|
|
A Provider turns a system prompt + conversation into a single reply string.
|
|
The bundled adapters speak plain HTTP via ``requests`` (already a dependency),
|
|
so no extra SDKs are required and any backend can be plugged in — including a
|
|
custom one via the ``module:Class`` spec.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import importlib
|
|
import json
|
|
import os
|
|
from dataclasses import dataclass
|
|
from typing import Protocol, runtime_checkable
|
|
|
|
import requests
|
|
|
|
|
|
@dataclass
|
|
class Msg:
|
|
role: str # "system" | "user" | "assistant"
|
|
content: str
|
|
|
|
|
|
@runtime_checkable
|
|
class Provider(Protocol):
|
|
name: str
|
|
model: str
|
|
|
|
def complete(self, system: str, messages: list[Msg]) -> str:
|
|
...
|
|
|
|
# Optional: list models the backend can serve, for discovery/preflight.
|
|
# Providers that can't enumerate (e.g. a bespoke endpoint) may omit this.
|
|
def available_models(self) -> list[str]:
|
|
...
|
|
|
|
|
|
class OllamaProvider:
|
|
"""Local Ollama (default, recommended). No API key — privacy-preserving."""
|
|
|
|
name = "ollama"
|
|
|
|
def __init__(self, model: str = "llama3", host: str | None = None, timeout: int = 120,
|
|
num_ctx: int = 4096, num_predict: int = 512, num_thread: int | None = None,
|
|
keep_alive: str = "30m"):
|
|
self.model = model
|
|
self.host = (host or os.environ.get("OLLAMA_HOST", "http://localhost:11434")).rstrip("/")
|
|
self.timeout = timeout
|
|
# On CPU, time-to-first-token is O(num_ctx) prefill, so keep the window
|
|
# modest (4096) rather than a GPU-mindset 8192. keep_alive pins the model
|
|
# so the next /ai doesn't pay a cold reload. num_thread defaults to
|
|
# Ollama's own (≈physical cores); set it explicitly to benchmark 4/6/8.
|
|
self.num_ctx = num_ctx
|
|
self.num_predict = num_predict
|
|
self.num_thread = num_thread
|
|
self.keep_alive = keep_alive
|
|
|
|
def _options(self) -> dict:
|
|
opts = {"num_ctx": self.num_ctx, "num_predict": self.num_predict}
|
|
if self.num_thread is not None:
|
|
opts["num_thread"] = self.num_thread
|
|
return opts
|
|
|
|
def complete(self, system: str, messages: list[Msg]) -> str:
|
|
payload = {
|
|
"model": self.model,
|
|
"stream": False,
|
|
"keep_alive": self.keep_alive,
|
|
"options": self._options(),
|
|
"messages": [{"role": "system", "content": system}]
|
|
+ [{"role": m.role, "content": m.content} for m in messages],
|
|
}
|
|
r = requests.post(f"{self.host}/api/chat", json=payload, timeout=self.timeout)
|
|
r.raise_for_status()
|
|
return (r.json().get("message", {}).get("content") or "").strip()
|
|
|
|
def stream(self, system: str, messages: list[Msg]):
|
|
"""Yield reply text incrementally as Ollama generates it. On CPU the
|
|
perceived latency is TTFT, so streaming makes a slow reply feel live."""
|
|
payload = {
|
|
"model": self.model,
|
|
"stream": True,
|
|
"keep_alive": self.keep_alive,
|
|
"options": self._options(),
|
|
"messages": [{"role": "system", "content": system}]
|
|
+ [{"role": m.role, "content": m.content} for m in messages],
|
|
}
|
|
with requests.post(f"{self.host}/api/chat", json=payload,
|
|
timeout=self.timeout, stream=True) as r:
|
|
r.raise_for_status()
|
|
for line in r.iter_lines():
|
|
if not line:
|
|
continue
|
|
chunk = json.loads(line)
|
|
piece = chunk.get("message", {}).get("content")
|
|
if piece:
|
|
yield piece
|
|
if chunk.get("done"):
|
|
break
|
|
|
|
def available_models(self) -> list[str]:
|
|
r = requests.get(f"{self.host}/api/tags", timeout=self.timeout)
|
|
r.raise_for_status()
|
|
return [m.get("name", "") for m in r.json().get("models", [])]
|
|
|
|
|
|
class OllamaEmbedder:
|
|
"""Local text embeddings via Ollama (default ``nomic-embed-text``), used for
|
|
the agent's in-RAM semantic recall. Local + free, so it stays on by default
|
|
regardless of which provider answers chat. No key, nothing persisted."""
|
|
|
|
name = "ollama-embed"
|
|
|
|
def __init__(self, model: str = "nomic-embed-text", host: str | None = None,
|
|
timeout: int = 60, truncate_dim: int | None = 256):
|
|
self.model = model
|
|
self.host = (host or os.environ.get("OLLAMA_HOST", "http://localhost:11434")).rstrip("/")
|
|
self.timeout = timeout
|
|
# nomic-embed-text is Matryoshka (MRL)-trained, so its 768-dim vector can
|
|
# be truncated to a shorter prefix with little quality loss — faster
|
|
# pure-Python cosine and less RAM. Query + stored use the same dim, so
|
|
# cosine stays correct. None keeps the full vector.
|
|
self.truncate_dim = truncate_dim
|
|
|
|
def embed(self, text: str) -> list[float]:
|
|
r = requests.post(
|
|
f"{self.host}/api/embeddings",
|
|
json={"model": self.model, "prompt": text},
|
|
timeout=self.timeout,
|
|
)
|
|
r.raise_for_status()
|
|
vec = r.json().get("embedding") or []
|
|
if self.truncate_dim is not None:
|
|
vec = vec[: self.truncate_dim]
|
|
return vec
|
|
|
|
|
|
class AnthropicProvider:
|
|
"""Anthropic Messages API. Cloud — opt-in. Needs ANTHROPIC_API_KEY."""
|
|
|
|
name = "anthropic"
|
|
|
|
def __init__(self, model: str = "claude-opus-4-6", api_key: str | None = None,
|
|
timeout: int = 120, max_tokens: int = 1024):
|
|
self.model = model
|
|
self.api_key = api_key or os.environ.get("ANTHROPIC_API_KEY")
|
|
self.timeout = timeout
|
|
self.max_tokens = max_tokens
|
|
if not self.api_key:
|
|
raise ValueError("ANTHROPIC_API_KEY not set")
|
|
|
|
def complete(self, system: str, messages: list[Msg]) -> str:
|
|
payload = {
|
|
"model": self.model,
|
|
"max_tokens": self.max_tokens,
|
|
"system": system,
|
|
"messages": [
|
|
{"role": m.role, "content": m.content}
|
|
for m in messages
|
|
if m.role in ("user", "assistant")
|
|
],
|
|
}
|
|
r = requests.post(
|
|
"https://api.anthropic.com/v1/messages",
|
|
json=payload,
|
|
timeout=self.timeout,
|
|
headers={
|
|
"x-api-key": self.api_key,
|
|
"anthropic-version": "2023-06-01",
|
|
"content-type": "application/json",
|
|
},
|
|
)
|
|
r.raise_for_status()
|
|
blocks = r.json().get("content", [])
|
|
return "".join(b.get("text", "") for b in blocks).strip()
|
|
|
|
def available_models(self) -> list[str]:
|
|
r = requests.get(
|
|
"https://api.anthropic.com/v1/models",
|
|
timeout=self.timeout,
|
|
headers={"x-api-key": self.api_key, "anthropic-version": "2023-06-01"},
|
|
)
|
|
r.raise_for_status()
|
|
return [m.get("id", "") for m in r.json().get("data", [])]
|
|
|
|
|
|
class OpenAICompatibleProvider:
|
|
"""OpenAI-style /chat/completions — OpenAI, Groq, Together, local vLLM, etc."""
|
|
|
|
name = "openai"
|
|
|
|
def __init__(self, model: str = "gpt-4o-mini", api_key: str | None = None,
|
|
base_url: str | None = None, timeout: int = 120):
|
|
self.model = model
|
|
self.api_key = api_key or os.environ.get("OPENAI_API_KEY", "")
|
|
self.base_url = (base_url or os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1")).rstrip("/")
|
|
self.timeout = timeout
|
|
|
|
def complete(self, system: str, messages: list[Msg]) -> str:
|
|
payload = {
|
|
"model": self.model,
|
|
"messages": [{"role": "system", "content": system}]
|
|
+ [{"role": m.role, "content": m.content} for m in messages],
|
|
}
|
|
headers = {"content-type": "application/json"}
|
|
if self.api_key:
|
|
headers["authorization"] = f"Bearer {self.api_key}"
|
|
r = requests.post(
|
|
f"{self.base_url}/chat/completions", json=payload, headers=headers, timeout=self.timeout
|
|
)
|
|
r.raise_for_status()
|
|
return r.json()["choices"][0]["message"]["content"].strip()
|
|
|
|
def available_models(self) -> list[str]:
|
|
headers = {}
|
|
if self.api_key:
|
|
headers["authorization"] = f"Bearer {self.api_key}"
|
|
r = requests.get(f"{self.base_url}/models", headers=headers, timeout=self.timeout)
|
|
r.raise_for_status()
|
|
return [m.get("id", "") for m in r.json().get("data", [])]
|
|
|
|
|
|
_BUILTINS = {
|
|
"ollama": OllamaProvider,
|
|
"anthropic": AnthropicProvider,
|
|
"openai": OpenAICompatibleProvider,
|
|
}
|
|
|
|
|
|
def make_provider(spec: str, model: str | None = None, **opts) -> Provider:
|
|
"""Build a provider.
|
|
|
|
``spec`` is a builtin name (``ollama`` / ``anthropic`` / ``openai``) or a
|
|
``module:Class`` path to a custom Provider implementation.
|
|
"""
|
|
if ":" in spec:
|
|
mod_name, _, cls_name = spec.partition(":")
|
|
cls = getattr(importlib.import_module(mod_name), cls_name)
|
|
else:
|
|
cls = _BUILTINS.get(spec)
|
|
if cls is None:
|
|
raise ValueError(f"unknown provider '{spec}' (builtins: {', '.join(_BUILTINS)})")
|
|
if model is not None:
|
|
opts["model"] = model
|
|
return cls(**opts)
|
|
|
|
|
|
def preflight(provider: Provider) -> tuple[bool, str]:
|
|
"""Cheap reachability + model-presence check before joining a room.
|
|
|
|
Returns ``(ok, message)``. Lets ``/ai start`` fail fast with a clear reason
|
|
(backend down / model not pulled / key missing) instead of erroring on the
|
|
first question. Providers without ``available_models`` are assumed reachable.
|
|
"""
|
|
discover = getattr(provider, "available_models", None)
|
|
if discover is None:
|
|
return True, f"{provider.name}: no discovery endpoint — assuming reachable"
|
|
try:
|
|
models = discover()
|
|
except Exception as e: # noqa: BLE001 — any failure means "not reachable yet"
|
|
return False, f"{provider.name}: cannot reach backend ({e})"
|
|
if provider.model in models:
|
|
return True, f"{provider.name}/{provider.model}: reachable"
|
|
if models:
|
|
sample = ", ".join(models[:8])
|
|
more = "…" if len(models) > 8 else ""
|
|
return False, (
|
|
f"{provider.name}: model '{provider.model}' not available. "
|
|
f"reachable models: {sample}{more}"
|
|
)
|
|
return True, f"{provider.name}: reachable (empty model list — skipping check)"
|