mirror of
https://github.com/khodges42/exoshell.git
synced 2026-06-14 18:08:37 +00:00
470 lines
14 KiB
Rust
470 lines
14 KiB
Rust
use std::env;
|
|
use std::fs;
|
|
use std::path::{Path, PathBuf};
|
|
|
|
use serde::Deserialize;
|
|
|
|
use crate::app::CliOptions;
|
|
use crate::context::ContextBudget;
|
|
use crate::prompts::Stance;
|
|
use crate::shell::ShellFamily;
|
|
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub struct Config {
|
|
pub provider: ProviderConfig,
|
|
pub shell: ShellConfig,
|
|
pub interaction: InteractionConfig,
|
|
pub transcript: TranscriptConfig,
|
|
pub context: ContextConfig,
|
|
}
|
|
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub struct ProviderConfig {
|
|
pub base_url: String,
|
|
pub api_key: String,
|
|
pub api_key_env: String,
|
|
pub model: String,
|
|
pub request_timeout_seconds: u64,
|
|
}
|
|
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub struct ShellConfig {
|
|
pub family: ShellFamily,
|
|
}
|
|
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub struct InteractionConfig {
|
|
pub stance: Stance,
|
|
}
|
|
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub struct TranscriptConfig {
|
|
pub directory: PathBuf,
|
|
pub enabled: bool,
|
|
}
|
|
|
|
#[derive(Debug, Clone, PartialEq, Eq, Default)]
|
|
pub struct ContextConfig {
|
|
pub max_characters: Option<usize>,
|
|
pub max_estimated_tokens: Option<usize>,
|
|
}
|
|
|
|
impl ContextConfig {
|
|
pub fn budget(&self) -> ContextBudget {
|
|
ContextBudget {
|
|
max_characters: self.max_characters,
|
|
max_estimated_tokens: self.max_estimated_tokens,
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Deserialize, Default)]
|
|
struct RawConfig {
|
|
provider: Option<RawProviderConfig>,
|
|
shell: Option<RawShellConfig>,
|
|
interaction: Option<RawInteractionConfig>,
|
|
transcript: Option<RawTranscriptConfig>,
|
|
context: Option<RawContextConfig>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize, Default)]
|
|
struct RawProviderConfig {
|
|
base_url: Option<String>,
|
|
api_key_env: Option<String>,
|
|
model: Option<String>,
|
|
request_timeout_seconds: Option<u64>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize, Default)]
|
|
struct RawShellConfig {
|
|
family: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize, Default)]
|
|
struct RawInteractionConfig {
|
|
stance: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize, Default)]
|
|
struct RawTranscriptConfig {
|
|
directory: Option<PathBuf>,
|
|
enabled: Option<bool>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize, Default)]
|
|
struct RawContextConfig {
|
|
max_characters: Option<usize>,
|
|
max_estimated_tokens: Option<usize>,
|
|
}
|
|
|
|
impl Config {
|
|
pub fn load(path: Option<&Path>) -> Result<Self, ConfigError> {
|
|
let raw = match path {
|
|
Some(path) => RawConfig::from_path(path)?,
|
|
None => RawConfig::from_default_path()?.unwrap_or_default(),
|
|
};
|
|
|
|
Self::from_raw(raw)
|
|
}
|
|
|
|
fn from_raw(raw: RawConfig) -> Result<Self, ConfigError> {
|
|
let provider = raw.provider.unwrap_or_default();
|
|
let shell = raw.shell.unwrap_or_default();
|
|
let interaction = raw.interaction.unwrap_or_default();
|
|
let transcript = raw.transcript.unwrap_or_default();
|
|
let context = raw.context.unwrap_or_default();
|
|
|
|
let base_url = provider
|
|
.base_url
|
|
.unwrap_or_else(|| "https://api.openai.com/v1".into());
|
|
let api_key_env = provider
|
|
.api_key_env
|
|
.unwrap_or_else(|| "OPENAI_API_KEY".into());
|
|
let api_key = provider_api_key(&base_url, &api_key_env)?;
|
|
|
|
let family = shell.family.unwrap_or_else(default_shell_family);
|
|
let family = family
|
|
.parse::<ShellFamily>()
|
|
.map_err(|error| ConfigError::Invalid(error.to_string()))?;
|
|
let stance = interaction
|
|
.stance
|
|
.unwrap_or_else(|| Stance::default().to_string())
|
|
.parse::<Stance>()
|
|
.map_err(|error| ConfigError::Invalid(error.to_string()))?;
|
|
|
|
Ok(Self {
|
|
provider: ProviderConfig {
|
|
base_url,
|
|
api_key,
|
|
api_key_env,
|
|
model: provider.model.unwrap_or_else(|| "gpt-4.1-mini".into()),
|
|
request_timeout_seconds: provider.request_timeout_seconds.unwrap_or(120),
|
|
},
|
|
shell: ShellConfig { family },
|
|
interaction: InteractionConfig { stance },
|
|
transcript: TranscriptConfig {
|
|
directory: transcript.directory.unwrap_or_else(default_transcript_dir),
|
|
enabled: transcript.enabled.unwrap_or(true),
|
|
},
|
|
context: ContextConfig {
|
|
max_characters: context.max_characters,
|
|
max_estimated_tokens: context.max_estimated_tokens,
|
|
},
|
|
})
|
|
}
|
|
|
|
pub fn apply_cli_overrides(&mut self, options: &CliOptions) -> Result<(), ConfigError> {
|
|
if let Some(shell_family) = options.shell_family {
|
|
self.shell.family = shell_family;
|
|
}
|
|
|
|
if let Some(stance) = options.stance {
|
|
self.interaction.stance = stance;
|
|
}
|
|
|
|
if let Some(transcript_enabled) = options.transcript_enabled {
|
|
self.transcript.enabled = transcript_enabled;
|
|
}
|
|
|
|
if let Some(transcript_directory) = &options.transcript_directory {
|
|
self.transcript.directory = transcript_directory.clone();
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
impl RawConfig {
|
|
fn from_path(path: &Path) -> Result<Self, ConfigError> {
|
|
let contents = fs::read_to_string(path).map_err(|error| ConfigError::Read {
|
|
path: path.to_path_buf(),
|
|
error,
|
|
})?;
|
|
|
|
toml::from_str(&contents).map_err(ConfigError::Parse)
|
|
}
|
|
|
|
fn from_default_path() -> Result<Option<Self>, ConfigError> {
|
|
let path = default_config_path();
|
|
if !path.exists() {
|
|
return Ok(None);
|
|
}
|
|
|
|
Self::from_path(&path).map(Some)
|
|
}
|
|
}
|
|
|
|
fn default_shell_family() -> String {
|
|
ShellFamily::default_for_platform().to_string()
|
|
}
|
|
|
|
fn provider_api_key(base_url: &str, api_key_env: &str) -> Result<String, ConfigError> {
|
|
match env::var(api_key_env) {
|
|
Ok(value) => Ok(value),
|
|
Err(_) if is_local_provider_url(base_url) => Ok("exoshell-local-provider".into()),
|
|
Err(_) => Err(ConfigError::MissingApiKey(format!(
|
|
"set {api_key_env} or configure provider.api_key_env"
|
|
))),
|
|
}
|
|
}
|
|
|
|
fn is_local_provider_url(base_url: &str) -> bool {
|
|
let Some(host) = provider_host(base_url) else {
|
|
return false;
|
|
};
|
|
|
|
matches!(host.as_str(), "localhost" | "127.0.0.1" | "::1" | "0.0.0.0")
|
|
}
|
|
|
|
fn provider_host(base_url: &str) -> Option<String> {
|
|
let after_scheme = base_url
|
|
.strip_prefix("http://")
|
|
.or_else(|| base_url.strip_prefix("https://"))
|
|
.unwrap_or(base_url);
|
|
let authority = after_scheme.split('/').next()?.trim();
|
|
|
|
if authority.starts_with('[') {
|
|
return authority
|
|
.split(']')
|
|
.next()
|
|
.map(|host| host.trim_start_matches('[').to_ascii_lowercase());
|
|
}
|
|
|
|
authority
|
|
.split(':')
|
|
.next()
|
|
.filter(|host| !host.is_empty())
|
|
.map(|host| host.to_ascii_lowercase())
|
|
}
|
|
|
|
fn default_config_path() -> PathBuf {
|
|
if cfg!(windows) {
|
|
env::var_os("APPDATA")
|
|
.map(PathBuf::from)
|
|
.unwrap_or_else(|| PathBuf::from("."))
|
|
.join("exoshell")
|
|
.join("config.toml")
|
|
} else {
|
|
env::var_os("XDG_CONFIG_HOME")
|
|
.map(PathBuf::from)
|
|
.or_else(|| env::var_os("HOME").map(|home| PathBuf::from(home).join(".config")))
|
|
.unwrap_or_else(|| PathBuf::from("."))
|
|
.join("exoshell")
|
|
.join("config.toml")
|
|
}
|
|
}
|
|
|
|
fn default_transcript_dir() -> PathBuf {
|
|
if cfg!(windows) {
|
|
env::var_os("APPDATA")
|
|
.map(PathBuf::from)
|
|
.unwrap_or_else(|| PathBuf::from("."))
|
|
.join("exoshell")
|
|
.join("transcripts")
|
|
} else {
|
|
env::var_os("XDG_DATA_HOME")
|
|
.map(PathBuf::from)
|
|
.or_else(|| env::var_os("HOME").map(|home| PathBuf::from(home).join(".local/share")))
|
|
.unwrap_or_else(|| PathBuf::from("."))
|
|
.join("exoshell")
|
|
.join("transcripts")
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, thiserror::Error)]
|
|
pub enum ConfigError {
|
|
#[error("failed to read config at {path}: {error}")]
|
|
Read {
|
|
path: PathBuf,
|
|
#[source]
|
|
error: std::io::Error,
|
|
},
|
|
#[error("failed to parse config: {0}")]
|
|
Parse(toml::de::Error),
|
|
#[error("missing provider API key: {0}")]
|
|
MissingApiKey(String),
|
|
#[error("invalid config: {0}")]
|
|
Invalid(String),
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use std::io::Write;
|
|
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn loads_defaults_from_environment() {
|
|
unsafe {
|
|
env::set_var("EXOSHELL_TEST_KEY", "secret");
|
|
}
|
|
|
|
let config = Config::from_raw(RawConfig {
|
|
provider: Some(RawProviderConfig {
|
|
api_key_env: Some("EXOSHELL_TEST_KEY".into()),
|
|
..RawProviderConfig::default()
|
|
}),
|
|
..RawConfig::default()
|
|
})
|
|
.expect("config loads");
|
|
|
|
assert_eq!(config.provider.api_key, "secret");
|
|
assert_eq!(config.provider.base_url, "https://api.openai.com/v1");
|
|
assert_eq!(config.provider.model, "gpt-4.1-mini");
|
|
}
|
|
|
|
#[test]
|
|
fn rejects_unknown_shell_family() {
|
|
unsafe {
|
|
env::set_var("EXOSHELL_TEST_KEY", "secret");
|
|
}
|
|
|
|
let error = Config::from_raw(RawConfig {
|
|
provider: Some(RawProviderConfig {
|
|
api_key_env: Some("EXOSHELL_TEST_KEY".into()),
|
|
..RawProviderConfig::default()
|
|
}),
|
|
shell: Some(RawShellConfig {
|
|
family: Some("cmd".into()),
|
|
}),
|
|
interaction: None,
|
|
transcript: None,
|
|
context: None,
|
|
})
|
|
.expect_err("shell family should be rejected");
|
|
|
|
assert!(matches!(error, ConfigError::Invalid(_)));
|
|
}
|
|
|
|
#[test]
|
|
fn loads_toml_config_file() {
|
|
let mut file = tempfile::NamedTempFile::new().expect("temp config");
|
|
write!(
|
|
file,
|
|
r#"
|
|
[provider]
|
|
base_url = "http://localhost:11434/v1"
|
|
model = "local-model"
|
|
request_timeout_seconds = 45
|
|
|
|
[shell]
|
|
family = "posix"
|
|
|
|
[interaction]
|
|
stance = "audit"
|
|
|
|
[transcript]
|
|
enabled = false
|
|
|
|
[context]
|
|
max_characters = 12000
|
|
max_estimated_tokens = 3000
|
|
"#
|
|
)
|
|
.expect("write config");
|
|
|
|
let config = Config::load(Some(file.path())).expect("config loads");
|
|
|
|
assert_eq!(config.provider.base_url, "http://localhost:11434/v1");
|
|
assert_eq!(config.provider.api_key, "exoshell-local-provider");
|
|
assert_eq!(config.provider.model, "local-model");
|
|
assert_eq!(config.provider.request_timeout_seconds, 45);
|
|
assert_eq!(config.shell.family, ShellFamily::Posix);
|
|
assert_eq!(config.interaction.stance, Stance::Audit);
|
|
assert!(!config.transcript.enabled);
|
|
assert_eq!(config.context.max_characters, Some(12000));
|
|
assert_eq!(config.context.max_estimated_tokens, Some(3000));
|
|
assert_eq!(config.context.budget().max_characters, Some(12000));
|
|
}
|
|
|
|
#[test]
|
|
fn context_budget_defaults_to_unlimited() {
|
|
unsafe {
|
|
env::set_var("EXOSHELL_TEST_KEY", "secret");
|
|
}
|
|
|
|
let config = Config::from_raw(RawConfig {
|
|
provider: Some(RawProviderConfig {
|
|
api_key_env: Some("EXOSHELL_TEST_KEY".into()),
|
|
..RawProviderConfig::default()
|
|
}),
|
|
..RawConfig::default()
|
|
})
|
|
.expect("config loads");
|
|
|
|
assert_eq!(config.context.max_characters, None);
|
|
assert_eq!(config.context.max_estimated_tokens, None);
|
|
assert_eq!(config.provider.request_timeout_seconds, 120);
|
|
}
|
|
|
|
#[test]
|
|
fn applies_cli_overrides() {
|
|
unsafe {
|
|
env::set_var("EXOSHELL_TEST_KEY", "secret");
|
|
}
|
|
|
|
let mut config = Config::from_raw(RawConfig {
|
|
provider: Some(RawProviderConfig {
|
|
api_key_env: Some("EXOSHELL_TEST_KEY".into()),
|
|
..RawProviderConfig::default()
|
|
}),
|
|
..RawConfig::default()
|
|
})
|
|
.expect("config loads");
|
|
|
|
let tempdir = PathBuf::from("manual-transcripts");
|
|
let options = CliOptions {
|
|
shell_family: Some(ShellFamily::Posix),
|
|
stance: Some(Stance::Teach),
|
|
transcript_enabled: Some(false),
|
|
transcript_directory: Some(tempdir.clone()),
|
|
..CliOptions::default()
|
|
};
|
|
|
|
config.apply_cli_overrides(&options).expect("overrides");
|
|
|
|
assert_eq!(config.shell.family, ShellFamily::Posix);
|
|
assert_eq!(config.interaction.stance, Stance::Teach);
|
|
assert!(!config.transcript.enabled);
|
|
assert_eq!(config.transcript.directory, tempdir);
|
|
}
|
|
|
|
#[test]
|
|
fn local_provider_urls_do_not_require_api_key_env() {
|
|
let config = Config::from_raw(RawConfig {
|
|
provider: Some(RawProviderConfig {
|
|
base_url: Some("http://127.0.0.1:11434/v1".into()),
|
|
api_key_env: Some("EXOSHELL_MISSING_LOCAL_KEY".into()),
|
|
..RawProviderConfig::default()
|
|
}),
|
|
..RawConfig::default()
|
|
})
|
|
.expect("local config should load without key");
|
|
|
|
assert_eq!(config.provider.api_key, "exoshell-local-provider");
|
|
}
|
|
|
|
#[test]
|
|
fn hosted_provider_urls_require_api_key_env() {
|
|
let error = Config::from_raw(RawConfig {
|
|
provider: Some(RawProviderConfig {
|
|
base_url: Some("https://api.openai.com/v1".into()),
|
|
api_key_env: Some("EXOSHELL_MISSING_HOSTED_KEY".into()),
|
|
..RawProviderConfig::default()
|
|
}),
|
|
..RawConfig::default()
|
|
})
|
|
.expect_err("hosted config should require key");
|
|
|
|
assert!(matches!(error, ConfigError::MissingApiKey(_)));
|
|
}
|
|
|
|
#[test]
|
|
fn detects_local_provider_hosts() {
|
|
assert!(is_local_provider_url("http://localhost:11434/v1"));
|
|
assert!(is_local_provider_url("http://127.0.0.1:11434/v1"));
|
|
assert!(is_local_provider_url("http://[::1]:11434/v1"));
|
|
assert!(!is_local_provider_url("https://api.openai.com/v1"));
|
|
}
|
|
}
|