diff --git a/Cargo.lock b/Cargo.lock index 3873a2b..5bcd77e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -100,7 +100,7 @@ dependencies = [ [[package]] name = "exoshell" -version = "0.3.0" +version = "0.4.0" dependencies = [ "async-trait", "futures-util", diff --git a/Cargo.toml b/Cargo.toml index 7298dd5..e2e3b6b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "exoshell" -version = "0.3.0" +version = "0.4.0" edition = "2024" license = "GPL-3.0-or-later" diff --git a/README.md b/README.md index 22ae8dc..20b8b3f 100644 --- a/README.md +++ b/README.md @@ -63,6 +63,17 @@ Select an operating stance: cargo run -- --stance audit ``` +Configure model routing: + +```toml +[router] +enabled = true +model = "qwen2.5-coder:7b" +fallback_role = "coding" +``` + +The default router roles are `instant`, `coding`, `heavy`, and `conversational`. For Ollama model setup examples, see [khodges42/modelfiles](https://github.com/khodges42/modelfiles). + Exoshell suggests commands. It does not execute them. ## Quality Checks diff --git a/docs/phase2_interaction_model.md b/docs/phase2_interaction_model.md index d76d91e..f5c7da6 100644 --- a/docs/phase2_interaction_model.md +++ b/docs/phase2_interaction_model.md @@ -18,6 +18,8 @@ Command suggestion is a shell fenced code block such as `powershell`, `pwsh`, `s Transcript entry is a markdown record of user prompts, assistant responses, context events, budget warnings, stance changes, command suggestions, and command actions. +Model route is an optional provider decision made before a request is answered. When model routing is enabled, a fast router model chooses one configured role. The selected role, target model, and reason are recorded in the transcript. + ## Prompt Assembly Prompt assembly is deterministic: @@ -113,9 +115,11 @@ Useful commands: ```text /panel +/keys /context /context stats /help +/help keys /help context /help stance /help commands @@ -123,6 +127,8 @@ Useful commands: `/panel` renders stance, shell family, provider/model, transcript state, context entries, and prompt estimates without requiring a TUI. +`/keys` documents the current line-REPL key actions and slash-command fallbacks. Advanced terminal key handling is not active yet; copy, explain, discard, context, and stance controls degrade to explicit slash commands. + ## Non-Goals Phase 2 does not make Exoshell an autonomous executor. @@ -141,4 +147,34 @@ Command parsing is intentionally simple and based on fenced blocks. Risk detection is heuristic and incomplete. -Advanced TUI keybindings and config profiles remain planned work. +Advanced full-screen TUI keybindings and config profiles remain planned work. + +## Model Routing + +The configurable model router lets a fast model inspect the prompt payload and choose the model role that should answer. + +Default roles: + +```text +instant qwen2.5-coder:7b +coding coder-g4-26b +heavy coder-g4-26b +conversational qwen2.5-coder:7b +``` + +Enable routing: + +```toml +[router] +enabled = true +model = "qwen2.5-coder:7b" +fallback_role = "coding" +``` + +The router asks for compact JSON: + +```json +{"role":"coding","reason":"source code change request"} +``` + +If the router fails or returns an unknown role, Exoshell uses the configured fallback role and records that reason in the transcript. diff --git a/docs/quickstart.md b/docs/quickstart.md index 7b4f657..45bb912 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -82,6 +82,60 @@ To replace the built-in defaults entirely: include_defaults = false ``` +## Configure Model Routing + +Exoshell can route each prompt through a fast router model before selecting the model that should answer. + +Enable the default router: + +```toml +[router] +enabled = true +model = "qwen2.5-coder:7b" +fallback_role = "coding" +``` + +Default roles: + +```text +instant qwen2.5-coder:7b +coding coder-g4-26b +heavy coder-g4-26b +conversational qwen2.5-coder:7b +``` + +Override role models or behavior: + +```toml +[router] +enabled = true +model = "qwen2.5-coder:7b" +fallback_role = "coding" +behavior = "Prefer instant for short shell questions. Use heavy only for architecture or high-context analysis." + +[[router.roles]] +name = "instant" +model = "qwen2.5-coder:7b" +description = "fast responses for simple prompts" + +[[router.roles]] +name = "coding" +model = "coder-g4-26b" +description = "code edits, debugging, tests, and shell command construction" + +[[router.roles]] +name = "heavy" +model = "coder-g4-26b" +description = "complex reasoning and architecture" + +[[router.roles]] +name = "conversational" +model = "qwen2.5-coder:7b" +description = "general discussion and explanations" +``` + +For Ollama model setup examples, see [khodges42/modelfiles](https://github.com/khodges42/modelfiles). + ## Start Exoshell Run with defaults: @@ -246,6 +300,16 @@ exo> /panel The panel includes stance, shell family, provider/model, transcript state, context entries, and prompt estimates. +## Keybinding Fallbacks + +The current REPL is line-oriented. Use `/keys` to show the available key actions and their slash-command fallbacks: + +```text +exo> /keys +``` + +Copy, explain, discard, context, and stance actions degrade to explicit commands such as `/copy `, `/explain `, `/discard `, `/context`, and `/stance`. + ## Multi-Line Prompts Use `/multi` for longer prompts: diff --git a/docs/versioning.md b/docs/versioning.md index 4f10b23..40b641a 100644 --- a/docs/versioning.md +++ b/docs/versioning.md @@ -102,3 +102,4 @@ Historical codenames should be tracked in docs/versioning.md below * 0.1.0 packet-kobold * 0.2.0 context-relic * 0.3.0 stance-lantern +* 0.4.0 switchboard-relic diff --git a/src/app.rs b/src/app.rs index 174a781..75a6fca 100644 --- a/src/app.rs +++ b/src/app.rs @@ -9,6 +9,7 @@ use crate::context::{ render_context_details, render_context_list, render_context_stats, }; use crate::formatting::render_assistant_output_with_policy; +use crate::keybindings::render_keybindings; use crate::prompts::{Stance, assemble_prompt, render_prompt_estimate}; use crate::providers::{ChatMessage, ChatRequest, ChatResponse, ChatRole, Provider, ProviderError}; use crate::repl::ReplError; @@ -72,6 +73,9 @@ impl App { Ok(Ok(ChatResponse::Complete(response))) => response, Ok(Ok(ChatResponse::Stream(chunks))) => chunks.concat(), Ok(Err(error)) => { + if let Some(route) = self.provider.last_model_route() { + self.transcript.record_model_route(&route); + } self.transcript.record_error(&error.to_string()); return Err(error.into()); } @@ -79,6 +83,9 @@ impl App { self.conversation .push(ChatMessage::new(ChatRole::Assistant, response.clone())); self.transcript.record_assistant(&response); + if let Some(route) = self.provider.last_model_route() { + self.transcript.record_model_route(&route); + } self.last_command_suggestions = parse_command_suggestions_with_policy(&response, &self.config.commands.risk); for suggestion in &self.last_command_suggestions { @@ -102,6 +109,10 @@ impl App { return Ok(help_overview().into()); } + if trimmed == "/keys" { + return Ok(render_keybindings()); + } + if let Some(topic) = trimmed.strip_prefix("/help ") { return Ok(help_topic(topic.trim()).into()); } @@ -328,10 +339,11 @@ impl App { fn render_panel(&self) -> String { format!( - "Exoshell session\nstance: {}\nshell: {}\nprovider: openai-compatible\nmodel: {}\ntranscript: {}\n\nContext\n{}\n\nPrompt estimate\n{}", + "Exoshell session\nstance: {}\nshell: {}\nprovider: openai-compatible\nmodel: {}\nrouter: {}\ntranscript: {}\n\nContext\n{}\n\nPrompt estimate\n{}", self.config.interaction.stance, self.config.shell.family, self.config.provider.model, + self.render_router_status(), if self.config.transcript.enabled { "enabled" } else { @@ -342,6 +354,25 @@ impl App { ) } + fn render_router_status(&self) -> String { + if !self.config.router.enabled { + return "disabled".into(); + } + + let roles = self + .config + .router + .roles + .iter() + .map(|role| format!("{}={}", role.name, role.model)) + .collect::>() + .join(", "); + format!( + "enabled model={} fallback={} roles=[{}]", + self.config.router.model, self.config.router.fallback_role, roles + ) + } + fn set_stance(&mut self, input: &str) -> Result { let stance = input .parse::() @@ -519,6 +550,7 @@ fn help_overview() -> &'static str { /explain explain a suggested command /discard mark a suggested command as discarded /panel show session, stance, provider, and context state +/keys show keybinding fallbacks for the line REPL /multi enter multi-line input /exit quit and write transcript if enabled @@ -536,7 +568,10 @@ fn help_topic(topic: &str) -> &'static str { "commands" => { "Suggested commands appear as fenced shell blocks and get IDs such as cmd-001. Use /copy, /explain, or /discard by ID. Copy prints the command when clipboard support is unavailable and never runs it." } - _ => "Unknown help topic. Try /help context, /help stance, or /help commands.", + "keys" => { + "The current line REPL does not install advanced terminal keybindings. Use /keys to see the predictable slash-command fallbacks for copy, explain, discard, context, and stance actions." + } + _ => "Unknown help topic. Try /help context, /help stance, /help commands, or /help keys.", } } @@ -861,6 +896,16 @@ mod tests { .expect("help") .contains("/copy") ); + assert!( + app.handle_command("/keys") + .expect("keys") + .contains("/discard ") + ); + assert!( + app.handle_command("/help keys") + .expect("help keys") + .contains("line REPL") + ); } #[tokio::test] @@ -961,6 +1006,7 @@ mod tests { model: "test-model".into(), request_timeout_seconds: 120, }, + router: crate::providers::router::ModelRouterConfig::default(), shell: ShellConfig { family: ShellFamily::PowerShell, }, diff --git a/src/commands.rs b/src/commands.rs index f9d4677..14e99ef 100644 --- a/src/commands.rs +++ b/src/commands.rs @@ -156,7 +156,8 @@ impl CommandRisk { } } -pub fn parse_command_suggestions(response: &str) -> Vec { +#[cfg(test)] +fn parse_command_suggestions(response: &str) -> Vec { parse_command_suggestions_with_policy(response, &CommandRiskPolicy::default()) } @@ -218,7 +219,8 @@ pub fn parse_command_suggestions_with_policy( suggestions } -pub fn detect_command_risk(command: &str, shell: CommandShell) -> CommandRisk { +#[cfg(test)] +fn detect_command_risk(command: &str, shell: CommandShell) -> CommandRisk { detect_command_risk_with_policy(command, shell, &CommandRiskPolicy::default()) } diff --git a/src/config.rs b/src/config.rs index c9bacbc..3ee8876 100644 --- a/src/config.rs +++ b/src/config.rs @@ -8,11 +8,13 @@ use crate::app::CliOptions; use crate::commands::{CommandRiskPolicy, CommandRiskRule}; use crate::context::ContextBudget; use crate::prompts::Stance; +use crate::providers::router::{ModelRouterConfig, ModelRouterRole}; use crate::shell::ShellFamily; #[derive(Debug, Clone, PartialEq, Eq)] pub struct Config { pub provider: ProviderConfig, + pub router: ModelRouterConfig, pub shell: ShellConfig, pub interaction: InteractionConfig, pub commands: CommandConfig, @@ -68,6 +70,7 @@ impl ContextConfig { #[derive(Debug, Deserialize, Default)] struct RawConfig { provider: Option, + router: Option, shell: Option, interaction: Option, commands: Option, @@ -83,6 +86,22 @@ struct RawProviderConfig { request_timeout_seconds: Option, } +#[derive(Debug, Deserialize, Default)] +struct RawModelRouterConfig { + enabled: Option, + model: Option, + fallback_role: Option, + behavior: Option, + roles: Option>, +} + +#[derive(Debug, Deserialize, Default)] +struct RawModelRouterRole { + name: Option, + model: Option, + description: Option, +} + #[derive(Debug, Deserialize, Default)] struct RawShellConfig { family: Option, @@ -128,6 +147,7 @@ impl Config { fn from_raw(raw: RawConfig) -> Result { let provider = raw.provider.unwrap_or_default(); + let router = raw.router.unwrap_or_default(); let shell = raw.shell.unwrap_or_default(); let interaction = raw.interaction.unwrap_or_default(); let commands = raw.commands.unwrap_or_default(); @@ -141,6 +161,7 @@ impl Config { .api_key_env .unwrap_or_else(|| "OPENAI_API_KEY".into()); let api_key = provider_api_key(&base_url, &api_key_env)?; + let router = model_router_config(router)?; let family = shell.family.unwrap_or_else(default_shell_family); let family = family @@ -161,6 +182,7 @@ impl Config { model: provider.model.unwrap_or_else(|| "gpt-4.1-mini".into()), request_timeout_seconds: provider.request_timeout_seconds.unwrap_or(120), }, + router, shell: ShellConfig { family }, interaction: InteractionConfig { stance }, commands: CommandConfig { risk }, @@ -196,6 +218,61 @@ impl Config { } } +fn model_router_config(raw: RawModelRouterConfig) -> Result { + let mut config = ModelRouterConfig::default(); + if let Some(enabled) = raw.enabled { + config.enabled = enabled; + } + if let Some(model) = raw.model { + config.model = non_empty_config_value("router.model", model)?; + } + if let Some(fallback_role) = raw.fallback_role { + config.fallback_role = non_empty_config_value("router.fallback_role", fallback_role)?; + } + if let Some(behavior) = raw.behavior { + config.behavior = non_empty_config_value("router.behavior", behavior)?; + } + if let Some(roles) = raw.roles { + let mut parsed = Vec::new(); + for role in roles { + parsed.push(ModelRouterRole { + name: non_empty_config_value( + "router.roles.name", + role.name.ok_or_else(|| { + ConfigError::Invalid("router.roles entries require name".into()) + })?, + )?, + model: non_empty_config_value( + "router.roles.model", + role.model.ok_or_else(|| { + ConfigError::Invalid("router.roles entries require model".into()) + })?, + )?, + description: non_empty_config_value( + "router.roles.description", + role.description.ok_or_else(|| { + ConfigError::Invalid("router.roles entries require description".into()) + })?, + )?, + }); + } + config.roles = parsed; + } + + config + .validate() + .map_err(|error| ConfigError::Invalid(error.to_string()))?; + Ok(config) +} + +fn non_empty_config_value(name: &str, value: String) -> Result { + if value.trim().is_empty() { + Err(ConfigError::Invalid(format!("{name} cannot be empty"))) + } else { + Ok(value) + } +} + fn command_risk_policy( raw: Option, ) -> Result { @@ -385,6 +462,7 @@ mod tests { shell: Some(RawShellConfig { family: Some("cmd".into()), }), + router: None, interaction: None, commands: None, transcript: None, @@ -406,6 +484,22 @@ base_url = "http://localhost:11434/v1" model = "local-model" request_timeout_seconds = 45 +[router] +enabled = true +model = "qwen2.5-coder:7b" +fallback_role = "instant" +behavior = "Route to the smallest model that can answer well." + +[[router.roles]] +name = "instant" +model = "qwen2.5-coder:7b" +description = "fast answers" + +[[router.roles]] +name = "heavy" +model = "coder-g4-26b" +description = "deep technical work" + [shell] family = "posix" @@ -436,6 +530,11 @@ max_estimated_tokens = 3000 assert_eq!(config.provider.api_key, "exoshell-local-provider"); assert_eq!(config.provider.model, "local-model"); assert_eq!(config.provider.request_timeout_seconds, 45); + assert!(config.router.enabled); + assert_eq!(config.router.model, "qwen2.5-coder:7b"); + assert_eq!(config.router.fallback_role, "instant"); + assert_eq!(config.router.roles.len(), 2); + assert_eq!(config.router.roles[1].model, "coder-g4-26b"); assert_eq!(config.shell.family, ShellFamily::Posix); assert_eq!(config.interaction.stance, Stance::Audit); assert!(config.commands.risk.include_defaults); @@ -470,6 +569,36 @@ max_estimated_tokens = 3000 assert_eq!(config.provider.request_timeout_seconds, 120); assert!(config.commands.risk.include_defaults); assert!(config.commands.risk.rules.is_empty()); + assert!(!config.router.enabled); + assert_eq!( + config + .router + .role("conversational") + .expect("conversational") + .model, + "qwen2.5-coder:7b" + ); + } + + #[test] + fn rejects_router_fallback_role_that_is_not_defined() { + let mut file = tempfile::NamedTempFile::new().expect("temp config"); + write!( + file, + r#" +[provider] +base_url = "http://localhost:11434/v1" + +[router] +enabled = true +fallback_role = "missing" +"# + ) + .expect("write config"); + + let error = Config::load(Some(file.path())).expect_err("config should fail"); + + assert!(error.to_string().contains("fallback role")); } #[test] diff --git a/src/formatting.rs b/src/formatting.rs index e13a5a1..3942b43 100644 --- a/src/formatting.rs +++ b/src/formatting.rs @@ -2,7 +2,8 @@ use crate::commands::{ CommandRiskPolicy, parse_command_suggestions_with_policy, render_suggestions, }; -pub fn render_assistant_output(response: &str) -> String { +#[cfg(test)] +fn render_assistant_output(response: &str) -> String { render_assistant_output_with_policy(response, &CommandRiskPolicy::default()) } diff --git a/src/keybindings.rs b/src/keybindings.rs new file mode 100644 index 0000000..fba58ca --- /dev/null +++ b/src/keybindings.rs @@ -0,0 +1,75 @@ +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct Keybinding { + pub key: &'static str, + pub action: &'static str, + pub fallback: &'static str, +} + +pub const BASIC_KEYBINDINGS: &[Keybinding] = &[ + Keybinding { + key: "Enter", + action: "send the current prompt", + fallback: "type a prompt and press Enter", + }, + Keybinding { + key: "Ctrl+C", + action: "interrupt the current terminal operation", + fallback: "keyboard interrupt remains handled by the terminal", + }, + Keybinding { + key: "copy", + action: "copy or print a suggested command", + fallback: "/copy ", + }, + Keybinding { + key: "explain", + action: "explain a suggested command", + fallback: "/explain ", + }, + Keybinding { + key: "discard", + action: "discard a suggested command", + fallback: "/discard ", + }, + Keybinding { + key: "context", + action: "show attached context", + fallback: "/context", + }, + Keybinding { + key: "stance", + action: "show or change stance", + fallback: "/stance", + }, +]; + +pub fn render_keybindings() -> String { + let mut rendered = String::from("Keybindings and fallbacks\n"); + rendered.push_str("Advanced terminal key handling is not active in the line REPL.\n"); + rendered.push_str("Use these slash commands when direct keybindings are unavailable.\n\n"); + + for binding in BASIC_KEYBINDINGS { + rendered.push_str(&format!( + "- {}: {}; fallback: {}\n", + binding.key, binding.action, binding.fallback + )); + } + + rendered.trim_end().to_string() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn renders_keybinding_fallbacks() { + let output = render_keybindings(); + + assert!(output.contains("/copy ")); + assert!(output.contains("/explain ")); + assert!(output.contains("/discard ")); + assert!(output.contains("/context")); + assert!(output.contains("Ctrl+C")); + } +} diff --git a/src/main.rs b/src/main.rs index 14b823f..b8f0b58 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,6 +3,7 @@ mod commands; mod config; pub mod context; mod formatting; +mod keybindings; mod prompts; mod providers; mod repl; @@ -14,6 +15,7 @@ use std::io::{IsTerminal, Read}; use crate::app::{App, CliOptions}; use crate::config::Config; use crate::providers::openai_compatible::OpenAiCompatibleProvider; +use crate::providers::router::ModelRouterProvider; use crate::repl::Repl; #[tokio::main] @@ -34,8 +36,16 @@ async fn run() -> Result<(), app::AppError> { let mut config = Config::load(options.config_path.as_deref())?; config.apply_cli_overrides(&options)?; - let provider = OpenAiCompatibleProvider::from_config(&config)?; - let mut app = App::new(config, Box::new(provider)); + let base_provider = OpenAiCompatibleProvider::from_config(&config)?; + let provider: Box = if config.router.enabled { + Box::new(ModelRouterProvider::new( + base_provider, + config.router.clone(), + )?) + } else { + Box::new(base_provider) + }; + let mut app = App::new(config, provider); for note in options.context_notes { println!("{}", app.add_note_context(note)?); diff --git a/src/prompts.rs b/src/prompts.rs index 720a573..eb77440 100644 --- a/src/prompts.rs +++ b/src/prompts.rs @@ -5,9 +5,10 @@ use crate::context::{ContextBudget, ContextEntry, ContextSize, render_prompt_con use crate::providers::{ChatMessage, ChatRole}; use crate::shell::ShellFamily; -#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)] #[serde(rename_all = "snake_case")] pub enum Stance { + #[default] Operator, Audit, Teach, @@ -47,12 +48,6 @@ impl Stance { } } -impl Default for Stance { - fn default() -> Self { - Self::Operator - } -} - impl fmt::Display for Stance { fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { match self { diff --git a/src/providers/mod.rs b/src/providers/mod.rs index 5414442..0ea56f2 100644 --- a/src/providers/mod.rs +++ b/src/providers/mod.rs @@ -1,4 +1,5 @@ pub mod openai_compatible; +pub mod router; #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] pub struct ChatMessage { @@ -35,9 +36,20 @@ pub enum ChatResponse { Stream(Vec), } +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ModelRoute { + pub role: String, + pub model: String, + pub reason: String, +} + #[async_trait::async_trait] pub trait Provider: Send + Sync { async fn chat(&self, request: ChatRequest) -> Result; + + fn last_model_route(&self) -> Option { + None + } } #[derive(Debug, thiserror::Error)] diff --git a/src/providers/openai_compatible.rs b/src/providers/openai_compatible.rs index b162cf4..2932806 100644 --- a/src/providers/openai_compatible.rs +++ b/src/providers/openai_compatible.rs @@ -53,17 +53,25 @@ impl OpenAiCompatibleProvider { fn chat_url(&self) -> String { format!("{}/chat/completions", self.base_url) } -} -#[async_trait::async_trait] -impl Provider for OpenAiCompatibleProvider { - async fn chat(&self, request: ChatRequest) -> Result { + pub async fn chat_with_model( + &self, + request: ChatRequest, + model: &str, + ) -> Result { let payload = ChatCompletionRequest { - model: self.model.clone(), + model: model.to_string(), messages: request.messages, stream: request.stream, }; + self.send_chat_completion(payload).await + } + + async fn send_chat_completion( + &self, + payload: ChatCompletionRequest, + ) -> Result { let response = self .client .post(self.chat_url()) @@ -108,6 +116,19 @@ impl Provider for OpenAiCompatibleProvider { } } +#[async_trait::async_trait] +impl Provider for OpenAiCompatibleProvider { + async fn chat(&self, request: ChatRequest) -> Result { + let payload = ChatCompletionRequest { + model: self.model.clone(), + messages: request.messages, + stream: request.stream, + }; + + self.send_chat_completion(payload).await + } +} + async fn read_streaming_response( response: reqwest::Response, ) -> Result { diff --git a/src/providers/router.rs b/src/providers/router.rs new file mode 100644 index 0000000..68056f6 --- /dev/null +++ b/src/providers/router.rs @@ -0,0 +1,333 @@ +use std::fmt; +use std::sync::Mutex; + +use serde::Deserialize; + +use crate::providers::openai_compatible::OpenAiCompatibleProvider; +use crate::providers::{ + ChatMessage, ChatRequest, ChatResponse, ChatRole, ModelRoute, Provider, ProviderError, +}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ModelRouterConfig { + pub enabled: bool, + pub model: String, + pub fallback_role: String, + pub behavior: String, + pub roles: Vec, +} + +impl ModelRouterConfig { + pub fn validate(&self) -> Result<(), ModelRouterConfigError> { + if self.model.trim().is_empty() { + return Err(ModelRouterConfigError::Invalid( + "router model is empty".into(), + )); + } + if self.fallback_role.trim().is_empty() { + return Err(ModelRouterConfigError::Invalid( + "router fallback role is empty".into(), + )); + } + if self.behavior.trim().is_empty() { + return Err(ModelRouterConfigError::Invalid( + "router behavior is empty".into(), + )); + } + if self.roles.is_empty() { + return Err(ModelRouterConfigError::Invalid( + "router requires at least one role".into(), + )); + } + + for role in &self.roles { + if role.name.trim().is_empty() { + return Err(ModelRouterConfigError::Invalid( + "router role name is empty".into(), + )); + } + if role.model.trim().is_empty() { + return Err(ModelRouterConfigError::Invalid(format!( + "router role '{}' has an empty model", + role.name + ))); + } + if role.description.trim().is_empty() { + return Err(ModelRouterConfigError::Invalid(format!( + "router role '{}' has an empty description", + role.name + ))); + } + } + + if self.role(&self.fallback_role).is_none() { + return Err(ModelRouterConfigError::Invalid(format!( + "router fallback role '{}' is not defined", + self.fallback_role + ))); + } + + Ok(()) + } + + pub fn role(&self, name: &str) -> Option<&ModelRouterRole> { + self.roles.iter().find(|role| role.name == name) + } + + fn fallback_route(&self, reason: impl Into) -> ModelRoute { + let role = self + .role(&self.fallback_role) + .expect("validated router fallback role exists"); + ModelRoute { + role: role.name.clone(), + model: role.model.clone(), + reason: reason.into(), + } + } +} + +impl Default for ModelRouterConfig { + fn default() -> Self { + Self { + enabled: false, + model: "qwen2.5-coder:7b".into(), + fallback_role: "coding".into(), + behavior: default_router_behavior().into(), + roles: default_router_roles(), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ModelRouterRole { + pub name: String, + pub model: String, + pub description: String, +} + +#[derive(Debug, thiserror::Error, PartialEq, Eq)] +pub enum ModelRouterConfigError { + #[error("invalid model router config: {0}")] + Invalid(String), +} + +pub struct ModelRouterProvider { + provider: OpenAiCompatibleProvider, + config: ModelRouterConfig, + last_route: Mutex>, +} + +impl ModelRouterProvider { + pub fn new( + provider: OpenAiCompatibleProvider, + config: ModelRouterConfig, + ) -> Result { + config + .validate() + .map_err(|error| ProviderError::Configuration(error.to_string()))?; + Ok(Self { + provider, + config, + last_route: Mutex::new(None), + }) + } + + async fn route(&self, request: &ChatRequest) -> ModelRoute { + let router_request = ChatRequest { + messages: vec![ + ChatMessage::new(ChatRole::System, self.router_system_prompt()), + ChatMessage::new(ChatRole::User, router_user_prompt(request)), + ], + stream: false, + }; + + let response = self + .provider + .chat_with_model(router_request, &self.config.model) + .await; + + match response { + Ok(ChatResponse::Complete(content)) => parse_router_response(&content, &self.config) + .unwrap_or_else(|| { + self.config + .fallback_route(format!("router returned unrecognized role: {content}")) + }), + Ok(ChatResponse::Stream(chunks)) => { + let content = chunks.concat(); + parse_router_response(&content, &self.config).unwrap_or_else(|| { + self.config + .fallback_route(format!("router returned unrecognized role: {content}")) + }) + } + Err(error) => self + .config + .fallback_route(format!("router failed: {error}")), + } + } + + fn router_system_prompt(&self) -> String { + let mut prompt = String::new(); + prompt.push_str("You are Exoshell's model router. Choose exactly one role for the next assistant response.\n"); + prompt.push_str(&self.config.behavior); + prompt.push_str("\n\nAvailable roles:\n"); + for role in &self.config.roles { + prompt.push_str(&format!( + "- {}: {} (model: {})\n", + role.name, role.description, role.model + )); + } + prompt.push_str( + "\nRespond as compact JSON only: {\"role\":\"\",\"reason\":\"\"}", + ); + prompt + } +} + +#[async_trait::async_trait] +impl Provider for ModelRouterProvider { + async fn chat(&self, request: ChatRequest) -> Result { + let route = self.route(&request).await; + { + let mut last_route = self + .last_route + .lock() + .expect("model route lock should not be poisoned"); + *last_route = Some(route.clone()); + } + + self.provider.chat_with_model(request, &route.model).await + } + + fn last_model_route(&self) -> Option { + self.last_route + .lock() + .expect("model route lock should not be poisoned") + .clone() + } +} + +fn router_user_prompt(request: &ChatRequest) -> String { + let mut rendered = String::new(); + rendered.push_str("Route this request using the current prompt payload.\n\n"); + for message in &request.messages { + rendered.push_str(&format!("{:?}:\n{}\n\n", message.role, message.content)); + } + rendered +} + +fn parse_router_response(content: &str, config: &ModelRouterConfig) -> Option { + if let Ok(response) = serde_json::from_str::(content.trim()) + && let Some(role) = config.role(response.role.trim()) + { + return Some(ModelRoute { + role: role.name.clone(), + model: role.model.clone(), + reason: response + .reason + .unwrap_or_else(|| "router selected role".into()), + }); + } + + let lowered = content.to_ascii_lowercase(); + config.roles.iter().find_map(|role| { + if lowered.contains(&role.name.to_ascii_lowercase()) { + Some(ModelRoute { + role: role.name.clone(), + model: role.model.clone(), + reason: "router selected role from text response".into(), + }) + } else { + None + } + }) +} + +#[derive(Debug, Deserialize)] +struct RouterResponse { + role: String, + reason: Option, +} + +fn default_router_behavior() -> &'static str { + "Prefer the cheapest and fastest role that is likely to answer well. Use instant for simple routing, short shell questions, quick lookups, and low-risk responses. Use coding for source changes, debugging, tests, command construction, and repository work. Use heavy for complex architecture, multi-step reasoning, risky operational analysis, or large-context synthesis. Use conversational for general explanation, planning, and non-code discussion." +} + +fn default_router_roles() -> Vec { + vec![ + ModelRouterRole { + name: "instant".into(), + model: "qwen2.5-coder:7b".into(), + description: "fast responses for simple prompts and low-latency checks".into(), + }, + ModelRouterRole { + name: "coding".into(), + model: "coder-g4-26b".into(), + description: "code edits, debugging, tests, and shell command construction".into(), + }, + ModelRouterRole { + name: "heavy".into(), + model: "coder-g4-26b".into(), + description: "complex reasoning, architecture, and high-context technical work".into(), + }, + ModelRouterRole { + name: "conversational".into(), + model: "qwen2.5-coder:7b".into(), + description: "general discussion, planning, and explanations".into(), + }, + ] +} + +impl fmt::Debug for ModelRouterProvider { + fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + formatter + .debug_struct("ModelRouterProvider") + .field("config", &self.config) + .finish_non_exhaustive() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn default_router_roles_match_expected_models() { + let config = ModelRouterConfig::default(); + + assert_eq!( + config.role("instant").expect("instant").model, + "qwen2.5-coder:7b" + ); + assert_eq!( + config.role("conversational").expect("conversational").model, + "qwen2.5-coder:7b" + ); + assert_eq!(config.role("coding").expect("coding").model, "coder-g4-26b"); + assert_eq!(config.role("heavy").expect("heavy").model, "coder-g4-26b"); + config.validate().expect("default config validates"); + } + + #[test] + fn parses_json_router_response() { + let config = ModelRouterConfig::default(); + let route = parse_router_response( + r#"{"role":"heavy","reason":"large context architecture"}"#, + &config, + ) + .expect("route"); + + assert_eq!(route.role, "heavy"); + assert_eq!(route.model, "coder-g4-26b"); + assert_eq!(route.reason, "large context architecture"); + } + + #[test] + fn rejects_missing_fallback_role() { + let config = ModelRouterConfig { + fallback_role: "missing".into(), + ..ModelRouterConfig::default() + }; + + assert!(config.validate().is_err()); + } +} diff --git a/src/repl.rs b/src/repl.rs index 3757e47..f9fef4c 100644 --- a/src/repl.rs +++ b/src/repl.rs @@ -50,6 +50,7 @@ impl Repl { || input.starts_with("/explain ") || input.starts_with("/discard ") || input.starts_with("/help") + || input == "/keys" || input == "/panel" || input.starts_with("/add-note ") || input.starts_with("/add-file ") diff --git a/src/transcripts.rs b/src/transcripts.rs index e77e6ca..417df59 100644 --- a/src/transcripts.rs +++ b/src/transcripts.rs @@ -5,6 +5,7 @@ use std::time::{SystemTime, UNIX_EPOCH}; use crate::commands::CommandSuggestion; use crate::context::{ContextEntry, redacted_provider_details}; use crate::prompts::Stance; +use crate::providers::ModelRoute; use crate::shell::ShellFamily; #[derive(Debug, Clone)] @@ -93,6 +94,14 @@ impl Transcript { }); } + pub fn record_model_route(&mut self, route: &ModelRoute) { + self.entries.push(TranscriptEntry::ModelRoute { + role: route.role.clone(), + model: route.model.clone(), + reason: route.reason.clone(), + }); + } + pub fn write_to_dir(&self, directory: &Path) -> Result { fs::create_dir_all(directory).map_err(|error| TranscriptError::CreateDir { path: directory.to_path_buf(), @@ -209,6 +218,16 @@ impl Transcript { markdown.push_str(&format!("- action: `{action}`\n")); markdown.push_str(&format!("- note: `{note}`\n\n")); } + TranscriptEntry::ModelRoute { + role, + model, + reason, + } => { + markdown.push_str("## Model Route\n\n"); + markdown.push_str(&format!("- role: `{role}`\n")); + markdown.push_str(&format!("- model: `{model}`\n")); + markdown.push_str(&format!("- reason: `{reason}`\n\n")); + } } } @@ -253,6 +272,11 @@ enum TranscriptEntry { action: String, note: String, }, + ModelRoute { + role: String, + model: String, + reason: String, + }, } fn unix_millis() -> u128 {