mirror of
https://github.com/instructkr/claw-code.git
synced 2026-06-29 13:49:04 -04:00
Compare commits
12 Commits
8d4a739c05
...
rcc/memory
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
549deb9a89 | ||
|
|
d6341d54c1 | ||
|
|
863958b94c | ||
|
|
9455280f24 | ||
|
|
c92403994d | ||
|
|
e2f061fd08 | ||
|
|
c139fe9bee | ||
|
|
842abcfe85 | ||
|
|
807e29c8a1 | ||
|
|
32e89df631 | ||
|
|
1f8cfbce38 | ||
|
|
1e5002b521 |
1
.claude/sessions/session-1774998936453.json
Normal file
1
.claude/sessions/session-1774998936453.json
Normal file
@@ -0,0 +1 @@
|
||||
{"messages":[],"version":1}
|
||||
1
.claude/sessions/session-1774998994373.json
Normal file
1
.claude/sessions/session-1774998994373.json
Normal file
@@ -0,0 +1 @@
|
||||
{"messages":[{"blocks":[{"text":"Say hello in one sentence","type":"text"}],"role":"user"},{"blocks":[{"text":"Hello! I'm Claude, an AI assistant ready to help you with software engineering tasks, code analysis, debugging, or any other programming challenges you might have.","type":"text"}],"role":"assistant","usage":{"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"input_tokens":11,"output_tokens":32}}],"version":1}
|
||||
1
rust/Cargo.lock
generated
1
rust/Cargo.lock
generated
@@ -22,6 +22,7 @@ name = "api"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"reqwest",
|
||||
"runtime",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tokio",
|
||||
|
||||
@@ -64,6 +64,26 @@ cd rust
|
||||
cargo run -p rusty-claude-cli -- --version
|
||||
```
|
||||
|
||||
### Login with OAuth
|
||||
|
||||
Configure `settings.json` with an `oauth` block containing `clientId`, `authorizeUrl`, `tokenUrl`, optional `callbackPort`, and optional `scopes`, then run:
|
||||
|
||||
```bash
|
||||
cd rust
|
||||
cargo run -p rusty-claude-cli -- login
|
||||
```
|
||||
|
||||
This opens the browser, listens on the configured localhost callback, exchanges the auth code for tokens, and stores OAuth credentials in `~/.claude/credentials.json` (or `$CLAUDE_CONFIG_HOME/credentials.json`).
|
||||
|
||||
### Logout
|
||||
|
||||
```bash
|
||||
cd rust
|
||||
cargo run -p rusty-claude-cli -- logout
|
||||
```
|
||||
|
||||
This removes only the stored OAuth credentials and preserves unrelated JSON fields in `credentials.json`.
|
||||
|
||||
## Usage examples
|
||||
|
||||
### 1) Prompt mode
|
||||
@@ -170,8 +190,9 @@ cargo run -p rusty-claude-cli -- --resume session.json /memory /config
|
||||
|
||||
### Anthropic/API
|
||||
|
||||
- `ANTHROPIC_AUTH_TOKEN` — preferred bearer token for API auth
|
||||
- `ANTHROPIC_API_KEY` — legacy API key fallback if auth token is unset
|
||||
- `ANTHROPIC_API_KEY` — highest-precedence API credential
|
||||
- `ANTHROPIC_AUTH_TOKEN` — bearer-token override used when no API key is set
|
||||
- Persisted OAuth credentials in `~/.claude/credentials.json` — used when neither env var is set
|
||||
- `ANTHROPIC_BASE_URL` — override the Anthropic API base URL
|
||||
- `ANTHROPIC_MODEL` — default model used by selected live integration tests
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ publish.workspace = true
|
||||
|
||||
[dependencies]
|
||||
reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls"] }
|
||||
runtime = { path = "../runtime" }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
tokio = { version = "1", features = ["io-util", "macros", "net", "rt-multi-thread", "time"] }
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
use std::collections::VecDeque;
|
||||
use std::time::Duration;
|
||||
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||
|
||||
use runtime::{
|
||||
load_oauth_credentials, save_oauth_credentials, OAuthConfig, OAuthRefreshRequest,
|
||||
OAuthTokenExchangeRequest,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::error::ApiError;
|
||||
@@ -81,11 +85,12 @@ impl AuthSource {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
|
||||
pub struct OAuthTokenSet {
|
||||
pub access_token: String,
|
||||
pub refresh_token: Option<String>,
|
||||
pub expires_at: Option<u64>,
|
||||
#[serde(default)]
|
||||
pub scopes: Vec<String>,
|
||||
}
|
||||
|
||||
@@ -131,7 +136,7 @@ impl AnthropicClient {
|
||||
}
|
||||
|
||||
pub fn from_env() -> Result<Self, ApiError> {
|
||||
Ok(Self::from_auth(AuthSource::from_env()?).with_base_url(read_base_url()))
|
||||
Ok(Self::from_auth(AuthSource::from_env_or_saved()?).with_base_url(read_base_url()))
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
@@ -225,6 +230,46 @@ impl AnthropicClient {
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn exchange_oauth_code(
|
||||
&self,
|
||||
config: &OAuthConfig,
|
||||
request: &OAuthTokenExchangeRequest,
|
||||
) -> Result<OAuthTokenSet, ApiError> {
|
||||
let response = self
|
||||
.http
|
||||
.post(&config.token_url)
|
||||
.header("content-type", "application/x-www-form-urlencoded")
|
||||
.form(&request.form_params())
|
||||
.send()
|
||||
.await
|
||||
.map_err(ApiError::from)?;
|
||||
let response = expect_success(response).await?;
|
||||
response
|
||||
.json::<OAuthTokenSet>()
|
||||
.await
|
||||
.map_err(ApiError::from)
|
||||
}
|
||||
|
||||
pub async fn refresh_oauth_token(
|
||||
&self,
|
||||
config: &OAuthConfig,
|
||||
request: &OAuthRefreshRequest,
|
||||
) -> Result<OAuthTokenSet, ApiError> {
|
||||
let response = self
|
||||
.http
|
||||
.post(&config.token_url)
|
||||
.header("content-type", "application/x-www-form-urlencoded")
|
||||
.form(&request.form_params())
|
||||
.send()
|
||||
.await
|
||||
.map_err(ApiError::from)?;
|
||||
let response = expect_success(response).await?;
|
||||
response
|
||||
.json::<OAuthTokenSet>()
|
||||
.await
|
||||
.map_err(ApiError::from)
|
||||
}
|
||||
|
||||
async fn send_with_retry(
|
||||
&self,
|
||||
request: &MessageRequest,
|
||||
@@ -304,6 +349,153 @@ impl AnthropicClient {
|
||||
}
|
||||
}
|
||||
|
||||
impl AuthSource {
|
||||
pub fn from_env_or_saved() -> Result<Self, ApiError> {
|
||||
if let Some(api_key) = read_env_non_empty("ANTHROPIC_API_KEY")? {
|
||||
return match read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
|
||||
Some(bearer_token) => Ok(Self::ApiKeyAndBearer {
|
||||
api_key,
|
||||
bearer_token,
|
||||
}),
|
||||
None => Ok(Self::ApiKey(api_key)),
|
||||
};
|
||||
}
|
||||
if let Some(bearer_token) = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
|
||||
return Ok(Self::BearerToken(bearer_token));
|
||||
}
|
||||
match load_saved_oauth_token() {
|
||||
Ok(Some(token_set)) if oauth_token_is_expired(&token_set) => {
|
||||
if token_set.refresh_token.is_some() {
|
||||
Err(ApiError::Auth(
|
||||
"saved OAuth token is expired; load runtime OAuth config to refresh it"
|
||||
.to_string(),
|
||||
))
|
||||
} else {
|
||||
Err(ApiError::ExpiredOAuthToken)
|
||||
}
|
||||
}
|
||||
Ok(Some(token_set)) => Ok(Self::BearerToken(token_set.access_token)),
|
||||
Ok(None) => Err(ApiError::MissingApiKey),
|
||||
Err(error) => Err(error),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn oauth_token_is_expired(token_set: &OAuthTokenSet) -> bool {
|
||||
token_set
|
||||
.expires_at
|
||||
.is_some_and(|expires_at| expires_at <= now_unix_timestamp())
|
||||
}
|
||||
|
||||
pub fn resolve_saved_oauth_token(config: &OAuthConfig) -> Result<Option<OAuthTokenSet>, ApiError> {
|
||||
let Some(token_set) = load_saved_oauth_token()? else {
|
||||
return Ok(None);
|
||||
};
|
||||
resolve_saved_oauth_token_set(config, token_set).map(Some)
|
||||
}
|
||||
|
||||
pub fn resolve_startup_auth_source<F>(load_oauth_config: F) -> Result<AuthSource, ApiError>
|
||||
where
|
||||
F: FnOnce() -> Result<Option<OAuthConfig>, ApiError>,
|
||||
{
|
||||
if let Some(api_key) = read_env_non_empty("ANTHROPIC_API_KEY")? {
|
||||
return match read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
|
||||
Some(bearer_token) => Ok(AuthSource::ApiKeyAndBearer {
|
||||
api_key,
|
||||
bearer_token,
|
||||
}),
|
||||
None => Ok(AuthSource::ApiKey(api_key)),
|
||||
};
|
||||
}
|
||||
if let Some(bearer_token) = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
|
||||
return Ok(AuthSource::BearerToken(bearer_token));
|
||||
}
|
||||
|
||||
let Some(token_set) = load_saved_oauth_token()? else {
|
||||
return Err(ApiError::MissingApiKey);
|
||||
};
|
||||
if !oauth_token_is_expired(&token_set) {
|
||||
return Ok(AuthSource::BearerToken(token_set.access_token));
|
||||
}
|
||||
if token_set.refresh_token.is_none() {
|
||||
return Err(ApiError::ExpiredOAuthToken);
|
||||
}
|
||||
|
||||
let Some(config) = load_oauth_config()? else {
|
||||
return Err(ApiError::Auth(
|
||||
"saved OAuth token is expired; runtime OAuth config is missing".to_string(),
|
||||
));
|
||||
};
|
||||
Ok(AuthSource::from(resolve_saved_oauth_token_set(
|
||||
&config, token_set,
|
||||
)?))
|
||||
}
|
||||
|
||||
fn resolve_saved_oauth_token_set(
|
||||
config: &OAuthConfig,
|
||||
token_set: OAuthTokenSet,
|
||||
) -> Result<OAuthTokenSet, ApiError> {
|
||||
if !oauth_token_is_expired(&token_set) {
|
||||
return Ok(token_set);
|
||||
}
|
||||
let Some(refresh_token) = token_set.refresh_token.clone() else {
|
||||
return Err(ApiError::ExpiredOAuthToken);
|
||||
};
|
||||
let client = AnthropicClient::from_auth(AuthSource::None).with_base_url(read_base_url());
|
||||
let refreshed = client_runtime_block_on(async {
|
||||
client
|
||||
.refresh_oauth_token(
|
||||
config,
|
||||
&OAuthRefreshRequest::from_config(
|
||||
config,
|
||||
refresh_token,
|
||||
Some(token_set.scopes.clone()),
|
||||
),
|
||||
)
|
||||
.await
|
||||
})?;
|
||||
let resolved = OAuthTokenSet {
|
||||
access_token: refreshed.access_token,
|
||||
refresh_token: refreshed.refresh_token.or(token_set.refresh_token),
|
||||
expires_at: refreshed.expires_at,
|
||||
scopes: refreshed.scopes,
|
||||
};
|
||||
save_oauth_credentials(&runtime::OAuthTokenSet {
|
||||
access_token: resolved.access_token.clone(),
|
||||
refresh_token: resolved.refresh_token.clone(),
|
||||
expires_at: resolved.expires_at,
|
||||
scopes: resolved.scopes.clone(),
|
||||
})
|
||||
.map_err(ApiError::from)?;
|
||||
Ok(resolved)
|
||||
}
|
||||
|
||||
fn client_runtime_block_on<F, T>(future: F) -> Result<T, ApiError>
|
||||
where
|
||||
F: std::future::Future<Output = Result<T, ApiError>>,
|
||||
{
|
||||
tokio::runtime::Runtime::new()
|
||||
.map_err(ApiError::from)?
|
||||
.block_on(future)
|
||||
}
|
||||
|
||||
fn load_saved_oauth_token() -> Result<Option<OAuthTokenSet>, ApiError> {
|
||||
let token_set = load_oauth_credentials().map_err(ApiError::from)?;
|
||||
Ok(token_set.map(|token_set| OAuthTokenSet {
|
||||
access_token: token_set.access_token,
|
||||
refresh_token: token_set.refresh_token,
|
||||
expires_at: token_set.expires_at,
|
||||
scopes: token_set.scopes,
|
||||
}))
|
||||
}
|
||||
|
||||
fn now_unix_timestamp() -> u64 {
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.map_or(0, |duration| duration.as_secs())
|
||||
}
|
||||
|
||||
fn read_env_non_empty(key: &str) -> Result<Option<String>, ApiError> {
|
||||
match std::env::var(key) {
|
||||
Ok(value) if !value.is_empty() => Ok(Some(value)),
|
||||
@@ -314,7 +506,7 @@ fn read_env_non_empty(key: &str) -> Result<Option<String>, ApiError> {
|
||||
|
||||
#[cfg(test)]
|
||||
fn read_api_key() -> Result<String, ApiError> {
|
||||
let auth = AuthSource::from_env()?;
|
||||
let auth = AuthSource::from_env_or_saved()?;
|
||||
auth.api_key()
|
||||
.or_else(|| auth.bearer_token())
|
||||
.map(ToOwned::to_owned)
|
||||
@@ -424,10 +616,18 @@ struct AnthropicErrorBody {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{ALT_REQUEST_ID_HEADER, REQUEST_ID_HEADER};
|
||||
use std::io::{Read, Write};
|
||||
use std::net::TcpListener;
|
||||
use std::sync::{Mutex, OnceLock};
|
||||
use std::time::Duration;
|
||||
use std::thread;
|
||||
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||
|
||||
use crate::client::{AuthSource, OAuthTokenSet};
|
||||
use runtime::{clear_oauth_credentials, save_oauth_credentials, OAuthConfig};
|
||||
|
||||
use crate::client::{
|
||||
now_unix_timestamp, oauth_token_is_expired, resolve_saved_oauth_token,
|
||||
resolve_startup_auth_source, AnthropicClient, AuthSource, OAuthTokenSet,
|
||||
};
|
||||
use crate::types::{ContentBlockDelta, MessageRequest};
|
||||
|
||||
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
|
||||
@@ -437,11 +637,53 @@ mod tests {
|
||||
.expect("env lock")
|
||||
}
|
||||
|
||||
fn temp_config_home() -> std::path::PathBuf {
|
||||
std::env::temp_dir().join(format!(
|
||||
"api-oauth-test-{}-{}",
|
||||
std::process::id(),
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.expect("time")
|
||||
.as_nanos()
|
||||
))
|
||||
}
|
||||
|
||||
fn sample_oauth_config(token_url: String) -> OAuthConfig {
|
||||
OAuthConfig {
|
||||
client_id: "runtime-client".to_string(),
|
||||
authorize_url: "https://console.test/oauth/authorize".to_string(),
|
||||
token_url,
|
||||
callback_port: Some(4545),
|
||||
manual_redirect_url: Some("https://console.test/oauth/callback".to_string()),
|
||||
scopes: vec!["org:read".to_string(), "user:write".to_string()],
|
||||
}
|
||||
}
|
||||
|
||||
fn spawn_token_server(response_body: &'static str) -> String {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").expect("bind listener");
|
||||
let address = listener.local_addr().expect("local addr");
|
||||
thread::spawn(move || {
|
||||
let (mut stream, _) = listener.accept().expect("accept connection");
|
||||
let mut buffer = [0_u8; 4096];
|
||||
let _ = stream.read(&mut buffer).expect("read request");
|
||||
let response = format!(
|
||||
"HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\n\r\n{}",
|
||||
response_body.len(),
|
||||
response_body
|
||||
);
|
||||
stream
|
||||
.write_all(response.as_bytes())
|
||||
.expect("write response");
|
||||
});
|
||||
format!("http://{address}/oauth/token")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_api_key_requires_presence() {
|
||||
let _guard = env_lock();
|
||||
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
|
||||
std::env::remove_var("ANTHROPIC_API_KEY");
|
||||
std::env::remove_var("CLAUDE_CONFIG_HOME");
|
||||
let error = super::read_api_key().expect_err("missing key should error");
|
||||
assert!(matches!(error, crate::error::ApiError::MissingApiKey));
|
||||
}
|
||||
@@ -453,6 +695,7 @@ mod tests {
|
||||
std::env::remove_var("ANTHROPIC_API_KEY");
|
||||
let error = super::read_api_key().expect_err("empty key should error");
|
||||
assert!(matches!(error, crate::error::ApiError::MissingApiKey));
|
||||
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -500,6 +743,166 @@ mod tests {
|
||||
std::env::remove_var("ANTHROPIC_API_KEY");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_source_from_saved_oauth_when_env_absent() {
|
||||
let _guard = env_lock();
|
||||
let config_home = temp_config_home();
|
||||
std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
|
||||
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
|
||||
std::env::remove_var("ANTHROPIC_API_KEY");
|
||||
save_oauth_credentials(&runtime::OAuthTokenSet {
|
||||
access_token: "saved-access-token".to_string(),
|
||||
refresh_token: Some("refresh".to_string()),
|
||||
expires_at: Some(now_unix_timestamp() + 300),
|
||||
scopes: vec!["scope:a".to_string()],
|
||||
})
|
||||
.expect("save oauth credentials");
|
||||
|
||||
let auth = AuthSource::from_env_or_saved().expect("saved auth");
|
||||
assert_eq!(auth.bearer_token(), Some("saved-access-token"));
|
||||
|
||||
clear_oauth_credentials().expect("clear credentials");
|
||||
std::env::remove_var("CLAUDE_CONFIG_HOME");
|
||||
std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn oauth_token_expiry_uses_expires_at_timestamp() {
|
||||
assert!(oauth_token_is_expired(&OAuthTokenSet {
|
||||
access_token: "access-token".to_string(),
|
||||
refresh_token: None,
|
||||
expires_at: Some(1),
|
||||
scopes: Vec::new(),
|
||||
}));
|
||||
assert!(!oauth_token_is_expired(&OAuthTokenSet {
|
||||
access_token: "access-token".to_string(),
|
||||
refresh_token: None,
|
||||
expires_at: Some(now_unix_timestamp() + 60),
|
||||
scopes: Vec::new(),
|
||||
}));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_saved_oauth_token_refreshes_expired_credentials() {
|
||||
let _guard = env_lock();
|
||||
let config_home = temp_config_home();
|
||||
std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
|
||||
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
|
||||
std::env::remove_var("ANTHROPIC_API_KEY");
|
||||
save_oauth_credentials(&runtime::OAuthTokenSet {
|
||||
access_token: "expired-access-token".to_string(),
|
||||
refresh_token: Some("refresh-token".to_string()),
|
||||
expires_at: Some(1),
|
||||
scopes: vec!["scope:a".to_string()],
|
||||
})
|
||||
.expect("save expired oauth credentials");
|
||||
|
||||
let token_url = spawn_token_server(
|
||||
"{\"access_token\":\"refreshed-token\",\"refresh_token\":\"fresh-refresh\",\"expires_at\":9999999999,\"scopes\":[\"scope:a\"]}",
|
||||
);
|
||||
let resolved = resolve_saved_oauth_token(&sample_oauth_config(token_url))
|
||||
.expect("resolve refreshed token")
|
||||
.expect("token set present");
|
||||
assert_eq!(resolved.access_token, "refreshed-token");
|
||||
let stored = runtime::load_oauth_credentials()
|
||||
.expect("load stored credentials")
|
||||
.expect("stored token set");
|
||||
assert_eq!(stored.access_token, "refreshed-token");
|
||||
|
||||
clear_oauth_credentials().expect("clear credentials");
|
||||
std::env::remove_var("CLAUDE_CONFIG_HOME");
|
||||
std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_startup_auth_source_uses_saved_oauth_without_loading_config() {
|
||||
let _guard = env_lock();
|
||||
let config_home = temp_config_home();
|
||||
std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
|
||||
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
|
||||
std::env::remove_var("ANTHROPIC_API_KEY");
|
||||
save_oauth_credentials(&runtime::OAuthTokenSet {
|
||||
access_token: "saved-access-token".to_string(),
|
||||
refresh_token: Some("refresh".to_string()),
|
||||
expires_at: Some(now_unix_timestamp() + 300),
|
||||
scopes: vec!["scope:a".to_string()],
|
||||
})
|
||||
.expect("save oauth credentials");
|
||||
|
||||
let auth = resolve_startup_auth_source(|| panic!("config should not be loaded"))
|
||||
.expect("startup auth");
|
||||
assert_eq!(auth.bearer_token(), Some("saved-access-token"));
|
||||
|
||||
clear_oauth_credentials().expect("clear credentials");
|
||||
std::env::remove_var("CLAUDE_CONFIG_HOME");
|
||||
std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_startup_auth_source_errors_when_refreshable_token_lacks_config() {
|
||||
let _guard = env_lock();
|
||||
let config_home = temp_config_home();
|
||||
std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
|
||||
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
|
||||
std::env::remove_var("ANTHROPIC_API_KEY");
|
||||
save_oauth_credentials(&runtime::OAuthTokenSet {
|
||||
access_token: "expired-access-token".to_string(),
|
||||
refresh_token: Some("refresh-token".to_string()),
|
||||
expires_at: Some(1),
|
||||
scopes: vec!["scope:a".to_string()],
|
||||
})
|
||||
.expect("save expired oauth credentials");
|
||||
|
||||
let error =
|
||||
resolve_startup_auth_source(|| Ok(None)).expect_err("missing config should error");
|
||||
assert!(
|
||||
matches!(error, crate::error::ApiError::Auth(message) if message.contains("runtime OAuth config is missing"))
|
||||
);
|
||||
|
||||
let stored = runtime::load_oauth_credentials()
|
||||
.expect("load stored credentials")
|
||||
.expect("stored token set");
|
||||
assert_eq!(stored.access_token, "expired-access-token");
|
||||
assert_eq!(stored.refresh_token.as_deref(), Some("refresh-token"));
|
||||
|
||||
clear_oauth_credentials().expect("clear credentials");
|
||||
std::env::remove_var("CLAUDE_CONFIG_HOME");
|
||||
std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_saved_oauth_token_preserves_refresh_token_when_refresh_response_omits_it() {
|
||||
let _guard = env_lock();
|
||||
let config_home = temp_config_home();
|
||||
std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
|
||||
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
|
||||
std::env::remove_var("ANTHROPIC_API_KEY");
|
||||
save_oauth_credentials(&runtime::OAuthTokenSet {
|
||||
access_token: "expired-access-token".to_string(),
|
||||
refresh_token: Some("refresh-token".to_string()),
|
||||
expires_at: Some(1),
|
||||
scopes: vec!["scope:a".to_string()],
|
||||
})
|
||||
.expect("save expired oauth credentials");
|
||||
|
||||
let token_url = spawn_token_server(
|
||||
"{\"access_token\":\"refreshed-token\",\"expires_at\":9999999999,\"scopes\":[\"scope:a\"]}",
|
||||
);
|
||||
let resolved = resolve_saved_oauth_token(&sample_oauth_config(token_url))
|
||||
.expect("resolve refreshed token")
|
||||
.expect("token set present");
|
||||
assert_eq!(resolved.access_token, "refreshed-token");
|
||||
assert_eq!(resolved.refresh_token.as_deref(), Some("refresh-token"));
|
||||
let stored = runtime::load_oauth_credentials()
|
||||
.expect("load stored credentials")
|
||||
.expect("stored token set");
|
||||
assert_eq!(stored.refresh_token.as_deref(), Some("refresh-token"));
|
||||
|
||||
clear_oauth_credentials().expect("clear credentials");
|
||||
std::env::remove_var("CLAUDE_CONFIG_HOME");
|
||||
std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn message_request_stream_helper_sets_stream_true() {
|
||||
let request = MessageRequest {
|
||||
@@ -517,7 +920,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn backoff_doubles_until_maximum() {
|
||||
let client = super::AnthropicClient::new("test-key").with_retry_policy(
|
||||
let client = AnthropicClient::new("test-key").with_retry_policy(
|
||||
3,
|
||||
Duration::from_millis(10),
|
||||
Duration::from_millis(25),
|
||||
|
||||
@@ -5,6 +5,8 @@ use std::time::Duration;
|
||||
#[derive(Debug)]
|
||||
pub enum ApiError {
|
||||
MissingApiKey,
|
||||
ExpiredOAuthToken,
|
||||
Auth(String),
|
||||
InvalidApiKeyEnv(VarError),
|
||||
Http(reqwest::Error),
|
||||
Io(std::io::Error),
|
||||
@@ -35,6 +37,8 @@ impl ApiError {
|
||||
Self::Api { retryable, .. } => *retryable,
|
||||
Self::RetriesExhausted { last_error, .. } => last_error.is_retryable(),
|
||||
Self::MissingApiKey
|
||||
| Self::ExpiredOAuthToken
|
||||
| Self::Auth(_)
|
||||
| Self::InvalidApiKeyEnv(_)
|
||||
| Self::Io(_)
|
||||
| Self::Json(_)
|
||||
@@ -53,6 +57,13 @@ impl Display for ApiError {
|
||||
"ANTHROPIC_AUTH_TOKEN or ANTHROPIC_API_KEY is not set; export one before calling the Anthropic API"
|
||||
)
|
||||
}
|
||||
Self::ExpiredOAuthToken => {
|
||||
write!(
|
||||
f,
|
||||
"saved OAuth token is expired and no refresh token is available"
|
||||
)
|
||||
}
|
||||
Self::Auth(message) => write!(f, "auth error: {message}"),
|
||||
Self::InvalidApiKeyEnv(error) => {
|
||||
write!(
|
||||
f,
|
||||
|
||||
@@ -3,7 +3,10 @@ mod error;
|
||||
mod sse;
|
||||
mod types;
|
||||
|
||||
pub use client::{AnthropicClient, AuthSource, MessageStream, OAuthTokenSet};
|
||||
pub use client::{
|
||||
oauth_token_is_expired, resolve_saved_oauth_token, resolve_startup_auth_source,
|
||||
AnthropicClient, AuthSource, MessageStream, OAuthTokenSet,
|
||||
};
|
||||
pub use error::ApiError;
|
||||
pub use sse::{parse_frame, SseParser};
|
||||
pub use types::{
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
use std::fs;
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
use crate::session::{ContentBlock, ConversationMessage, MessageRole, Session};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
@@ -90,6 +93,7 @@ pub fn compact_session(session: &Session, config: CompactionConfig) -> Compactio
|
||||
let preserved = session.messages[keep_from..].to_vec();
|
||||
let summary = summarize_messages(removed);
|
||||
let formatted_summary = format_compact_summary(&summary);
|
||||
persist_compact_summary(&formatted_summary);
|
||||
let continuation = get_compact_continuation_message(&summary, true, !preserved.is_empty());
|
||||
|
||||
let mut compacted_messages = vec![ConversationMessage {
|
||||
@@ -110,6 +114,35 @@ pub fn compact_session(session: &Session, config: CompactionConfig) -> Compactio
|
||||
}
|
||||
}
|
||||
|
||||
fn persist_compact_summary(formatted_summary: &str) {
|
||||
if formatted_summary.trim().is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let Ok(cwd) = std::env::current_dir() else {
|
||||
return;
|
||||
};
|
||||
let memory_dir = cwd.join(".claude").join("memory");
|
||||
if fs::create_dir_all(&memory_dir).is_err() {
|
||||
return;
|
||||
}
|
||||
|
||||
let path = memory_dir.join(compact_summary_filename());
|
||||
let _ = fs::write(path, render_memory_file(formatted_summary));
|
||||
}
|
||||
|
||||
fn compact_summary_filename() -> String {
|
||||
let timestamp = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
format!("summary-{timestamp}.md")
|
||||
}
|
||||
|
||||
fn render_memory_file(formatted_summary: &str) -> String {
|
||||
format!("# Project memory\n\n{}\n", formatted_summary.trim())
|
||||
}
|
||||
|
||||
fn summarize_messages(messages: &[ConversationMessage]) -> String {
|
||||
let user_messages = messages
|
||||
.iter()
|
||||
@@ -378,14 +411,21 @@ fn collapse_blank_lines(content: &str) -> String {
|
||||
mod tests {
|
||||
use super::{
|
||||
collect_key_files, compact_session, estimate_session_tokens, format_compact_summary,
|
||||
infer_pending_work, should_compact, CompactionConfig,
|
||||
infer_pending_work, render_memory_file, should_compact, CompactionConfig,
|
||||
};
|
||||
use std::fs;
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
use crate::session::{ContentBlock, ConversationMessage, MessageRole, Session};
|
||||
|
||||
#[test]
|
||||
fn formats_compact_summary_like_upstream() {
|
||||
let summary = "<analysis>scratch</analysis>\n<summary>Kept work</summary>";
|
||||
assert_eq!(format_compact_summary(summary), "Summary:\nKept work");
|
||||
assert_eq!(
|
||||
render_memory_file("Summary:\nKept work"),
|
||||
"# Project memory\n\nSummary:\nKept work\n"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -402,6 +442,63 @@ mod tests {
|
||||
assert!(result.formatted_summary.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn persists_compacted_summaries_under_dot_claude_memory() {
|
||||
let _guard = crate::test_env_lock();
|
||||
let temp = std::env::temp_dir().join(format!(
|
||||
"runtime-compact-memory-{}",
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.expect("time after epoch")
|
||||
.as_nanos()
|
||||
));
|
||||
fs::create_dir_all(&temp).expect("temp dir");
|
||||
let previous = std::env::current_dir().expect("cwd");
|
||||
std::env::set_current_dir(&temp).expect("set cwd");
|
||||
|
||||
let session = Session {
|
||||
version: 1,
|
||||
messages: vec![
|
||||
ConversationMessage::user_text("one ".repeat(200)),
|
||||
ConversationMessage::assistant(vec![ContentBlock::Text {
|
||||
text: "two ".repeat(200),
|
||||
}]),
|
||||
ConversationMessage::tool_result("1", "bash", "ok ".repeat(200), false),
|
||||
ConversationMessage {
|
||||
role: MessageRole::Assistant,
|
||||
blocks: vec![ContentBlock::Text {
|
||||
text: "recent".to_string(),
|
||||
}],
|
||||
usage: None,
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
let result = compact_session(
|
||||
&session,
|
||||
CompactionConfig {
|
||||
preserve_recent_messages: 2,
|
||||
max_estimated_tokens: 1,
|
||||
},
|
||||
);
|
||||
let memory_dir = temp.join(".claude").join("memory");
|
||||
let files = fs::read_dir(&memory_dir)
|
||||
.expect("memory dir exists")
|
||||
.flatten()
|
||||
.map(|entry| entry.path())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(result.removed_message_count, 2);
|
||||
assert_eq!(files.len(), 1);
|
||||
let persisted = fs::read_to_string(&files[0]).expect("memory file readable");
|
||||
|
||||
std::env::set_current_dir(previous).expect("restore cwd");
|
||||
fs::remove_dir_all(temp).expect("cleanup temp dir");
|
||||
|
||||
assert!(persisted.contains("# Project memory"));
|
||||
assert!(persisted.contains("Summary:"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compacts_older_messages_into_a_system_summary() {
|
||||
let session = Session {
|
||||
|
||||
@@ -14,6 +14,13 @@ pub enum ConfigSource {
|
||||
Local,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum ResolvedPermissionMode {
|
||||
ReadOnly,
|
||||
WorkspaceWrite,
|
||||
DangerFullAccess,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ConfigEntry {
|
||||
pub source: ConfigSource,
|
||||
@@ -31,6 +38,8 @@ pub struct RuntimeConfig {
|
||||
pub struct RuntimeFeatureConfig {
|
||||
mcp: McpConfigCollection,
|
||||
oauth: Option<OAuthConfig>,
|
||||
model: Option<String>,
|
||||
permission_mode: Option<ResolvedPermissionMode>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Default)]
|
||||
@@ -165,11 +174,23 @@ impl ConfigLoader {
|
||||
|
||||
#[must_use]
|
||||
pub fn discover(&self) -> Vec<ConfigEntry> {
|
||||
let user_legacy_path = self.config_home.parent().map_or_else(
|
||||
|| PathBuf::from(".claude.json"),
|
||||
|parent| parent.join(".claude.json"),
|
||||
);
|
||||
vec![
|
||||
ConfigEntry {
|
||||
source: ConfigSource::User,
|
||||
path: user_legacy_path,
|
||||
},
|
||||
ConfigEntry {
|
||||
source: ConfigSource::User,
|
||||
path: self.config_home.join("settings.json"),
|
||||
},
|
||||
ConfigEntry {
|
||||
source: ConfigSource::Project,
|
||||
path: self.cwd.join(".claude.json"),
|
||||
},
|
||||
ConfigEntry {
|
||||
source: ConfigSource::Project,
|
||||
path: self.cwd.join(".claude").join("settings.json"),
|
||||
@@ -195,14 +216,15 @@ impl ConfigLoader {
|
||||
loaded_entries.push(entry);
|
||||
}
|
||||
|
||||
let merged_value = JsonValue::Object(merged.clone());
|
||||
|
||||
let feature_config = RuntimeFeatureConfig {
|
||||
mcp: McpConfigCollection {
|
||||
servers: mcp_servers,
|
||||
},
|
||||
oauth: parse_optional_oauth_config(
|
||||
&JsonValue::Object(merged.clone()),
|
||||
"merged settings.oauth",
|
||||
)?,
|
||||
oauth: parse_optional_oauth_config(&merged_value, "merged settings.oauth")?,
|
||||
model: parse_optional_model(&merged_value),
|
||||
permission_mode: parse_optional_permission_mode(&merged_value)?,
|
||||
};
|
||||
|
||||
Ok(RuntimeConfig {
|
||||
@@ -257,6 +279,16 @@ impl RuntimeConfig {
|
||||
pub fn oauth(&self) -> Option<&OAuthConfig> {
|
||||
self.feature_config.oauth.as_ref()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn model(&self) -> Option<&str> {
|
||||
self.feature_config.model.as_deref()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn permission_mode(&self) -> Option<ResolvedPermissionMode> {
|
||||
self.feature_config.permission_mode
|
||||
}
|
||||
}
|
||||
|
||||
impl RuntimeFeatureConfig {
|
||||
@@ -269,6 +301,16 @@ impl RuntimeFeatureConfig {
|
||||
pub fn oauth(&self) -> Option<&OAuthConfig> {
|
||||
self.oauth.as_ref()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn model(&self) -> Option<&str> {
|
||||
self.model.as_deref()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn permission_mode(&self) -> Option<ResolvedPermissionMode> {
|
||||
self.permission_mode
|
||||
}
|
||||
}
|
||||
|
||||
impl McpConfigCollection {
|
||||
@@ -307,6 +349,7 @@ impl McpServerConfig {
|
||||
fn read_optional_json_object(
|
||||
path: &Path,
|
||||
) -> Result<Option<BTreeMap<String, JsonValue>>, ConfigError> {
|
||||
let is_legacy_config = path.file_name().and_then(|name| name.to_str()) == Some(".claude.json");
|
||||
let contents = match fs::read_to_string(path) {
|
||||
Ok(contents) => contents,
|
||||
Err(error) if error.kind() == std::io::ErrorKind::NotFound => return Ok(None),
|
||||
@@ -317,14 +360,20 @@ fn read_optional_json_object(
|
||||
return Ok(Some(BTreeMap::new()));
|
||||
}
|
||||
|
||||
let parsed = JsonValue::parse(&contents)
|
||||
.map_err(|error| ConfigError::Parse(format!("{}: {error}", path.display())))?;
|
||||
let object = parsed.as_object().ok_or_else(|| {
|
||||
ConfigError::Parse(format!(
|
||||
let parsed = match JsonValue::parse(&contents) {
|
||||
Ok(parsed) => parsed,
|
||||
Err(error) if is_legacy_config => return Ok(None),
|
||||
Err(error) => return Err(ConfigError::Parse(format!("{}: {error}", path.display()))),
|
||||
};
|
||||
let Some(object) = parsed.as_object() else {
|
||||
if is_legacy_config {
|
||||
return Ok(None);
|
||||
}
|
||||
return Err(ConfigError::Parse(format!(
|
||||
"{}: top-level settings value must be a JSON object",
|
||||
path.display()
|
||||
))
|
||||
})?;
|
||||
)));
|
||||
};
|
||||
Ok(Some(object.clone()))
|
||||
}
|
||||
|
||||
@@ -355,6 +404,47 @@ fn merge_mcp_servers(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn parse_optional_model(root: &JsonValue) -> Option<String> {
|
||||
root.as_object()
|
||||
.and_then(|object| object.get("model"))
|
||||
.and_then(JsonValue::as_str)
|
||||
.map(ToOwned::to_owned)
|
||||
}
|
||||
|
||||
fn parse_optional_permission_mode(
|
||||
root: &JsonValue,
|
||||
) -> Result<Option<ResolvedPermissionMode>, ConfigError> {
|
||||
let Some(object) = root.as_object() else {
|
||||
return Ok(None);
|
||||
};
|
||||
if let Some(mode) = object.get("permissionMode").and_then(JsonValue::as_str) {
|
||||
return parse_permission_mode_label(mode, "merged settings.permissionMode").map(Some);
|
||||
}
|
||||
let Some(mode) = object
|
||||
.get("permissions")
|
||||
.and_then(JsonValue::as_object)
|
||||
.and_then(|permissions| permissions.get("defaultMode"))
|
||||
.and_then(JsonValue::as_str)
|
||||
else {
|
||||
return Ok(None);
|
||||
};
|
||||
parse_permission_mode_label(mode, "merged settings.permissions.defaultMode").map(Some)
|
||||
}
|
||||
|
||||
fn parse_permission_mode_label(
|
||||
mode: &str,
|
||||
context: &str,
|
||||
) -> Result<ResolvedPermissionMode, ConfigError> {
|
||||
match mode {
|
||||
"default" | "plan" | "read-only" => Ok(ResolvedPermissionMode::ReadOnly),
|
||||
"acceptEdits" | "auto" | "workspace-write" => Ok(ResolvedPermissionMode::WorkspaceWrite),
|
||||
"dontAsk" | "danger-full-access" => Ok(ResolvedPermissionMode::DangerFullAccess),
|
||||
other => Err(ConfigError::Parse(format!(
|
||||
"{context}: unsupported permission mode {other}"
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_optional_oauth_config(
|
||||
root: &JsonValue,
|
||||
context: &str,
|
||||
@@ -594,7 +684,8 @@ fn deep_merge_objects(
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{
|
||||
ConfigLoader, ConfigSource, McpServerConfig, McpTransport, CLAUDE_CODE_SETTINGS_SCHEMA_NAME,
|
||||
ConfigLoader, ConfigSource, McpServerConfig, McpTransport, ResolvedPermissionMode,
|
||||
CLAUDE_CODE_SETTINGS_SCHEMA_NAME,
|
||||
};
|
||||
use crate::json::JsonValue;
|
||||
use std::fs;
|
||||
@@ -635,14 +726,24 @@ mod tests {
|
||||
fs::create_dir_all(cwd.join(".claude")).expect("project config dir");
|
||||
fs::create_dir_all(&home).expect("home config dir");
|
||||
|
||||
fs::write(
|
||||
home.parent().expect("home parent").join(".claude.json"),
|
||||
r#"{"model":"haiku","env":{"A":"1"},"mcpServers":{"home":{"command":"uvx","args":["home"]}}}"#,
|
||||
)
|
||||
.expect("write user compat config");
|
||||
fs::write(
|
||||
home.join("settings.json"),
|
||||
r#"{"model":"sonnet","env":{"A":"1"},"hooks":{"PreToolUse":["base"]}}"#,
|
||||
r#"{"model":"sonnet","env":{"A2":"1"},"hooks":{"PreToolUse":["base"]},"permissions":{"defaultMode":"plan"}}"#,
|
||||
)
|
||||
.expect("write user settings");
|
||||
fs::write(
|
||||
cwd.join(".claude.json"),
|
||||
r#"{"model":"project-compat","env":{"B":"2"}}"#,
|
||||
)
|
||||
.expect("write project compat config");
|
||||
fs::write(
|
||||
cwd.join(".claude").join("settings.json"),
|
||||
r#"{"env":{"B":"2"},"hooks":{"PostToolUse":["project"]}}"#,
|
||||
r#"{"env":{"C":"3"},"hooks":{"PostToolUse":["project"]},"mcpServers":{"project":{"command":"uvx","args":["project"]}}}"#,
|
||||
)
|
||||
.expect("write project settings");
|
||||
fs::write(
|
||||
@@ -656,25 +757,37 @@ mod tests {
|
||||
.expect("config should load");
|
||||
|
||||
assert_eq!(CLAUDE_CODE_SETTINGS_SCHEMA_NAME, "SettingsSchema");
|
||||
assert_eq!(loaded.loaded_entries().len(), 3);
|
||||
assert_eq!(loaded.loaded_entries().len(), 5);
|
||||
assert_eq!(loaded.loaded_entries()[0].source, ConfigSource::User);
|
||||
assert_eq!(
|
||||
loaded.get("model"),
|
||||
Some(&JsonValue::String("opus".to_string()))
|
||||
);
|
||||
assert_eq!(loaded.model(), Some("opus"));
|
||||
assert_eq!(
|
||||
loaded.permission_mode(),
|
||||
Some(ResolvedPermissionMode::WorkspaceWrite)
|
||||
);
|
||||
assert_eq!(
|
||||
loaded
|
||||
.get("env")
|
||||
.and_then(JsonValue::as_object)
|
||||
.expect("env object")
|
||||
.len(),
|
||||
2
|
||||
4
|
||||
);
|
||||
assert!(loaded
|
||||
.get("hooks")
|
||||
.and_then(JsonValue::as_object)
|
||||
.expect("hooks object")
|
||||
.contains_key("PreToolUse"));
|
||||
assert!(loaded
|
||||
.get("hooks")
|
||||
.and_then(JsonValue::as_object)
|
||||
.expect("hooks object")
|
||||
.contains_key("PostToolUse"));
|
||||
assert!(loaded.mcp().get("home").is_some());
|
||||
assert!(loaded.mcp().get("project").is_some());
|
||||
|
||||
fs::remove_dir_all(root).expect("cleanup temp dir");
|
||||
}
|
||||
|
||||
@@ -408,13 +408,14 @@ mod tests {
|
||||
.sum::<i32>();
|
||||
Ok(total.to_string())
|
||||
});
|
||||
let permission_policy = PermissionPolicy::new(PermissionMode::Prompt);
|
||||
let permission_policy = PermissionPolicy::new(PermissionMode::WorkspaceWrite);
|
||||
let system_prompt = SystemPromptBuilder::new()
|
||||
.with_project_context(ProjectContext {
|
||||
cwd: PathBuf::from("/tmp/project"),
|
||||
current_date: "2026-03-31".to_string(),
|
||||
git_status: None,
|
||||
instruction_files: Vec::new(),
|
||||
memory_files: Vec::new(),
|
||||
})
|
||||
.with_os("linux", "6.8")
|
||||
.build();
|
||||
@@ -487,7 +488,7 @@ mod tests {
|
||||
Session::new(),
|
||||
SingleCallApiClient,
|
||||
StaticToolExecutor::new(),
|
||||
PermissionPolicy::new(PermissionMode::Prompt),
|
||||
PermissionPolicy::new(PermissionMode::WorkspaceWrite),
|
||||
vec!["system".to_string()],
|
||||
);
|
||||
|
||||
@@ -536,7 +537,7 @@ mod tests {
|
||||
session,
|
||||
SimpleApi,
|
||||
StaticToolExecutor::new(),
|
||||
PermissionPolicy::new(PermissionMode::Allow),
|
||||
PermissionPolicy::new(PermissionMode::DangerFullAccess),
|
||||
vec!["system".to_string()],
|
||||
);
|
||||
|
||||
@@ -563,7 +564,7 @@ mod tests {
|
||||
Session::new(),
|
||||
SimpleApi,
|
||||
StaticToolExecutor::new(),
|
||||
PermissionPolicy::new(PermissionMode::Allow),
|
||||
PermissionPolicy::new(PermissionMode::DangerFullAccess),
|
||||
vec!["system".to_string()],
|
||||
);
|
||||
runtime.run_turn("a", None).expect("turn a");
|
||||
|
||||
@@ -25,7 +25,8 @@ pub use config::{
|
||||
ConfigEntry, ConfigError, ConfigLoader, ConfigSource, McpClaudeAiProxyServerConfig,
|
||||
McpConfigCollection, McpOAuthConfig, McpRemoteServerConfig, McpSdkServerConfig,
|
||||
McpServerConfig, McpStdioServerConfig, McpTransport, McpWebSocketServerConfig, OAuthConfig,
|
||||
RuntimeConfig, RuntimeFeatureConfig, ScopedMcpServerConfig, CLAUDE_CODE_SETTINGS_SCHEMA_NAME,
|
||||
ResolvedPermissionMode, RuntimeConfig, RuntimeFeatureConfig, ScopedMcpServerConfig,
|
||||
CLAUDE_CODE_SETTINGS_SCHEMA_NAME,
|
||||
};
|
||||
pub use conversation::{
|
||||
ApiClient, ApiRequest, AssistantEvent, ConversationRuntime, RuntimeError, StaticToolExecutor,
|
||||
@@ -46,14 +47,17 @@ pub use mcp_client::{
|
||||
};
|
||||
pub use mcp_stdio::{
|
||||
spawn_mcp_stdio_process, JsonRpcError, JsonRpcId, JsonRpcRequest, JsonRpcResponse,
|
||||
McpInitializeClientInfo, McpInitializeParams, McpInitializeResult, McpInitializeServerInfo,
|
||||
McpListResourcesParams, McpListResourcesResult, McpListToolsParams, McpListToolsResult,
|
||||
McpReadResourceParams, McpReadResourceResult, McpResource, McpResourceContents,
|
||||
McpStdioProcess, McpTool, McpToolCallContent, McpToolCallParams, McpToolCallResult,
|
||||
ManagedMcpTool, McpInitializeClientInfo, McpInitializeParams, McpInitializeResult,
|
||||
McpInitializeServerInfo, McpListResourcesParams, McpListResourcesResult, McpListToolsParams,
|
||||
McpListToolsResult, McpReadResourceParams, McpReadResourceResult, McpResource,
|
||||
McpResourceContents, McpServerManager, McpServerManagerError, McpStdioProcess, McpTool,
|
||||
McpToolCallContent, McpToolCallParams, McpToolCallResult, UnsupportedMcpServer,
|
||||
};
|
||||
pub use oauth::{
|
||||
code_challenge_s256, generate_pkce_pair, generate_state, loopback_redirect_uri,
|
||||
OAuthAuthorizationRequest, OAuthRefreshRequest, OAuthTokenExchangeRequest, OAuthTokenSet,
|
||||
clear_oauth_credentials, code_challenge_s256, credentials_path, generate_pkce_pair,
|
||||
generate_state, load_oauth_credentials, loopback_redirect_uri, parse_oauth_callback_query,
|
||||
parse_oauth_callback_request_target, save_oauth_credentials, OAuthAuthorizationRequest,
|
||||
OAuthCallbackParams, OAuthRefreshRequest, OAuthTokenExchangeRequest, OAuthTokenSet,
|
||||
PkceChallengeMethod, PkceCodePair,
|
||||
};
|
||||
pub use permissions::{
|
||||
@@ -73,3 +77,11 @@ pub use session::{ContentBlock, ConversationMessage, MessageRole, Session, Sessi
|
||||
pub use usage::{
|
||||
format_usd, pricing_for_model, ModelPricing, TokenUsage, UsageCostEstimate, UsageTracker,
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn test_env_lock() -> std::sync::MutexGuard<'static, ()> {
|
||||
static LOCK: std::sync::OnceLock<std::sync::Mutex<()>> = std::sync::OnceLock::new();
|
||||
LOCK.get_or_init(|| std::sync::Mutex::new(()))
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
}
|
||||
|
||||
@@ -8,6 +8,8 @@ use serde_json::Value as JsonValue;
|
||||
use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
|
||||
use tokio::process::{Child, ChildStdin, ChildStdout, Command};
|
||||
|
||||
use crate::config::{McpTransport, RuntimeConfig, ScopedMcpServerConfig};
|
||||
use crate::mcp::mcp_tool_name;
|
||||
use crate::mcp_client::{McpClientBootstrap, McpClientTransport, McpStdioTransport};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
@@ -200,6 +202,374 @@ pub struct McpReadResourceResult {
|
||||
pub contents: Vec<McpResourceContents>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct ManagedMcpTool {
|
||||
pub server_name: String,
|
||||
pub qualified_name: String,
|
||||
pub raw_name: String,
|
||||
pub tool: McpTool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct UnsupportedMcpServer {
|
||||
pub server_name: String,
|
||||
pub transport: McpTransport,
|
||||
pub reason: String,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum McpServerManagerError {
|
||||
Io(io::Error),
|
||||
JsonRpc {
|
||||
server_name: String,
|
||||
method: &'static str,
|
||||
error: JsonRpcError,
|
||||
},
|
||||
InvalidResponse {
|
||||
server_name: String,
|
||||
method: &'static str,
|
||||
details: String,
|
||||
},
|
||||
UnknownTool {
|
||||
qualified_name: String,
|
||||
},
|
||||
UnknownServer {
|
||||
server_name: String,
|
||||
},
|
||||
}
|
||||
|
||||
impl std::fmt::Display for McpServerManagerError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Io(error) => write!(f, "{error}"),
|
||||
Self::JsonRpc {
|
||||
server_name,
|
||||
method,
|
||||
error,
|
||||
} => write!(
|
||||
f,
|
||||
"MCP server `{server_name}` returned JSON-RPC error for {method}: {} ({})",
|
||||
error.message, error.code
|
||||
),
|
||||
Self::InvalidResponse {
|
||||
server_name,
|
||||
method,
|
||||
details,
|
||||
} => write!(
|
||||
f,
|
||||
"MCP server `{server_name}` returned invalid response for {method}: {details}"
|
||||
),
|
||||
Self::UnknownTool { qualified_name } => {
|
||||
write!(f, "unknown MCP tool `{qualified_name}`")
|
||||
}
|
||||
Self::UnknownServer { server_name } => write!(f, "unknown MCP server `{server_name}`"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for McpServerManagerError {
|
||||
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||
match self {
|
||||
Self::Io(error) => Some(error),
|
||||
Self::JsonRpc { .. }
|
||||
| Self::InvalidResponse { .. }
|
||||
| Self::UnknownTool { .. }
|
||||
| Self::UnknownServer { .. } => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<io::Error> for McpServerManagerError {
|
||||
fn from(value: io::Error) -> Self {
|
||||
Self::Io(value)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
struct ToolRoute {
|
||||
server_name: String,
|
||||
raw_name: String,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ManagedMcpServer {
|
||||
bootstrap: McpClientBootstrap,
|
||||
process: Option<McpStdioProcess>,
|
||||
initialized: bool,
|
||||
}
|
||||
|
||||
impl ManagedMcpServer {
|
||||
fn new(bootstrap: McpClientBootstrap) -> Self {
|
||||
Self {
|
||||
bootstrap,
|
||||
process: None,
|
||||
initialized: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct McpServerManager {
|
||||
servers: BTreeMap<String, ManagedMcpServer>,
|
||||
unsupported_servers: Vec<UnsupportedMcpServer>,
|
||||
tool_index: BTreeMap<String, ToolRoute>,
|
||||
next_request_id: u64,
|
||||
}
|
||||
|
||||
impl McpServerManager {
|
||||
#[must_use]
|
||||
pub fn from_runtime_config(config: &RuntimeConfig) -> Self {
|
||||
Self::from_servers(config.mcp().servers())
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn from_servers(servers: &BTreeMap<String, ScopedMcpServerConfig>) -> Self {
|
||||
let mut managed_servers = BTreeMap::new();
|
||||
let mut unsupported_servers = Vec::new();
|
||||
|
||||
for (server_name, server_config) in servers {
|
||||
if server_config.transport() == McpTransport::Stdio {
|
||||
let bootstrap = McpClientBootstrap::from_scoped_config(server_name, server_config);
|
||||
managed_servers.insert(server_name.clone(), ManagedMcpServer::new(bootstrap));
|
||||
} else {
|
||||
unsupported_servers.push(UnsupportedMcpServer {
|
||||
server_name: server_name.clone(),
|
||||
transport: server_config.transport(),
|
||||
reason: format!(
|
||||
"transport {:?} is not supported by McpServerManager",
|
||||
server_config.transport()
|
||||
),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Self {
|
||||
servers: managed_servers,
|
||||
unsupported_servers,
|
||||
tool_index: BTreeMap::new(),
|
||||
next_request_id: 1,
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn unsupported_servers(&self) -> &[UnsupportedMcpServer] {
|
||||
&self.unsupported_servers
|
||||
}
|
||||
|
||||
pub async fn discover_tools(&mut self) -> Result<Vec<ManagedMcpTool>, McpServerManagerError> {
|
||||
let server_names = self.servers.keys().cloned().collect::<Vec<_>>();
|
||||
let mut discovered_tools = Vec::new();
|
||||
|
||||
for server_name in server_names {
|
||||
self.ensure_server_ready(&server_name).await?;
|
||||
self.clear_routes_for_server(&server_name);
|
||||
|
||||
let mut cursor = None;
|
||||
loop {
|
||||
let request_id = self.take_request_id();
|
||||
let response = {
|
||||
let server = self.server_mut(&server_name)?;
|
||||
let process = server.process.as_mut().ok_or_else(|| {
|
||||
McpServerManagerError::InvalidResponse {
|
||||
server_name: server_name.clone(),
|
||||
method: "tools/list",
|
||||
details: "server process missing after initialization".to_string(),
|
||||
}
|
||||
})?;
|
||||
process
|
||||
.list_tools(
|
||||
request_id,
|
||||
Some(McpListToolsParams {
|
||||
cursor: cursor.clone(),
|
||||
}),
|
||||
)
|
||||
.await?
|
||||
};
|
||||
|
||||
if let Some(error) = response.error {
|
||||
return Err(McpServerManagerError::JsonRpc {
|
||||
server_name: server_name.clone(),
|
||||
method: "tools/list",
|
||||
error,
|
||||
});
|
||||
}
|
||||
|
||||
let result =
|
||||
response
|
||||
.result
|
||||
.ok_or_else(|| McpServerManagerError::InvalidResponse {
|
||||
server_name: server_name.clone(),
|
||||
method: "tools/list",
|
||||
details: "missing result payload".to_string(),
|
||||
})?;
|
||||
|
||||
for tool in result.tools {
|
||||
let qualified_name = mcp_tool_name(&server_name, &tool.name);
|
||||
self.tool_index.insert(
|
||||
qualified_name.clone(),
|
||||
ToolRoute {
|
||||
server_name: server_name.clone(),
|
||||
raw_name: tool.name.clone(),
|
||||
},
|
||||
);
|
||||
discovered_tools.push(ManagedMcpTool {
|
||||
server_name: server_name.clone(),
|
||||
qualified_name,
|
||||
raw_name: tool.name.clone(),
|
||||
tool,
|
||||
});
|
||||
}
|
||||
|
||||
match result.next_cursor {
|
||||
Some(next_cursor) => cursor = Some(next_cursor),
|
||||
None => break,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(discovered_tools)
|
||||
}
|
||||
|
||||
pub async fn call_tool(
|
||||
&mut self,
|
||||
qualified_tool_name: &str,
|
||||
arguments: Option<JsonValue>,
|
||||
) -> Result<JsonRpcResponse<McpToolCallResult>, McpServerManagerError> {
|
||||
let route = self
|
||||
.tool_index
|
||||
.get(qualified_tool_name)
|
||||
.cloned()
|
||||
.ok_or_else(|| McpServerManagerError::UnknownTool {
|
||||
qualified_name: qualified_tool_name.to_string(),
|
||||
})?;
|
||||
|
||||
self.ensure_server_ready(&route.server_name).await?;
|
||||
let request_id = self.take_request_id();
|
||||
let response =
|
||||
{
|
||||
let server = self.server_mut(&route.server_name)?;
|
||||
let process = server.process.as_mut().ok_or_else(|| {
|
||||
McpServerManagerError::InvalidResponse {
|
||||
server_name: route.server_name.clone(),
|
||||
method: "tools/call",
|
||||
details: "server process missing after initialization".to_string(),
|
||||
}
|
||||
})?;
|
||||
process
|
||||
.call_tool(
|
||||
request_id,
|
||||
McpToolCallParams {
|
||||
name: route.raw_name,
|
||||
arguments,
|
||||
meta: None,
|
||||
},
|
||||
)
|
||||
.await?
|
||||
};
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
pub async fn shutdown(&mut self) -> Result<(), McpServerManagerError> {
|
||||
let server_names = self.servers.keys().cloned().collect::<Vec<_>>();
|
||||
for server_name in server_names {
|
||||
let server = self.server_mut(&server_name)?;
|
||||
if let Some(process) = server.process.as_mut() {
|
||||
process.shutdown().await?;
|
||||
}
|
||||
server.process = None;
|
||||
server.initialized = false;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn clear_routes_for_server(&mut self, server_name: &str) {
|
||||
self.tool_index
|
||||
.retain(|_, route| route.server_name != server_name);
|
||||
}
|
||||
|
||||
fn server_mut(
|
||||
&mut self,
|
||||
server_name: &str,
|
||||
) -> Result<&mut ManagedMcpServer, McpServerManagerError> {
|
||||
self.servers
|
||||
.get_mut(server_name)
|
||||
.ok_or_else(|| McpServerManagerError::UnknownServer {
|
||||
server_name: server_name.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
fn take_request_id(&mut self) -> JsonRpcId {
|
||||
let id = self.next_request_id;
|
||||
self.next_request_id = self.next_request_id.saturating_add(1);
|
||||
JsonRpcId::Number(id)
|
||||
}
|
||||
|
||||
async fn ensure_server_ready(
|
||||
&mut self,
|
||||
server_name: &str,
|
||||
) -> Result<(), McpServerManagerError> {
|
||||
let needs_spawn = self
|
||||
.servers
|
||||
.get(server_name)
|
||||
.map(|server| server.process.is_none())
|
||||
.ok_or_else(|| McpServerManagerError::UnknownServer {
|
||||
server_name: server_name.to_string(),
|
||||
})?;
|
||||
|
||||
if needs_spawn {
|
||||
let server = self.server_mut(server_name)?;
|
||||
server.process = Some(spawn_mcp_stdio_process(&server.bootstrap)?);
|
||||
server.initialized = false;
|
||||
}
|
||||
|
||||
let needs_initialize = self
|
||||
.servers
|
||||
.get(server_name)
|
||||
.map(|server| !server.initialized)
|
||||
.ok_or_else(|| McpServerManagerError::UnknownServer {
|
||||
server_name: server_name.to_string(),
|
||||
})?;
|
||||
|
||||
if needs_initialize {
|
||||
let request_id = self.take_request_id();
|
||||
let response = {
|
||||
let server = self.server_mut(server_name)?;
|
||||
let process = server.process.as_mut().ok_or_else(|| {
|
||||
McpServerManagerError::InvalidResponse {
|
||||
server_name: server_name.to_string(),
|
||||
method: "initialize",
|
||||
details: "server process missing before initialize".to_string(),
|
||||
}
|
||||
})?;
|
||||
process
|
||||
.initialize(request_id, default_initialize_params())
|
||||
.await?
|
||||
};
|
||||
|
||||
if let Some(error) = response.error {
|
||||
return Err(McpServerManagerError::JsonRpc {
|
||||
server_name: server_name.to_string(),
|
||||
method: "initialize",
|
||||
error,
|
||||
});
|
||||
}
|
||||
|
||||
if response.result.is_none() {
|
||||
return Err(McpServerManagerError::InvalidResponse {
|
||||
server_name: server_name.to_string(),
|
||||
method: "initialize",
|
||||
details: "missing result payload".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
let server = self.server_mut(server_name)?;
|
||||
server.initialized = true;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct McpStdioProcess {
|
||||
child: Child,
|
||||
@@ -385,6 +755,14 @@ impl McpStdioProcess {
|
||||
pub async fn wait(&mut self) -> io::Result<std::process::ExitStatus> {
|
||||
self.child.wait().await
|
||||
}
|
||||
|
||||
async fn shutdown(&mut self) -> io::Result<()> {
|
||||
if self.child.try_wait()?.is_none() {
|
||||
self.child.kill().await?;
|
||||
}
|
||||
let _ = self.child.wait().await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn spawn_mcp_stdio_process(bootstrap: &McpClientBootstrap) -> io::Result<McpStdioProcess> {
|
||||
@@ -413,6 +791,17 @@ fn encode_frame(payload: &[u8]) -> Vec<u8> {
|
||||
framed
|
||||
}
|
||||
|
||||
fn default_initialize_params() -> McpInitializeParams {
|
||||
McpInitializeParams {
|
||||
protocol_version: "2025-03-26".to_string(),
|
||||
capabilities: JsonValue::Object(serde_json::Map::new()),
|
||||
client_info: McpInitializeClientInfo {
|
||||
name: "runtime".to_string(),
|
||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::collections::BTreeMap;
|
||||
@@ -426,15 +815,17 @@ mod tests {
|
||||
use tokio::runtime::Builder;
|
||||
|
||||
use crate::config::{
|
||||
ConfigSource, McpServerConfig, McpStdioServerConfig, ScopedMcpServerConfig,
|
||||
ConfigSource, McpRemoteServerConfig, McpSdkServerConfig, McpServerConfig,
|
||||
McpStdioServerConfig, McpWebSocketServerConfig, ScopedMcpServerConfig,
|
||||
};
|
||||
use crate::mcp::mcp_tool_name;
|
||||
use crate::mcp_client::McpClientBootstrap;
|
||||
|
||||
use super::{
|
||||
spawn_mcp_stdio_process, JsonRpcId, JsonRpcRequest, JsonRpcResponse,
|
||||
McpInitializeClientInfo, McpInitializeParams, McpInitializeResult, McpInitializeServerInfo,
|
||||
McpListToolsResult, McpReadResourceParams, McpReadResourceResult, McpStdioProcess, McpTool,
|
||||
McpToolCallParams,
|
||||
McpListToolsResult, McpReadResourceParams, McpReadResourceResult, McpServerManager,
|
||||
McpServerManagerError, McpStdioProcess, McpTool, McpToolCallParams,
|
||||
};
|
||||
|
||||
fn temp_dir() -> PathBuf {
|
||||
@@ -628,6 +1019,110 @@ mod tests {
|
||||
script_path
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
fn write_manager_mcp_server_script() -> PathBuf {
|
||||
let root = temp_dir();
|
||||
fs::create_dir_all(&root).expect("temp dir");
|
||||
let script_path = root.join("manager-mcp-server.py");
|
||||
let script = [
|
||||
"#!/usr/bin/env python3",
|
||||
"import json, os, sys",
|
||||
"",
|
||||
"LABEL = os.environ.get('MCP_SERVER_LABEL', 'server')",
|
||||
"LOG_PATH = os.environ.get('MCP_LOG_PATH')",
|
||||
"initialize_count = 0",
|
||||
"",
|
||||
"def log(method):",
|
||||
" if LOG_PATH:",
|
||||
" with open(LOG_PATH, 'a', encoding='utf-8') as handle:",
|
||||
" handle.write(f'{method}\\n')",
|
||||
"",
|
||||
"def read_message():",
|
||||
" header = b''",
|
||||
r" while not header.endswith(b'\r\n\r\n'):",
|
||||
" chunk = sys.stdin.buffer.read(1)",
|
||||
" if not chunk:",
|
||||
" return None",
|
||||
" header += chunk",
|
||||
" length = 0",
|
||||
r" for line in header.decode().split('\r\n'):",
|
||||
r" if line.lower().startswith('content-length:'):",
|
||||
r" length = int(line.split(':', 1)[1].strip())",
|
||||
" payload = sys.stdin.buffer.read(length)",
|
||||
" return json.loads(payload.decode())",
|
||||
"",
|
||||
"def send_message(message):",
|
||||
" payload = json.dumps(message).encode()",
|
||||
r" sys.stdout.buffer.write(f'Content-Length: {len(payload)}\r\n\r\n'.encode() + payload)",
|
||||
" sys.stdout.buffer.flush()",
|
||||
"",
|
||||
"while True:",
|
||||
" request = read_message()",
|
||||
" if request is None:",
|
||||
" break",
|
||||
" method = request['method']",
|
||||
" log(method)",
|
||||
" if method == 'initialize':",
|
||||
" initialize_count += 1",
|
||||
" send_message({",
|
||||
" 'jsonrpc': '2.0',",
|
||||
" 'id': request['id'],",
|
||||
" 'result': {",
|
||||
" 'protocolVersion': request['params']['protocolVersion'],",
|
||||
" 'capabilities': {'tools': {}},",
|
||||
" 'serverInfo': {'name': LABEL, 'version': '1.0.0'}",
|
||||
" }",
|
||||
" })",
|
||||
" elif method == 'tools/list':",
|
||||
" send_message({",
|
||||
" 'jsonrpc': '2.0',",
|
||||
" 'id': request['id'],",
|
||||
" 'result': {",
|
||||
" 'tools': [",
|
||||
" {",
|
||||
" 'name': 'echo',",
|
||||
" 'description': f'Echo tool for {LABEL}',",
|
||||
" 'inputSchema': {",
|
||||
" 'type': 'object',",
|
||||
" 'properties': {'text': {'type': 'string'}},",
|
||||
" 'required': ['text']",
|
||||
" }",
|
||||
" }",
|
||||
" ]",
|
||||
" }",
|
||||
" })",
|
||||
" elif method == 'tools/call':",
|
||||
" args = request['params'].get('arguments') or {}",
|
||||
" text = args.get('text', '')",
|
||||
" send_message({",
|
||||
" 'jsonrpc': '2.0',",
|
||||
" 'id': request['id'],",
|
||||
" 'result': {",
|
||||
" 'content': [{'type': 'text', 'text': f'{LABEL}:{text}'}],",
|
||||
" 'structuredContent': {",
|
||||
" 'server': LABEL,",
|
||||
" 'echoed': text,",
|
||||
" 'initializeCount': initialize_count",
|
||||
" },",
|
||||
" 'isError': False",
|
||||
" }",
|
||||
" })",
|
||||
" else:",
|
||||
" send_message({",
|
||||
" 'jsonrpc': '2.0',",
|
||||
" 'id': request['id'],",
|
||||
" 'error': {'code': -32601, 'message': f'unknown method: {method}'},",
|
||||
" })",
|
||||
"",
|
||||
]
|
||||
.join("\n");
|
||||
fs::write(&script_path, script).expect("write script");
|
||||
let mut permissions = fs::metadata(&script_path).expect("metadata").permissions();
|
||||
permissions.set_mode(0o755);
|
||||
fs::set_permissions(&script_path, permissions).expect("chmod");
|
||||
script_path
|
||||
}
|
||||
|
||||
fn sample_bootstrap(script_path: &Path) -> McpClientBootstrap {
|
||||
let config = ScopedMcpServerConfig {
|
||||
scope: ConfigSource::Local,
|
||||
@@ -653,6 +1148,27 @@ mod tests {
|
||||
fs::remove_dir_all(script_path.parent().expect("script parent")).expect("cleanup dir");
|
||||
}
|
||||
|
||||
fn manager_server_config(
|
||||
script_path: &Path,
|
||||
label: &str,
|
||||
log_path: &Path,
|
||||
) -> ScopedMcpServerConfig {
|
||||
ScopedMcpServerConfig {
|
||||
scope: ConfigSource::Local,
|
||||
config: McpServerConfig::Stdio(McpStdioServerConfig {
|
||||
command: "python3".to_string(),
|
||||
args: vec![script_path.to_string_lossy().into_owned()],
|
||||
env: BTreeMap::from([
|
||||
("MCP_SERVER_LABEL".to_string(), label.to_string()),
|
||||
(
|
||||
"MCP_LOG_PATH".to_string(),
|
||||
log_path.to_string_lossy().into_owned(),
|
||||
),
|
||||
]),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn spawns_stdio_process_and_round_trips_io() {
|
||||
let runtime = Builder::new_current_thread()
|
||||
@@ -935,4 +1451,247 @@ mod tests {
|
||||
cleanup_script(&script_path);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn manager_discovers_tools_from_stdio_config() {
|
||||
let runtime = Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.expect("runtime");
|
||||
runtime.block_on(async {
|
||||
let script_path = write_manager_mcp_server_script();
|
||||
let root = script_path.parent().expect("script parent");
|
||||
let log_path = root.join("alpha.log");
|
||||
let servers = BTreeMap::from([(
|
||||
"alpha".to_string(),
|
||||
manager_server_config(&script_path, "alpha", &log_path),
|
||||
)]);
|
||||
let mut manager = McpServerManager::from_servers(&servers);
|
||||
|
||||
let tools = manager.discover_tools().await.expect("discover tools");
|
||||
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].server_name, "alpha");
|
||||
assert_eq!(tools[0].raw_name, "echo");
|
||||
assert_eq!(tools[0].qualified_name, mcp_tool_name("alpha", "echo"));
|
||||
assert_eq!(tools[0].tool.name, "echo");
|
||||
assert!(manager.unsupported_servers().is_empty());
|
||||
|
||||
manager.shutdown().await.expect("shutdown");
|
||||
cleanup_script(&script_path);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn manager_routes_tool_calls_to_correct_server() {
|
||||
let runtime = Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.expect("runtime");
|
||||
runtime.block_on(async {
|
||||
let script_path = write_manager_mcp_server_script();
|
||||
let root = script_path.parent().expect("script parent");
|
||||
let alpha_log = root.join("alpha.log");
|
||||
let beta_log = root.join("beta.log");
|
||||
let servers = BTreeMap::from([
|
||||
(
|
||||
"alpha".to_string(),
|
||||
manager_server_config(&script_path, "alpha", &alpha_log),
|
||||
),
|
||||
(
|
||||
"beta".to_string(),
|
||||
manager_server_config(&script_path, "beta", &beta_log),
|
||||
),
|
||||
]);
|
||||
let mut manager = McpServerManager::from_servers(&servers);
|
||||
|
||||
let tools = manager.discover_tools().await.expect("discover tools");
|
||||
assert_eq!(tools.len(), 2);
|
||||
|
||||
let alpha = manager
|
||||
.call_tool(
|
||||
&mcp_tool_name("alpha", "echo"),
|
||||
Some(json!({"text": "hello"})),
|
||||
)
|
||||
.await
|
||||
.expect("call alpha tool");
|
||||
let beta = manager
|
||||
.call_tool(
|
||||
&mcp_tool_name("beta", "echo"),
|
||||
Some(json!({"text": "world"})),
|
||||
)
|
||||
.await
|
||||
.expect("call beta tool");
|
||||
|
||||
assert_eq!(
|
||||
alpha
|
||||
.result
|
||||
.as_ref()
|
||||
.and_then(|result| result.structured_content.as_ref())
|
||||
.and_then(|value| value.get("server")),
|
||||
Some(&json!("alpha"))
|
||||
);
|
||||
assert_eq!(
|
||||
beta.result
|
||||
.as_ref()
|
||||
.and_then(|result| result.structured_content.as_ref())
|
||||
.and_then(|value| value.get("server")),
|
||||
Some(&json!("beta"))
|
||||
);
|
||||
|
||||
manager.shutdown().await.expect("shutdown");
|
||||
cleanup_script(&script_path);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn manager_records_unsupported_non_stdio_servers_without_panicking() {
|
||||
let servers = BTreeMap::from([
|
||||
(
|
||||
"http".to_string(),
|
||||
ScopedMcpServerConfig {
|
||||
scope: ConfigSource::Local,
|
||||
config: McpServerConfig::Http(McpRemoteServerConfig {
|
||||
url: "https://example.test/mcp".to_string(),
|
||||
headers: BTreeMap::new(),
|
||||
headers_helper: None,
|
||||
oauth: None,
|
||||
}),
|
||||
},
|
||||
),
|
||||
(
|
||||
"sdk".to_string(),
|
||||
ScopedMcpServerConfig {
|
||||
scope: ConfigSource::Local,
|
||||
config: McpServerConfig::Sdk(McpSdkServerConfig {
|
||||
name: "sdk-server".to_string(),
|
||||
}),
|
||||
},
|
||||
),
|
||||
(
|
||||
"ws".to_string(),
|
||||
ScopedMcpServerConfig {
|
||||
scope: ConfigSource::Local,
|
||||
config: McpServerConfig::Ws(McpWebSocketServerConfig {
|
||||
url: "wss://example.test/mcp".to_string(),
|
||||
headers: BTreeMap::new(),
|
||||
headers_helper: None,
|
||||
}),
|
||||
},
|
||||
),
|
||||
]);
|
||||
|
||||
let manager = McpServerManager::from_servers(&servers);
|
||||
let unsupported = manager.unsupported_servers();
|
||||
|
||||
assert_eq!(unsupported.len(), 3);
|
||||
assert_eq!(unsupported[0].server_name, "http");
|
||||
assert_eq!(unsupported[1].server_name, "sdk");
|
||||
assert_eq!(unsupported[2].server_name, "ws");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn manager_shutdown_terminates_spawned_children_and_is_idempotent() {
|
||||
let runtime = Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.expect("runtime");
|
||||
runtime.block_on(async {
|
||||
let script_path = write_manager_mcp_server_script();
|
||||
let root = script_path.parent().expect("script parent");
|
||||
let log_path = root.join("alpha.log");
|
||||
let servers = BTreeMap::from([(
|
||||
"alpha".to_string(),
|
||||
manager_server_config(&script_path, "alpha", &log_path),
|
||||
)]);
|
||||
let mut manager = McpServerManager::from_servers(&servers);
|
||||
|
||||
manager.discover_tools().await.expect("discover tools");
|
||||
manager.shutdown().await.expect("first shutdown");
|
||||
manager.shutdown().await.expect("second shutdown");
|
||||
|
||||
cleanup_script(&script_path);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn manager_reuses_spawned_server_between_discovery_and_call() {
|
||||
let runtime = Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.expect("runtime");
|
||||
runtime.block_on(async {
|
||||
let script_path = write_manager_mcp_server_script();
|
||||
let root = script_path.parent().expect("script parent");
|
||||
let log_path = root.join("alpha.log");
|
||||
let servers = BTreeMap::from([(
|
||||
"alpha".to_string(),
|
||||
manager_server_config(&script_path, "alpha", &log_path),
|
||||
)]);
|
||||
let mut manager = McpServerManager::from_servers(&servers);
|
||||
|
||||
manager.discover_tools().await.expect("discover tools");
|
||||
let response = manager
|
||||
.call_tool(
|
||||
&mcp_tool_name("alpha", "echo"),
|
||||
Some(json!({"text": "reuse"})),
|
||||
)
|
||||
.await
|
||||
.expect("call tool");
|
||||
|
||||
assert_eq!(
|
||||
response
|
||||
.result
|
||||
.as_ref()
|
||||
.and_then(|result| result.structured_content.as_ref())
|
||||
.and_then(|value| value.get("initializeCount")),
|
||||
Some(&json!(1))
|
||||
);
|
||||
|
||||
let log = fs::read_to_string(&log_path).expect("read log");
|
||||
assert_eq!(log.lines().filter(|line| *line == "initialize").count(), 1);
|
||||
assert_eq!(
|
||||
log.lines().collect::<Vec<_>>(),
|
||||
vec!["initialize", "tools/list", "tools/call"]
|
||||
);
|
||||
|
||||
manager.shutdown().await.expect("shutdown");
|
||||
cleanup_script(&script_path);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn manager_reports_unknown_qualified_tool_name() {
|
||||
let runtime = Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.expect("runtime");
|
||||
runtime.block_on(async {
|
||||
let script_path = write_manager_mcp_server_script();
|
||||
let root = script_path.parent().expect("script parent");
|
||||
let log_path = root.join("alpha.log");
|
||||
let servers = BTreeMap::from([(
|
||||
"alpha".to_string(),
|
||||
manager_server_config(&script_path, "alpha", &log_path),
|
||||
)]);
|
||||
let mut manager = McpServerManager::from_servers(&servers);
|
||||
|
||||
let error = manager
|
||||
.call_tool(
|
||||
&mcp_tool_name("alpha", "missing"),
|
||||
Some(json!({"text": "nope"})),
|
||||
)
|
||||
.await
|
||||
.expect_err("unknown qualified tool should fail");
|
||||
|
||||
match error {
|
||||
McpServerManagerError::UnknownTool { qualified_name } => {
|
||||
assert_eq!(qualified_name, mcp_tool_name("alpha", "missing"));
|
||||
}
|
||||
other => panic!("expected unknown tool error, got {other:?}"),
|
||||
}
|
||||
|
||||
cleanup_script(&script_path);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
use std::collections::BTreeMap;
|
||||
use std::fs::File;
|
||||
use std::fs::{self, File};
|
||||
use std::io::{self, Read};
|
||||
use std::path::PathBuf;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{Map, Value};
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
use crate::config::OAuthConfig;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct OAuthTokenSet {
|
||||
pub access_token: String,
|
||||
pub refresh_token: Option<String>,
|
||||
@@ -65,6 +68,48 @@ pub struct OAuthRefreshRequest {
|
||||
pub scopes: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct OAuthCallbackParams {
|
||||
pub code: Option<String>,
|
||||
pub state: Option<String>,
|
||||
pub error: Option<String>,
|
||||
pub error_description: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct StoredOAuthCredentials {
|
||||
access_token: String,
|
||||
#[serde(default)]
|
||||
refresh_token: Option<String>,
|
||||
#[serde(default)]
|
||||
expires_at: Option<u64>,
|
||||
#[serde(default)]
|
||||
scopes: Vec<String>,
|
||||
}
|
||||
|
||||
impl From<OAuthTokenSet> for StoredOAuthCredentials {
|
||||
fn from(value: OAuthTokenSet) -> Self {
|
||||
Self {
|
||||
access_token: value.access_token,
|
||||
refresh_token: value.refresh_token,
|
||||
expires_at: value.expires_at,
|
||||
scopes: value.scopes,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<StoredOAuthCredentials> for OAuthTokenSet {
|
||||
fn from(value: StoredOAuthCredentials) -> Self {
|
||||
Self {
|
||||
access_token: value.access_token,
|
||||
refresh_token: value.refresh_token,
|
||||
expires_at: value.expires_at,
|
||||
scopes: value.scopes,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl OAuthAuthorizationRequest {
|
||||
#[must_use]
|
||||
pub fn from_config(
|
||||
@@ -137,7 +182,6 @@ impl OAuthTokenExchangeRequest {
|
||||
verifier: impl Into<String>,
|
||||
redirect_uri: impl Into<String>,
|
||||
) -> Self {
|
||||
let _ = config;
|
||||
Self {
|
||||
grant_type: "authorization_code",
|
||||
code: code.into(),
|
||||
@@ -211,12 +255,116 @@ pub fn loopback_redirect_uri(port: u16) -> String {
|
||||
format!("http://localhost:{port}/callback")
|
||||
}
|
||||
|
||||
pub fn credentials_path() -> io::Result<PathBuf> {
|
||||
Ok(credentials_home_dir()?.join("credentials.json"))
|
||||
}
|
||||
|
||||
pub fn load_oauth_credentials() -> io::Result<Option<OAuthTokenSet>> {
|
||||
let path = credentials_path()?;
|
||||
let root = read_credentials_root(&path)?;
|
||||
let Some(oauth) = root.get("oauth") else {
|
||||
return Ok(None);
|
||||
};
|
||||
if oauth.is_null() {
|
||||
return Ok(None);
|
||||
}
|
||||
let stored = serde_json::from_value::<StoredOAuthCredentials>(oauth.clone())
|
||||
.map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?;
|
||||
Ok(Some(stored.into()))
|
||||
}
|
||||
|
||||
pub fn save_oauth_credentials(token_set: &OAuthTokenSet) -> io::Result<()> {
|
||||
let path = credentials_path()?;
|
||||
let mut root = read_credentials_root(&path)?;
|
||||
root.insert(
|
||||
"oauth".to_string(),
|
||||
serde_json::to_value(StoredOAuthCredentials::from(token_set.clone()))
|
||||
.map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?,
|
||||
);
|
||||
write_credentials_root(&path, &root)
|
||||
}
|
||||
|
||||
pub fn clear_oauth_credentials() -> io::Result<()> {
|
||||
let path = credentials_path()?;
|
||||
let mut root = read_credentials_root(&path)?;
|
||||
root.remove("oauth");
|
||||
write_credentials_root(&path, &root)
|
||||
}
|
||||
|
||||
pub fn parse_oauth_callback_request_target(target: &str) -> Result<OAuthCallbackParams, String> {
|
||||
let (path, query) = target
|
||||
.split_once('?')
|
||||
.map_or((target, ""), |(path, query)| (path, query));
|
||||
if path != "/callback" {
|
||||
return Err(format!("unexpected callback path: {path}"));
|
||||
}
|
||||
parse_oauth_callback_query(query)
|
||||
}
|
||||
|
||||
pub fn parse_oauth_callback_query(query: &str) -> Result<OAuthCallbackParams, String> {
|
||||
let mut params = BTreeMap::new();
|
||||
for pair in query.split('&').filter(|pair| !pair.is_empty()) {
|
||||
let (key, value) = pair
|
||||
.split_once('=')
|
||||
.map_or((pair, ""), |(key, value)| (key, value));
|
||||
params.insert(percent_decode(key)?, percent_decode(value)?);
|
||||
}
|
||||
Ok(OAuthCallbackParams {
|
||||
code: params.get("code").cloned(),
|
||||
state: params.get("state").cloned(),
|
||||
error: params.get("error").cloned(),
|
||||
error_description: params.get("error_description").cloned(),
|
||||
})
|
||||
}
|
||||
|
||||
fn generate_random_token(bytes: usize) -> io::Result<String> {
|
||||
let mut buffer = vec![0_u8; bytes];
|
||||
File::open("/dev/urandom")?.read_exact(&mut buffer)?;
|
||||
Ok(base64url_encode(&buffer))
|
||||
}
|
||||
|
||||
fn credentials_home_dir() -> io::Result<PathBuf> {
|
||||
if let Some(path) = std::env::var_os("CLAUDE_CONFIG_HOME") {
|
||||
return Ok(PathBuf::from(path));
|
||||
}
|
||||
let home = std::env::var_os("HOME")
|
||||
.ok_or_else(|| io::Error::new(io::ErrorKind::NotFound, "HOME is not set"))?;
|
||||
Ok(PathBuf::from(home).join(".claude"))
|
||||
}
|
||||
|
||||
fn read_credentials_root(path: &PathBuf) -> io::Result<Map<String, Value>> {
|
||||
match fs::read_to_string(path) {
|
||||
Ok(contents) => {
|
||||
if contents.trim().is_empty() {
|
||||
return Ok(Map::new());
|
||||
}
|
||||
serde_json::from_str::<Value>(&contents)
|
||||
.map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?
|
||||
.as_object()
|
||||
.cloned()
|
||||
.ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"credentials file must contain a JSON object",
|
||||
)
|
||||
})
|
||||
}
|
||||
Err(error) if error.kind() == io::ErrorKind::NotFound => Ok(Map::new()),
|
||||
Err(error) => Err(error),
|
||||
}
|
||||
}
|
||||
|
||||
fn write_credentials_root(path: &PathBuf, root: &Map<String, Value>) -> io::Result<()> {
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent)?;
|
||||
}
|
||||
let rendered = serde_json::to_string_pretty(&Value::Object(root.clone()))
|
||||
.map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?;
|
||||
let temp_path = path.with_extension("json.tmp");
|
||||
fs::write(&temp_path, format!("{rendered}\n"))?;
|
||||
fs::rename(temp_path, path)
|
||||
}
|
||||
|
||||
fn base64url_encode(bytes: &[u8]) -> String {
|
||||
const TABLE: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
|
||||
let mut output = String::new();
|
||||
@@ -264,11 +412,49 @@ fn percent_encode(value: &str) -> String {
|
||||
encoded
|
||||
}
|
||||
|
||||
fn percent_decode(value: &str) -> Result<String, String> {
|
||||
let mut decoded = Vec::with_capacity(value.len());
|
||||
let bytes = value.as_bytes();
|
||||
let mut index = 0;
|
||||
while index < bytes.len() {
|
||||
match bytes[index] {
|
||||
b'%' if index + 2 < bytes.len() => {
|
||||
let hi = decode_hex(bytes[index + 1])?;
|
||||
let lo = decode_hex(bytes[index + 2])?;
|
||||
decoded.push((hi << 4) | lo);
|
||||
index += 3;
|
||||
}
|
||||
b'+' => {
|
||||
decoded.push(b' ');
|
||||
index += 1;
|
||||
}
|
||||
byte => {
|
||||
decoded.push(byte);
|
||||
index += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
String::from_utf8(decoded).map_err(|error| error.to_string())
|
||||
}
|
||||
|
||||
fn decode_hex(byte: u8) -> Result<u8, String> {
|
||||
match byte {
|
||||
b'0'..=b'9' => Ok(byte - b'0'),
|
||||
b'a'..=b'f' => Ok(byte - b'a' + 10),
|
||||
b'A'..=b'F' => Ok(byte - b'A' + 10),
|
||||
_ => Err(format!("invalid percent-encoding byte: {byte}")),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
use super::{
|
||||
code_challenge_s256, generate_pkce_pair, generate_state, loopback_redirect_uri,
|
||||
OAuthAuthorizationRequest, OAuthConfig, OAuthRefreshRequest, OAuthTokenExchangeRequest,
|
||||
clear_oauth_credentials, code_challenge_s256, credentials_path, generate_pkce_pair,
|
||||
generate_state, load_oauth_credentials, loopback_redirect_uri, parse_oauth_callback_query,
|
||||
parse_oauth_callback_request_target, save_oauth_credentials, OAuthAuthorizationRequest,
|
||||
OAuthConfig, OAuthRefreshRequest, OAuthTokenExchangeRequest, OAuthTokenSet,
|
||||
};
|
||||
|
||||
fn sample_config() -> OAuthConfig {
|
||||
@@ -282,6 +468,21 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
|
||||
crate::test_env_lock()
|
||||
}
|
||||
|
||||
fn temp_config_home() -> std::path::PathBuf {
|
||||
std::env::temp_dir().join(format!(
|
||||
"runtime-oauth-test-{}-{}",
|
||||
std::process::id(),
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.expect("time")
|
||||
.as_nanos()
|
||||
))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn s256_challenge_matches_expected_vector() {
|
||||
assert_eq!(
|
||||
@@ -335,4 +536,54 @@ mod tests {
|
||||
Some("org:read user:write")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn oauth_credentials_round_trip_and_clear_preserves_other_fields() {
|
||||
let _guard = env_lock();
|
||||
let config_home = temp_config_home();
|
||||
std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
|
||||
let path = credentials_path().expect("credentials path");
|
||||
std::fs::create_dir_all(path.parent().expect("parent")).expect("create parent");
|
||||
std::fs::write(&path, "{\"other\":\"value\"}\n").expect("seed credentials");
|
||||
|
||||
let token_set = OAuthTokenSet {
|
||||
access_token: "access-token".to_string(),
|
||||
refresh_token: Some("refresh-token".to_string()),
|
||||
expires_at: Some(123),
|
||||
scopes: vec!["scope:a".to_string()],
|
||||
};
|
||||
save_oauth_credentials(&token_set).expect("save credentials");
|
||||
assert_eq!(
|
||||
load_oauth_credentials().expect("load credentials"),
|
||||
Some(token_set)
|
||||
);
|
||||
let saved = std::fs::read_to_string(&path).expect("read saved file");
|
||||
assert!(saved.contains("\"other\": \"value\""));
|
||||
assert!(saved.contains("\"oauth\""));
|
||||
|
||||
clear_oauth_credentials().expect("clear credentials");
|
||||
assert_eq!(load_oauth_credentials().expect("load cleared"), None);
|
||||
let cleared = std::fs::read_to_string(&path).expect("read cleared file");
|
||||
assert!(cleared.contains("\"other\": \"value\""));
|
||||
assert!(!cleared.contains("\"oauth\""));
|
||||
|
||||
std::env::remove_var("CLAUDE_CONFIG_HOME");
|
||||
std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_callback_query_and_target() {
|
||||
let params =
|
||||
parse_oauth_callback_query("code=abc123&state=state-1&error_description=needs%20login")
|
||||
.expect("parse query");
|
||||
assert_eq!(params.code.as_deref(), Some("abc123"));
|
||||
assert_eq!(params.state.as_deref(), Some("state-1"));
|
||||
assert_eq!(params.error_description.as_deref(), Some("needs login"));
|
||||
|
||||
let params = parse_oauth_callback_request_target("/callback?code=abc&state=xyz")
|
||||
.expect("parse callback target");
|
||||
assert_eq!(params.code.as_deref(), Some("abc"));
|
||||
assert_eq!(params.state.as_deref(), Some("xyz"));
|
||||
assert!(parse_oauth_callback_request_target("/wrong?code=abc").is_err());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,16 +1,29 @@
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
|
||||
pub enum PermissionMode {
|
||||
Allow,
|
||||
Deny,
|
||||
Prompt,
|
||||
ReadOnly,
|
||||
WorkspaceWrite,
|
||||
DangerFullAccess,
|
||||
}
|
||||
|
||||
impl PermissionMode {
|
||||
#[must_use]
|
||||
pub fn as_str(self) -> &'static str {
|
||||
match self {
|
||||
Self::ReadOnly => "read-only",
|
||||
Self::WorkspaceWrite => "workspace-write",
|
||||
Self::DangerFullAccess => "danger-full-access",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct PermissionRequest {
|
||||
pub tool_name: String,
|
||||
pub input: String,
|
||||
pub current_mode: PermissionMode,
|
||||
pub required_mode: PermissionMode,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
@@ -31,31 +44,41 @@ pub enum PermissionOutcome {
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct PermissionPolicy {
|
||||
default_mode: PermissionMode,
|
||||
tool_modes: BTreeMap<String, PermissionMode>,
|
||||
active_mode: PermissionMode,
|
||||
tool_requirements: BTreeMap<String, PermissionMode>,
|
||||
}
|
||||
|
||||
impl PermissionPolicy {
|
||||
#[must_use]
|
||||
pub fn new(default_mode: PermissionMode) -> Self {
|
||||
pub fn new(active_mode: PermissionMode) -> Self {
|
||||
Self {
|
||||
default_mode,
|
||||
tool_modes: BTreeMap::new(),
|
||||
active_mode,
|
||||
tool_requirements: BTreeMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_tool_mode(mut self, tool_name: impl Into<String>, mode: PermissionMode) -> Self {
|
||||
self.tool_modes.insert(tool_name.into(), mode);
|
||||
pub fn with_tool_requirement(
|
||||
mut self,
|
||||
tool_name: impl Into<String>,
|
||||
required_mode: PermissionMode,
|
||||
) -> Self {
|
||||
self.tool_requirements
|
||||
.insert(tool_name.into(), required_mode);
|
||||
self
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn mode_for(&self, tool_name: &str) -> PermissionMode {
|
||||
self.tool_modes
|
||||
pub fn active_mode(&self) -> PermissionMode {
|
||||
self.active_mode
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn required_mode_for(&self, tool_name: &str) -> PermissionMode {
|
||||
self.tool_requirements
|
||||
.get(tool_name)
|
||||
.copied()
|
||||
.unwrap_or(self.default_mode)
|
||||
.unwrap_or(PermissionMode::DangerFullAccess)
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
@@ -65,23 +88,43 @@ impl PermissionPolicy {
|
||||
input: &str,
|
||||
mut prompter: Option<&mut dyn PermissionPrompter>,
|
||||
) -> PermissionOutcome {
|
||||
match self.mode_for(tool_name) {
|
||||
PermissionMode::Allow => PermissionOutcome::Allow,
|
||||
PermissionMode::Deny => PermissionOutcome::Deny {
|
||||
reason: format!("tool '{tool_name}' denied by permission policy"),
|
||||
},
|
||||
PermissionMode::Prompt => match prompter.as_mut() {
|
||||
Some(prompter) => match prompter.decide(&PermissionRequest {
|
||||
tool_name: tool_name.to_string(),
|
||||
input: input.to_string(),
|
||||
}) {
|
||||
let current_mode = self.active_mode();
|
||||
let required_mode = self.required_mode_for(tool_name);
|
||||
if current_mode >= required_mode {
|
||||
return PermissionOutcome::Allow;
|
||||
}
|
||||
|
||||
let request = PermissionRequest {
|
||||
tool_name: tool_name.to_string(),
|
||||
input: input.to_string(),
|
||||
current_mode,
|
||||
required_mode,
|
||||
};
|
||||
|
||||
if current_mode == PermissionMode::WorkspaceWrite
|
||||
&& required_mode == PermissionMode::DangerFullAccess
|
||||
{
|
||||
return match prompter.as_mut() {
|
||||
Some(prompter) => match prompter.decide(&request) {
|
||||
PermissionPromptDecision::Allow => PermissionOutcome::Allow,
|
||||
PermissionPromptDecision::Deny { reason } => PermissionOutcome::Deny { reason },
|
||||
},
|
||||
None => PermissionOutcome::Deny {
|
||||
reason: format!("tool '{tool_name}' requires interactive approval"),
|
||||
reason: format!(
|
||||
"tool '{tool_name}' requires approval to escalate from {} to {}",
|
||||
current_mode.as_str(),
|
||||
required_mode.as_str()
|
||||
),
|
||||
},
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
PermissionOutcome::Deny {
|
||||
reason: format!(
|
||||
"tool '{tool_name}' requires {} permission; current mode is {}",
|
||||
required_mode.as_str(),
|
||||
current_mode.as_str()
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -93,25 +136,92 @@ mod tests {
|
||||
PermissionPrompter, PermissionRequest,
|
||||
};
|
||||
|
||||
struct AllowPrompter;
|
||||
struct RecordingPrompter {
|
||||
seen: Vec<PermissionRequest>,
|
||||
allow: bool,
|
||||
}
|
||||
|
||||
impl PermissionPrompter for AllowPrompter {
|
||||
impl PermissionPrompter for RecordingPrompter {
|
||||
fn decide(&mut self, request: &PermissionRequest) -> PermissionPromptDecision {
|
||||
assert_eq!(request.tool_name, "bash");
|
||||
PermissionPromptDecision::Allow
|
||||
self.seen.push(request.clone());
|
||||
if self.allow {
|
||||
PermissionPromptDecision::Allow
|
||||
} else {
|
||||
PermissionPromptDecision::Deny {
|
||||
reason: "not now".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn uses_tool_specific_overrides() {
|
||||
let policy = PermissionPolicy::new(PermissionMode::Deny)
|
||||
.with_tool_mode("bash", PermissionMode::Prompt);
|
||||
fn allows_tools_when_active_mode_meets_requirement() {
|
||||
let policy = PermissionPolicy::new(PermissionMode::WorkspaceWrite)
|
||||
.with_tool_requirement("read_file", PermissionMode::ReadOnly)
|
||||
.with_tool_requirement("write_file", PermissionMode::WorkspaceWrite);
|
||||
|
||||
assert_eq!(
|
||||
policy.authorize("read_file", "{}", None),
|
||||
PermissionOutcome::Allow
|
||||
);
|
||||
assert_eq!(
|
||||
policy.authorize("write_file", "{}", None),
|
||||
PermissionOutcome::Allow
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn denies_read_only_escalations_without_prompt() {
|
||||
let policy = PermissionPolicy::new(PermissionMode::ReadOnly)
|
||||
.with_tool_requirement("write_file", PermissionMode::WorkspaceWrite)
|
||||
.with_tool_requirement("bash", PermissionMode::DangerFullAccess);
|
||||
|
||||
let outcome = policy.authorize("bash", "echo hi", Some(&mut AllowPrompter));
|
||||
assert_eq!(outcome, PermissionOutcome::Allow);
|
||||
assert!(matches!(
|
||||
policy.authorize("edit", "x", None),
|
||||
PermissionOutcome::Deny { .. }
|
||||
policy.authorize("write_file", "{}", None),
|
||||
PermissionOutcome::Deny { reason } if reason.contains("requires workspace-write permission")
|
||||
));
|
||||
assert!(matches!(
|
||||
policy.authorize("bash", "{}", None),
|
||||
PermissionOutcome::Deny { reason } if reason.contains("requires danger-full-access permission")
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prompts_for_workspace_write_to_danger_full_access_escalation() {
|
||||
let policy = PermissionPolicy::new(PermissionMode::WorkspaceWrite)
|
||||
.with_tool_requirement("bash", PermissionMode::DangerFullAccess);
|
||||
let mut prompter = RecordingPrompter {
|
||||
seen: Vec::new(),
|
||||
allow: true,
|
||||
};
|
||||
|
||||
let outcome = policy.authorize("bash", "echo hi", Some(&mut prompter));
|
||||
|
||||
assert_eq!(outcome, PermissionOutcome::Allow);
|
||||
assert_eq!(prompter.seen.len(), 1);
|
||||
assert_eq!(prompter.seen[0].tool_name, "bash");
|
||||
assert_eq!(
|
||||
prompter.seen[0].current_mode,
|
||||
PermissionMode::WorkspaceWrite
|
||||
);
|
||||
assert_eq!(
|
||||
prompter.seen[0].required_mode,
|
||||
PermissionMode::DangerFullAccess
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn honors_prompt_rejection_reason() {
|
||||
let policy = PermissionPolicy::new(PermissionMode::WorkspaceWrite)
|
||||
.with_tool_requirement("bash", PermissionMode::DangerFullAccess);
|
||||
let mut prompter = RecordingPrompter {
|
||||
seen: Vec::new(),
|
||||
allow: false,
|
||||
};
|
||||
|
||||
assert!(matches!(
|
||||
policy.authorize("bash", "echo hi", Some(&mut prompter)),
|
||||
PermissionOutcome::Deny { reason } if reason == "not now"
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -51,6 +51,7 @@ pub struct ProjectContext {
|
||||
pub current_date: String,
|
||||
pub git_status: Option<String>,
|
||||
pub instruction_files: Vec<ContextFile>,
|
||||
pub memory_files: Vec<ContextFile>,
|
||||
}
|
||||
|
||||
impl ProjectContext {
|
||||
@@ -60,11 +61,13 @@ impl ProjectContext {
|
||||
) -> std::io::Result<Self> {
|
||||
let cwd = cwd.into();
|
||||
let instruction_files = discover_instruction_files(&cwd)?;
|
||||
let memory_files = discover_memory_files(&cwd)?;
|
||||
Ok(Self {
|
||||
cwd,
|
||||
current_date: current_date.into(),
|
||||
git_status: None,
|
||||
instruction_files,
|
||||
memory_files,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -144,6 +147,9 @@ impl SystemPromptBuilder {
|
||||
if !project_context.instruction_files.is_empty() {
|
||||
sections.push(render_instruction_files(&project_context.instruction_files));
|
||||
}
|
||||
if !project_context.memory_files.is_empty() {
|
||||
sections.push(render_memory_files(&project_context.memory_files));
|
||||
}
|
||||
}
|
||||
if let Some(config) = &self.config {
|
||||
sections.push(render_config_section(config));
|
||||
@@ -186,7 +192,7 @@ pub fn prepend_bullets(items: Vec<String>) -> Vec<String> {
|
||||
items.into_iter().map(|item| format!(" - {item}")).collect()
|
||||
}
|
||||
|
||||
fn discover_instruction_files(cwd: &Path) -> std::io::Result<Vec<ContextFile>> {
|
||||
fn discover_context_directories(cwd: &Path) -> Vec<PathBuf> {
|
||||
let mut directories = Vec::new();
|
||||
let mut cursor = Some(cwd);
|
||||
while let Some(dir) = cursor {
|
||||
@@ -194,6 +200,11 @@ fn discover_instruction_files(cwd: &Path) -> std::io::Result<Vec<ContextFile>> {
|
||||
cursor = dir.parent();
|
||||
}
|
||||
directories.reverse();
|
||||
directories
|
||||
}
|
||||
|
||||
fn discover_instruction_files(cwd: &Path) -> std::io::Result<Vec<ContextFile>> {
|
||||
let directories = discover_context_directories(cwd);
|
||||
|
||||
let mut files = Vec::new();
|
||||
for dir in directories {
|
||||
@@ -201,6 +212,7 @@ fn discover_instruction_files(cwd: &Path) -> std::io::Result<Vec<ContextFile>> {
|
||||
dir.join("CLAUDE.md"),
|
||||
dir.join("CLAUDE.local.md"),
|
||||
dir.join(".claude").join("CLAUDE.md"),
|
||||
dir.join(".claude").join("instructions.md"),
|
||||
] {
|
||||
push_context_file(&mut files, candidate)?;
|
||||
}
|
||||
@@ -208,6 +220,26 @@ fn discover_instruction_files(cwd: &Path) -> std::io::Result<Vec<ContextFile>> {
|
||||
Ok(dedupe_instruction_files(files))
|
||||
}
|
||||
|
||||
fn discover_memory_files(cwd: &Path) -> std::io::Result<Vec<ContextFile>> {
|
||||
let mut files = Vec::new();
|
||||
for dir in discover_context_directories(cwd) {
|
||||
let memory_dir = dir.join(".claude").join("memory");
|
||||
let Ok(entries) = fs::read_dir(&memory_dir) else {
|
||||
continue;
|
||||
};
|
||||
let mut paths = entries
|
||||
.flatten()
|
||||
.map(|entry| entry.path())
|
||||
.filter(|path| path.is_file())
|
||||
.collect::<Vec<_>>();
|
||||
paths.sort();
|
||||
for path in paths {
|
||||
push_context_file(&mut files, path)?;
|
||||
}
|
||||
}
|
||||
Ok(dedupe_instruction_files(files))
|
||||
}
|
||||
|
||||
fn push_context_file(files: &mut Vec<ContextFile>, path: PathBuf) -> std::io::Result<()> {
|
||||
match fs::read_to_string(&path) {
|
||||
Ok(content) if !content.trim().is_empty() => {
|
||||
@@ -250,6 +282,12 @@ fn render_project_context(project_context: &ProjectContext) -> String {
|
||||
project_context.instruction_files.len()
|
||||
));
|
||||
}
|
||||
if !project_context.memory_files.is_empty() {
|
||||
bullets.push(format!(
|
||||
"Project memory files discovered: {}.",
|
||||
project_context.memory_files.len()
|
||||
));
|
||||
}
|
||||
lines.extend(prepend_bullets(bullets));
|
||||
if let Some(status) = &project_context.git_status {
|
||||
lines.push(String::new());
|
||||
@@ -260,7 +298,15 @@ fn render_project_context(project_context: &ProjectContext) -> String {
|
||||
}
|
||||
|
||||
fn render_instruction_files(files: &[ContextFile]) -> String {
|
||||
let mut sections = vec!["# Claude instructions".to_string()];
|
||||
render_context_file_section("# Claude instructions", files)
|
||||
}
|
||||
|
||||
fn render_memory_files(files: &[ContextFile]) -> String {
|
||||
render_context_file_section("# Project memory", files)
|
||||
}
|
||||
|
||||
fn render_context_file_section(title: &str, files: &[ContextFile]) -> String {
|
||||
let mut sections = vec![title.to_string()];
|
||||
let mut remaining_chars = MAX_TOTAL_INSTRUCTION_CHARS;
|
||||
for file in files {
|
||||
if remaining_chars == 0 {
|
||||
@@ -452,8 +498,9 @@ fn get_actions_section() -> String {
|
||||
mod tests {
|
||||
use super::{
|
||||
collapse_blank_lines, display_context_path, normalize_instruction_content,
|
||||
render_instruction_content, render_instruction_files, truncate_instruction_content,
|
||||
ContextFile, ProjectContext, SystemPromptBuilder, SYSTEM_PROMPT_DYNAMIC_BOUNDARY,
|
||||
render_instruction_content, render_instruction_files, render_memory_files,
|
||||
truncate_instruction_content, ContextFile, ProjectContext, SystemPromptBuilder,
|
||||
SYSTEM_PROMPT_DYNAMIC_BOUNDARY,
|
||||
};
|
||||
use crate::config::ConfigLoader;
|
||||
use std::fs;
|
||||
@@ -468,6 +515,10 @@ mod tests {
|
||||
std::env::temp_dir().join(format!("runtime-prompt-{nanos}"))
|
||||
}
|
||||
|
||||
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
|
||||
crate::test_env_lock()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn discovers_instruction_files_from_ancestor_chain() {
|
||||
let root = temp_dir();
|
||||
@@ -477,10 +528,21 @@ mod tests {
|
||||
fs::write(root.join("CLAUDE.local.md"), "local instructions")
|
||||
.expect("write local instructions");
|
||||
fs::create_dir_all(root.join("apps")).expect("apps dir");
|
||||
fs::create_dir_all(root.join("apps").join(".claude")).expect("apps claude dir");
|
||||
fs::write(root.join("apps").join("CLAUDE.md"), "apps instructions")
|
||||
.expect("write apps instructions");
|
||||
fs::write(
|
||||
root.join("apps").join(".claude").join("instructions.md"),
|
||||
"apps dot claude instructions",
|
||||
)
|
||||
.expect("write apps dot claude instructions");
|
||||
fs::write(nested.join(".claude").join("CLAUDE.md"), "nested rules")
|
||||
.expect("write nested rules");
|
||||
fs::write(
|
||||
nested.join(".claude").join("instructions.md"),
|
||||
"nested instructions",
|
||||
)
|
||||
.expect("write nested instructions");
|
||||
|
||||
let context = ProjectContext::discover(&nested, "2026-03-31").expect("context should load");
|
||||
let contents = context
|
||||
@@ -495,12 +557,43 @@ mod tests {
|
||||
"root instructions",
|
||||
"local instructions",
|
||||
"apps instructions",
|
||||
"nested rules"
|
||||
"apps dot claude instructions",
|
||||
"nested rules",
|
||||
"nested instructions"
|
||||
]
|
||||
);
|
||||
fs::remove_dir_all(root).expect("cleanup temp dir");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn discovers_project_memory_files_from_ancestor_chain() {
|
||||
let root = temp_dir();
|
||||
let nested = root.join("apps").join("api");
|
||||
fs::create_dir_all(root.join(".claude").join("memory")).expect("root memory dir");
|
||||
fs::create_dir_all(nested.join(".claude").join("memory")).expect("nested memory dir");
|
||||
fs::write(
|
||||
root.join(".claude").join("memory").join("2026-03-30.md"),
|
||||
"root memory",
|
||||
)
|
||||
.expect("write root memory");
|
||||
fs::write(
|
||||
nested.join(".claude").join("memory").join("2026-03-31.md"),
|
||||
"nested memory",
|
||||
)
|
||||
.expect("write nested memory");
|
||||
|
||||
let context = ProjectContext::discover(&nested, "2026-03-31").expect("context should load");
|
||||
let contents = context
|
||||
.memory_files
|
||||
.iter()
|
||||
.map(|file| file.content.as_str())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(contents, vec!["root memory", "nested memory"]);
|
||||
assert!(render_memory_files(&context.memory_files).contains("# Project memory"));
|
||||
fs::remove_dir_all(root).expect("cleanup temp dir");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dedupes_identical_instruction_content_across_scopes() {
|
||||
let root = temp_dir();
|
||||
@@ -574,7 +667,12 @@ mod tests {
|
||||
)
|
||||
.expect("write settings");
|
||||
|
||||
let _guard = env_lock();
|
||||
let previous = std::env::current_dir().expect("cwd");
|
||||
let original_home = std::env::var("HOME").ok();
|
||||
let original_claude_home = std::env::var("CLAUDE_CONFIG_HOME").ok();
|
||||
std::env::set_var("HOME", &root);
|
||||
std::env::set_var("CLAUDE_CONFIG_HOME", root.join("missing-home"));
|
||||
std::env::set_current_dir(&root).expect("change cwd");
|
||||
let prompt = super::load_system_prompt(&root, "2026-03-31", "linux", "6.8")
|
||||
.expect("system prompt should load")
|
||||
@@ -584,6 +682,16 @@ mod tests {
|
||||
",
|
||||
);
|
||||
std::env::set_current_dir(previous).expect("restore cwd");
|
||||
if let Some(value) = original_home {
|
||||
std::env::set_var("HOME", value);
|
||||
} else {
|
||||
std::env::remove_var("HOME");
|
||||
}
|
||||
if let Some(value) = original_claude_home {
|
||||
std::env::set_var("CLAUDE_CONFIG_HOME", value);
|
||||
} else {
|
||||
std::env::remove_var("CLAUDE_CONFIG_HOME");
|
||||
}
|
||||
|
||||
assert!(prompt.contains("Project rules"));
|
||||
assert!(prompt.contains("permissionMode"));
|
||||
@@ -631,6 +739,29 @@ mod tests {
|
||||
assert!(rendered.chars().count() <= 4_000 + "\n\n[truncated]".chars().count());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn discovers_dot_claude_instructions_markdown() {
|
||||
let root = temp_dir();
|
||||
let nested = root.join("apps").join("api");
|
||||
fs::create_dir_all(nested.join(".claude")).expect("nested claude dir");
|
||||
fs::write(
|
||||
nested.join(".claude").join("instructions.md"),
|
||||
"instruction markdown",
|
||||
)
|
||||
.expect("write instructions.md");
|
||||
|
||||
let context = ProjectContext::discover(&nested, "2026-03-31").expect("context should load");
|
||||
assert!(context
|
||||
.instruction_files
|
||||
.iter()
|
||||
.any(|file| file.path.ends_with(".claude/instructions.md")));
|
||||
assert!(
|
||||
render_instruction_files(&context.instruction_files).contains("instruction markdown")
|
||||
);
|
||||
|
||||
fs::remove_dir_all(root).expect("cleanup temp dir");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn renders_instruction_file_metadata() {
|
||||
let rendered = render_instruction_files(&[ContextFile {
|
||||
|
||||
@@ -31,6 +31,10 @@ pub enum Command {
|
||||
DumpManifests,
|
||||
/// Print the current bootstrap phase skeleton
|
||||
BootstrapPlan,
|
||||
/// Start the OAuth login flow
|
||||
Login,
|
||||
/// Clear saved OAuth credentials
|
||||
Logout,
|
||||
/// Run a non-interactive prompt and exit
|
||||
Prompt { prompt: Vec<String> },
|
||||
}
|
||||
@@ -86,4 +90,13 @@ mod tests {
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_login_and_logout_commands() {
|
||||
let login = Cli::parse_from(["rusty-claude-cli", "login"]);
|
||||
assert_eq!(login.command, Some(Command::Login));
|
||||
|
||||
let logout = Cli::parse_from(["rusty-claude-cli", "logout"]);
|
||||
assert_eq!(logout.command, Some(Command::Logout));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,14 +4,16 @@ mod render;
|
||||
use std::collections::{BTreeMap, BTreeSet};
|
||||
use std::env;
|
||||
use std::fs;
|
||||
use std::io::{self, Write};
|
||||
use std::io::{self, Read, Write};
|
||||
use std::net::TcpListener;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::process::Command;
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
use api::{
|
||||
AnthropicClient, ContentBlockDelta, InputContentBlock, InputMessage, MessageRequest,
|
||||
MessageResponse, OutputContentBlock, StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition,
|
||||
ToolResultContentBlock,
|
||||
resolve_startup_auth_source, AnthropicClient, AuthSource, ContentBlockDelta, InputContentBlock,
|
||||
InputMessage, MessageRequest, MessageResponse, OutputContentBlock,
|
||||
StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition, ToolResultContentBlock,
|
||||
};
|
||||
|
||||
use commands::{
|
||||
@@ -20,17 +22,20 @@ use commands::{
|
||||
use compat_harness::{extract_manifest, UpstreamPaths};
|
||||
use render::{Spinner, TerminalRenderer};
|
||||
use runtime::{
|
||||
load_system_prompt, ApiClient, ApiRequest, AssistantEvent, CompactionConfig, ConfigLoader,
|
||||
ConfigSource, ContentBlock, ConversationMessage, ConversationRuntime, MessageRole,
|
||||
PermissionMode, PermissionPolicy, ProjectContext, RuntimeError, Session, TokenUsage, ToolError,
|
||||
ToolExecutor, UsageTracker,
|
||||
clear_oauth_credentials, generate_pkce_pair, generate_state, load_system_prompt,
|
||||
parse_oauth_callback_request_target, save_oauth_credentials, ApiClient, ApiRequest,
|
||||
AssistantEvent, CompactionConfig, ConfigLoader, ConfigSource, ContentBlock,
|
||||
ConversationMessage, ConversationRuntime, MessageRole, OAuthAuthorizationRequest,
|
||||
OAuthTokenExchangeRequest, PermissionMode, PermissionPolicy, ProjectContext, RuntimeError,
|
||||
Session, TokenUsage, ToolError, ToolExecutor, UsageTracker,
|
||||
};
|
||||
use serde_json::json;
|
||||
use tools::{execute_tool, mvp_tool_specs};
|
||||
use tools::{execute_tool, mvp_tool_specs, ToolSpec};
|
||||
|
||||
const DEFAULT_MODEL: &str = "claude-sonnet-4-20250514";
|
||||
const DEFAULT_MAX_TOKENS: u32 = 32;
|
||||
const DEFAULT_DATE: &str = "2026-03-31";
|
||||
const DEFAULT_OAUTH_CALLBACK_PORT: u16 = 4545;
|
||||
const VERSION: &str = env!("CARGO_PKG_VERSION");
|
||||
const BUILD_TARGET: Option<&str> = option_env!("TARGET");
|
||||
const GIT_SHA: Option<&str> = option_env!("GIT_SHA");
|
||||
@@ -64,12 +69,16 @@ fn run() -> Result<(), Box<dyn std::error::Error>> {
|
||||
model,
|
||||
output_format,
|
||||
allowed_tools,
|
||||
} => LiveCli::new(model, false, allowed_tools)?
|
||||
permission_mode,
|
||||
} => LiveCli::new(model, false, allowed_tools, permission_mode)?
|
||||
.run_turn_with_output(&prompt, output_format)?,
|
||||
CliAction::Login => run_login()?,
|
||||
CliAction::Logout => run_logout()?,
|
||||
CliAction::Repl {
|
||||
model,
|
||||
allowed_tools,
|
||||
} => run_repl(model, allowed_tools)?,
|
||||
permission_mode,
|
||||
} => run_repl(model, allowed_tools, permission_mode)?,
|
||||
CliAction::Help => print_help(),
|
||||
}
|
||||
Ok(())
|
||||
@@ -93,10 +102,14 @@ enum CliAction {
|
||||
model: String,
|
||||
output_format: CliOutputFormat,
|
||||
allowed_tools: Option<AllowedToolSet>,
|
||||
permission_mode: PermissionMode,
|
||||
},
|
||||
Login,
|
||||
Logout,
|
||||
Repl {
|
||||
model: String,
|
||||
allowed_tools: Option<AllowedToolSet>,
|
||||
permission_mode: PermissionMode,
|
||||
},
|
||||
// prompt-mode formatting is only supported for non-interactive runs
|
||||
Help,
|
||||
@@ -120,9 +133,11 @@ impl CliOutputFormat {
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
fn parse_args(args: &[String]) -> Result<CliAction, String> {
|
||||
let mut model = DEFAULT_MODEL.to_string();
|
||||
let mut output_format = CliOutputFormat::Text;
|
||||
let mut permission_mode = default_permission_mode();
|
||||
let mut wants_version = false;
|
||||
let mut allowed_tool_values = Vec::new();
|
||||
let mut rest = Vec::new();
|
||||
@@ -152,10 +167,21 @@ fn parse_args(args: &[String]) -> Result<CliAction, String> {
|
||||
output_format = CliOutputFormat::parse(value)?;
|
||||
index += 2;
|
||||
}
|
||||
"--permission-mode" => {
|
||||
let value = args
|
||||
.get(index + 1)
|
||||
.ok_or_else(|| "missing value for --permission-mode".to_string())?;
|
||||
permission_mode = parse_permission_mode_arg(value)?;
|
||||
index += 2;
|
||||
}
|
||||
flag if flag.starts_with("--output-format=") => {
|
||||
output_format = CliOutputFormat::parse(&flag[16..])?;
|
||||
index += 1;
|
||||
}
|
||||
flag if flag.starts_with("--permission-mode=") => {
|
||||
permission_mode = parse_permission_mode_arg(&flag[18..])?;
|
||||
index += 1;
|
||||
}
|
||||
"--allowedTools" | "--allowed-tools" => {
|
||||
let value = args
|
||||
.get(index + 1)
|
||||
@@ -188,6 +214,7 @@ fn parse_args(args: &[String]) -> Result<CliAction, String> {
|
||||
return Ok(CliAction::Repl {
|
||||
model,
|
||||
allowed_tools,
|
||||
permission_mode,
|
||||
});
|
||||
}
|
||||
if matches!(rest.first().map(String::as_str), Some("--help" | "-h")) {
|
||||
@@ -201,6 +228,8 @@ fn parse_args(args: &[String]) -> Result<CliAction, String> {
|
||||
"dump-manifests" => Ok(CliAction::DumpManifests),
|
||||
"bootstrap-plan" => Ok(CliAction::BootstrapPlan),
|
||||
"system-prompt" => parse_system_prompt_args(&rest[1..]),
|
||||
"login" => Ok(CliAction::Login),
|
||||
"logout" => Ok(CliAction::Logout),
|
||||
"prompt" => {
|
||||
let prompt = rest[1..].join(" ");
|
||||
if prompt.trim().is_empty() {
|
||||
@@ -211,6 +240,7 @@ fn parse_args(args: &[String]) -> Result<CliAction, String> {
|
||||
model,
|
||||
output_format,
|
||||
allowed_tools,
|
||||
permission_mode,
|
||||
})
|
||||
}
|
||||
other if !other.starts_with('/') => Ok(CliAction::Prompt {
|
||||
@@ -218,6 +248,7 @@ fn parse_args(args: &[String]) -> Result<CliAction, String> {
|
||||
model,
|
||||
output_format,
|
||||
allowed_tools,
|
||||
permission_mode,
|
||||
}),
|
||||
other => Err(format!("unknown subcommand: {other}")),
|
||||
}
|
||||
@@ -271,6 +302,33 @@ fn normalize_tool_name(value: &str) -> String {
|
||||
value.trim().replace('-', "_").to_ascii_lowercase()
|
||||
}
|
||||
|
||||
fn parse_permission_mode_arg(value: &str) -> Result<PermissionMode, String> {
|
||||
normalize_permission_mode(value)
|
||||
.ok_or_else(|| {
|
||||
format!(
|
||||
"unsupported permission mode '{value}'. Use read-only, workspace-write, or danger-full-access."
|
||||
)
|
||||
})
|
||||
.map(permission_mode_from_label)
|
||||
}
|
||||
|
||||
fn permission_mode_from_label(mode: &str) -> PermissionMode {
|
||||
match mode {
|
||||
"read-only" => PermissionMode::ReadOnly,
|
||||
"workspace-write" => PermissionMode::WorkspaceWrite,
|
||||
"danger-full-access" => PermissionMode::DangerFullAccess,
|
||||
other => panic!("unsupported permission mode label: {other}"),
|
||||
}
|
||||
}
|
||||
|
||||
fn default_permission_mode() -> PermissionMode {
|
||||
env::var("RUSTY_CLAUDE_PERMISSION_MODE")
|
||||
.ok()
|
||||
.as_deref()
|
||||
.and_then(normalize_permission_mode)
|
||||
.map_or(PermissionMode::WorkspaceWrite, permission_mode_from_label)
|
||||
}
|
||||
|
||||
fn filter_tool_specs(allowed_tools: Option<&AllowedToolSet>) -> Vec<tools::ToolSpec> {
|
||||
mvp_tool_specs()
|
||||
.into_iter()
|
||||
@@ -346,6 +404,122 @@ fn print_bootstrap_plan() {
|
||||
}
|
||||
}
|
||||
|
||||
fn run_login() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let cwd = env::current_dir()?;
|
||||
let config = ConfigLoader::default_for(&cwd).load()?;
|
||||
let oauth = config.oauth().ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::NotFound,
|
||||
"OAuth config is missing. Add settings.oauth.clientId/authorizeUrl/tokenUrl first.",
|
||||
)
|
||||
})?;
|
||||
let callback_port = oauth.callback_port.unwrap_or(DEFAULT_OAUTH_CALLBACK_PORT);
|
||||
let redirect_uri = runtime::loopback_redirect_uri(callback_port);
|
||||
let pkce = generate_pkce_pair()?;
|
||||
let state = generate_state()?;
|
||||
let authorize_url =
|
||||
OAuthAuthorizationRequest::from_config(oauth, redirect_uri.clone(), state.clone(), &pkce)
|
||||
.build_url();
|
||||
|
||||
println!("Starting Claude OAuth login...");
|
||||
println!("Listening for callback on {redirect_uri}");
|
||||
if let Err(error) = open_browser(&authorize_url) {
|
||||
eprintln!("warning: failed to open browser automatically: {error}");
|
||||
println!("Open this URL manually:\n{authorize_url}");
|
||||
}
|
||||
|
||||
let callback = wait_for_oauth_callback(callback_port)?;
|
||||
if let Some(error) = callback.error {
|
||||
let description = callback
|
||||
.error_description
|
||||
.unwrap_or_else(|| "authorization failed".to_string());
|
||||
return Err(io::Error::other(format!("{error}: {description}")).into());
|
||||
}
|
||||
let code = callback.code.ok_or_else(|| {
|
||||
io::Error::new(io::ErrorKind::InvalidData, "callback did not include code")
|
||||
})?;
|
||||
let returned_state = callback.state.ok_or_else(|| {
|
||||
io::Error::new(io::ErrorKind::InvalidData, "callback did not include state")
|
||||
})?;
|
||||
if returned_state != state {
|
||||
return Err(io::Error::new(io::ErrorKind::InvalidData, "oauth state mismatch").into());
|
||||
}
|
||||
|
||||
let client = AnthropicClient::from_auth(AuthSource::None);
|
||||
let exchange_request =
|
||||
OAuthTokenExchangeRequest::from_config(oauth, code, state, pkce.verifier, redirect_uri);
|
||||
let runtime = tokio::runtime::Runtime::new()?;
|
||||
let token_set = runtime.block_on(client.exchange_oauth_code(oauth, &exchange_request))?;
|
||||
save_oauth_credentials(&runtime::OAuthTokenSet {
|
||||
access_token: token_set.access_token,
|
||||
refresh_token: token_set.refresh_token,
|
||||
expires_at: token_set.expires_at,
|
||||
scopes: token_set.scopes,
|
||||
})?;
|
||||
println!("Claude OAuth login complete.");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_logout() -> Result<(), Box<dyn std::error::Error>> {
|
||||
clear_oauth_credentials()?;
|
||||
println!("Claude OAuth credentials cleared.");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn open_browser(url: &str) -> io::Result<()> {
|
||||
let commands = if cfg!(target_os = "macos") {
|
||||
vec![("open", vec![url])]
|
||||
} else if cfg!(target_os = "windows") {
|
||||
vec![("cmd", vec!["/C", "start", "", url])]
|
||||
} else {
|
||||
vec![("xdg-open", vec![url])]
|
||||
};
|
||||
for (program, args) in commands {
|
||||
match Command::new(program).args(args).spawn() {
|
||||
Ok(_) => return Ok(()),
|
||||
Err(error) if error.kind() == io::ErrorKind::NotFound => {}
|
||||
Err(error) => return Err(error),
|
||||
}
|
||||
}
|
||||
Err(io::Error::new(
|
||||
io::ErrorKind::NotFound,
|
||||
"no supported browser opener command found",
|
||||
))
|
||||
}
|
||||
|
||||
fn wait_for_oauth_callback(
|
||||
port: u16,
|
||||
) -> Result<runtime::OAuthCallbackParams, Box<dyn std::error::Error>> {
|
||||
let listener = TcpListener::bind(("127.0.0.1", port))?;
|
||||
let (mut stream, _) = listener.accept()?;
|
||||
let mut buffer = [0_u8; 4096];
|
||||
let bytes_read = stream.read(&mut buffer)?;
|
||||
let request = String::from_utf8_lossy(&buffer[..bytes_read]);
|
||||
let request_line = request.lines().next().ok_or_else(|| {
|
||||
io::Error::new(io::ErrorKind::InvalidData, "missing callback request line")
|
||||
})?;
|
||||
let target = request_line.split_whitespace().nth(1).ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"missing callback request target",
|
||||
)
|
||||
})?;
|
||||
let callback = parse_oauth_callback_request_target(target)
|
||||
.map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?;
|
||||
let body = if callback.error.is_some() {
|
||||
"Claude OAuth login failed. You can close this window."
|
||||
} else {
|
||||
"Claude OAuth login succeeded. You can close this window."
|
||||
};
|
||||
let response = format!(
|
||||
"HTTP/1.1 200 OK\r\ncontent-type: text/plain; charset=utf-8\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{}",
|
||||
body.len(),
|
||||
body
|
||||
);
|
||||
stream.write_all(response.as_bytes())?;
|
||||
Ok(callback)
|
||||
}
|
||||
|
||||
fn print_system_prompt(cwd: PathBuf, date: String) {
|
||||
match load_system_prompt(cwd, date, env::consts::OS, "unknown") {
|
||||
Ok(sections) => println!("{}", sections.join("\n\n")),
|
||||
@@ -661,7 +835,7 @@ fn run_resume_command(
|
||||
cumulative: usage,
|
||||
estimated_tokens: 0,
|
||||
},
|
||||
permission_mode_label(),
|
||||
default_permission_mode().as_str(),
|
||||
&status_context(Some(session_path))?,
|
||||
)),
|
||||
})
|
||||
@@ -716,8 +890,9 @@ fn run_resume_command(
|
||||
fn run_repl(
|
||||
model: String,
|
||||
allowed_tools: Option<AllowedToolSet>,
|
||||
permission_mode: PermissionMode,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let mut cli = LiveCli::new(model, true, allowed_tools)?;
|
||||
let mut cli = LiveCli::new(model, true, allowed_tools, permission_mode)?;
|
||||
let mut editor = input::LineEditor::new("› ", slash_command_completion_candidates());
|
||||
println!("{}", cli.startup_banner());
|
||||
|
||||
@@ -769,6 +944,7 @@ struct ManagedSessionSummary {
|
||||
struct LiveCli {
|
||||
model: String,
|
||||
allowed_tools: Option<AllowedToolSet>,
|
||||
permission_mode: PermissionMode,
|
||||
system_prompt: Vec<String>,
|
||||
runtime: ConversationRuntime<AnthropicRuntimeClient, CliToolExecutor>,
|
||||
session: SessionHandle,
|
||||
@@ -779,6 +955,7 @@ impl LiveCli {
|
||||
model: String,
|
||||
enable_tools: bool,
|
||||
allowed_tools: Option<AllowedToolSet>,
|
||||
permission_mode: PermissionMode,
|
||||
) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
let system_prompt = build_system_prompt()?;
|
||||
let session = create_managed_session_handle()?;
|
||||
@@ -788,10 +965,12 @@ impl LiveCli {
|
||||
system_prompt.clone(),
|
||||
enable_tools,
|
||||
allowed_tools.clone(),
|
||||
permission_mode,
|
||||
)?;
|
||||
let cli = Self {
|
||||
model,
|
||||
allowed_tools,
|
||||
permission_mode,
|
||||
system_prompt,
|
||||
runtime,
|
||||
session,
|
||||
@@ -802,8 +981,9 @@ impl LiveCli {
|
||||
|
||||
fn startup_banner(&self) -> String {
|
||||
format!(
|
||||
"Rusty Claude CLI\n Model {}\n Working directory {}\n Session {}\n\nType /help for commands. Shift+Enter or Ctrl+J inserts a newline.",
|
||||
"Rusty Claude CLI\n Model {}\n Permission mode {}\n Working directory {}\n Session {}\n\nType /help for commands. Shift+Enter or Ctrl+J inserts a newline.",
|
||||
self.model,
|
||||
self.permission_mode.as_str(),
|
||||
env::current_dir().map_or_else(
|
||||
|_| "<unknown>".to_string(),
|
||||
|path| path.display().to_string(),
|
||||
@@ -820,7 +1000,8 @@ impl LiveCli {
|
||||
TerminalRenderer::new().color_theme(),
|
||||
&mut stdout,
|
||||
)?;
|
||||
let result = self.runtime.run_turn(input, None);
|
||||
let mut permission_prompter = CliPermissionPrompter::new(self.permission_mode);
|
||||
let result = self.runtime.run_turn(input, Some(&mut permission_prompter));
|
||||
match result {
|
||||
Ok(_) => {
|
||||
spinner.finish(
|
||||
@@ -855,7 +1036,7 @@ impl LiveCli {
|
||||
}
|
||||
|
||||
fn run_prompt_json(&mut self, input: &str) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let client = AnthropicClient::from_env()?;
|
||||
let client = AnthropicClient::from_auth(resolve_cli_auth_source()?);
|
||||
let request = MessageRequest {
|
||||
model: self.model.clone(),
|
||||
max_tokens: DEFAULT_MAX_TOKENS,
|
||||
@@ -975,7 +1156,7 @@ impl LiveCli {
|
||||
cumulative,
|
||||
estimated_tokens: self.runtime.estimated_tokens(),
|
||||
},
|
||||
permission_mode_label(),
|
||||
self.permission_mode.as_str(),
|
||||
&status_context(Some(&self.session.path)).expect("status context should load"),
|
||||
)
|
||||
);
|
||||
@@ -1015,6 +1196,7 @@ impl LiveCli {
|
||||
self.system_prompt.clone(),
|
||||
true,
|
||||
self.allowed_tools.clone(),
|
||||
self.permission_mode,
|
||||
)?;
|
||||
self.model.clone_from(&model);
|
||||
println!(
|
||||
@@ -1029,7 +1211,10 @@ impl LiveCli {
|
||||
mode: Option<String>,
|
||||
) -> Result<bool, Box<dyn std::error::Error>> {
|
||||
let Some(mode) = mode else {
|
||||
println!("{}", format_permissions_report(permission_mode_label()));
|
||||
println!(
|
||||
"{}",
|
||||
format_permissions_report(self.permission_mode.as_str())
|
||||
);
|
||||
return Ok(false);
|
||||
};
|
||||
|
||||
@@ -1039,20 +1224,21 @@ impl LiveCli {
|
||||
)
|
||||
})?;
|
||||
|
||||
if normalized == permission_mode_label() {
|
||||
if normalized == self.permission_mode.as_str() {
|
||||
println!("{}", format_permissions_report(normalized));
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
let previous = permission_mode_label().to_string();
|
||||
let previous = self.permission_mode.as_str().to_string();
|
||||
let session = self.runtime.session().clone();
|
||||
self.runtime = build_runtime_with_permission_mode(
|
||||
self.permission_mode = permission_mode_from_label(normalized);
|
||||
self.runtime = build_runtime(
|
||||
session,
|
||||
self.model.clone(),
|
||||
self.system_prompt.clone(),
|
||||
true,
|
||||
self.allowed_tools.clone(),
|
||||
normalized,
|
||||
self.permission_mode,
|
||||
)?;
|
||||
println!(
|
||||
"{}",
|
||||
@@ -1070,18 +1256,18 @@ impl LiveCli {
|
||||
}
|
||||
|
||||
self.session = create_managed_session_handle()?;
|
||||
self.runtime = build_runtime_with_permission_mode(
|
||||
self.runtime = build_runtime(
|
||||
Session::new(),
|
||||
self.model.clone(),
|
||||
self.system_prompt.clone(),
|
||||
true,
|
||||
self.allowed_tools.clone(),
|
||||
permission_mode_label(),
|
||||
self.permission_mode,
|
||||
)?;
|
||||
println!(
|
||||
"Session cleared\n Mode fresh session\n Preserved model {}\n Permission mode {}\n Session {}",
|
||||
self.model,
|
||||
permission_mode_label(),
|
||||
self.permission_mode.as_str(),
|
||||
self.session.id,
|
||||
);
|
||||
Ok(true)
|
||||
@@ -1104,13 +1290,13 @@ impl LiveCli {
|
||||
let handle = resolve_session_reference(&session_ref)?;
|
||||
let session = Session::load_from_path(&handle.path)?;
|
||||
let message_count = session.messages.len();
|
||||
self.runtime = build_runtime_with_permission_mode(
|
||||
self.runtime = build_runtime(
|
||||
session,
|
||||
self.model.clone(),
|
||||
self.system_prompt.clone(),
|
||||
true,
|
||||
self.allowed_tools.clone(),
|
||||
permission_mode_label(),
|
||||
self.permission_mode,
|
||||
)?;
|
||||
self.session = handle;
|
||||
println!(
|
||||
@@ -1180,13 +1366,13 @@ impl LiveCli {
|
||||
let handle = resolve_session_reference(target)?;
|
||||
let session = Session::load_from_path(&handle.path)?;
|
||||
let message_count = session.messages.len();
|
||||
self.runtime = build_runtime_with_permission_mode(
|
||||
self.runtime = build_runtime(
|
||||
session,
|
||||
self.model.clone(),
|
||||
self.system_prompt.clone(),
|
||||
true,
|
||||
self.allowed_tools.clone(),
|
||||
permission_mode_label(),
|
||||
self.permission_mode,
|
||||
)?;
|
||||
self.session = handle;
|
||||
println!(
|
||||
@@ -1209,13 +1395,13 @@ impl LiveCli {
|
||||
let removed = result.removed_message_count;
|
||||
let kept = result.compacted_session.messages.len();
|
||||
let skipped = removed == 0;
|
||||
self.runtime = build_runtime_with_permission_mode(
|
||||
self.runtime = build_runtime(
|
||||
result.compacted_session,
|
||||
self.model.clone(),
|
||||
self.system_prompt.clone(),
|
||||
true,
|
||||
self.allowed_tools.clone(),
|
||||
permission_mode_label(),
|
||||
self.permission_mode,
|
||||
)?;
|
||||
self.persist_session()?;
|
||||
println!("{}", format_compact_report(removed, kept, skipped));
|
||||
@@ -1356,7 +1542,8 @@ fn status_context(
|
||||
session_path: session_path.map(Path::to_path_buf),
|
||||
loaded_config_files: runtime_config.loaded_entries().len(),
|
||||
discovered_config_files,
|
||||
memory_file_count: project_context.instruction_files.len(),
|
||||
memory_file_count: project_context.instruction_files.len()
|
||||
+ project_context.memory_files.len(),
|
||||
project_root,
|
||||
git_branch,
|
||||
})
|
||||
@@ -1501,39 +1688,58 @@ fn render_memory_report() -> Result<String, Box<dyn std::error::Error>> {
|
||||
let mut lines = vec![format!(
|
||||
"Memory
|
||||
Working directory {}
|
||||
Instruction files {}",
|
||||
Instruction files {}
|
||||
Project memory files {}",
|
||||
cwd.display(),
|
||||
project_context.instruction_files.len()
|
||||
project_context.instruction_files.len(),
|
||||
project_context.memory_files.len()
|
||||
)];
|
||||
if project_context.instruction_files.is_empty() {
|
||||
lines.push("Discovered files".to_string());
|
||||
lines.push(
|
||||
" No CLAUDE instruction files discovered in the current directory ancestry."
|
||||
.to_string(),
|
||||
);
|
||||
} else {
|
||||
lines.push("Discovered files".to_string());
|
||||
for (index, file) in project_context.instruction_files.iter().enumerate() {
|
||||
let preview = file.content.lines().next().unwrap_or("").trim();
|
||||
let preview = if preview.is_empty() {
|
||||
"<empty>"
|
||||
} else {
|
||||
preview
|
||||
};
|
||||
lines.push(format!(" {}. {}", index + 1, file.path.display(),));
|
||||
lines.push(format!(
|
||||
" lines={} preview={}",
|
||||
file.content.lines().count(),
|
||||
preview
|
||||
));
|
||||
}
|
||||
}
|
||||
append_memory_section(
|
||||
&mut lines,
|
||||
"Instruction files",
|
||||
&project_context.instruction_files,
|
||||
"No CLAUDE instruction files discovered in the current directory ancestry.",
|
||||
);
|
||||
append_memory_section(
|
||||
&mut lines,
|
||||
"Project memory files",
|
||||
&project_context.memory_files,
|
||||
"No persisted project memory files discovered in .claude/memory.",
|
||||
);
|
||||
Ok(lines.join(
|
||||
"
|
||||
",
|
||||
))
|
||||
}
|
||||
|
||||
fn append_memory_section(
|
||||
lines: &mut Vec<String>,
|
||||
title: &str,
|
||||
files: &[runtime::ContextFile],
|
||||
empty_message: &str,
|
||||
) {
|
||||
lines.push(title.to_string());
|
||||
if files.is_empty() {
|
||||
lines.push(format!(" {empty_message}"));
|
||||
return;
|
||||
}
|
||||
|
||||
for (index, file) in files.iter().enumerate() {
|
||||
let preview = file.content.lines().next().unwrap_or("").trim();
|
||||
let preview = if preview.is_empty() {
|
||||
"<empty>"
|
||||
} else {
|
||||
preview
|
||||
};
|
||||
lines.push(format!(" {}. {}", index + 1, file.path.display()));
|
||||
lines.push(format!(
|
||||
" lines={} preview={}",
|
||||
file.content.lines().count(),
|
||||
preview
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
fn init_claude_md() -> Result<String, Box<dyn std::error::Error>> {
|
||||
let cwd = env::current_dir()?;
|
||||
let claude_md = cwd.join("CLAUDE.md");
|
||||
@@ -1608,14 +1814,6 @@ fn normalize_permission_mode(mode: &str) -> Option<&'static str> {
|
||||
}
|
||||
}
|
||||
|
||||
fn permission_mode_label() -> &'static str {
|
||||
match env::var("RUSTY_CLAUDE_PERMISSION_MODE") {
|
||||
Ok(value) if value == "read-only" => "read-only",
|
||||
Ok(value) if value == "danger-full-access" => "danger-full-access",
|
||||
_ => "workspace-write",
|
||||
}
|
||||
}
|
||||
|
||||
fn render_diff_report() -> Result<String, Box<dyn std::error::Error>> {
|
||||
let output = std::process::Command::new("git")
|
||||
.args(["diff", "--", ":(exclude).omx"])
|
||||
@@ -1745,25 +1943,7 @@ fn build_runtime(
|
||||
system_prompt: Vec<String>,
|
||||
enable_tools: bool,
|
||||
allowed_tools: Option<AllowedToolSet>,
|
||||
) -> Result<ConversationRuntime<AnthropicRuntimeClient, CliToolExecutor>, Box<dyn std::error::Error>>
|
||||
{
|
||||
build_runtime_with_permission_mode(
|
||||
session,
|
||||
model,
|
||||
system_prompt,
|
||||
enable_tools,
|
||||
allowed_tools,
|
||||
permission_mode_label(),
|
||||
)
|
||||
}
|
||||
|
||||
fn build_runtime_with_permission_mode(
|
||||
session: Session,
|
||||
model: String,
|
||||
system_prompt: Vec<String>,
|
||||
enable_tools: bool,
|
||||
allowed_tools: Option<AllowedToolSet>,
|
||||
permission_mode: &str,
|
||||
permission_mode: PermissionMode,
|
||||
) -> Result<ConversationRuntime<AnthropicRuntimeClient, CliToolExecutor>, Box<dyn std::error::Error>>
|
||||
{
|
||||
Ok(ConversationRuntime::new(
|
||||
@@ -1775,6 +1955,52 @@ fn build_runtime_with_permission_mode(
|
||||
))
|
||||
}
|
||||
|
||||
struct CliPermissionPrompter {
|
||||
current_mode: PermissionMode,
|
||||
}
|
||||
|
||||
impl CliPermissionPrompter {
|
||||
fn new(current_mode: PermissionMode) -> Self {
|
||||
Self { current_mode }
|
||||
}
|
||||
}
|
||||
|
||||
impl runtime::PermissionPrompter for CliPermissionPrompter {
|
||||
fn decide(
|
||||
&mut self,
|
||||
request: &runtime::PermissionRequest,
|
||||
) -> runtime::PermissionPromptDecision {
|
||||
println!();
|
||||
println!("Permission approval required");
|
||||
println!(" Tool {}", request.tool_name);
|
||||
println!(" Current mode {}", self.current_mode.as_str());
|
||||
println!(" Required mode {}", request.required_mode.as_str());
|
||||
println!(" Input {}", request.input);
|
||||
print!("Approve this tool call? [y/N]: ");
|
||||
let _ = io::stdout().flush();
|
||||
|
||||
let mut response = String::new();
|
||||
match io::stdin().read_line(&mut response) {
|
||||
Ok(_) => {
|
||||
let normalized = response.trim().to_ascii_lowercase();
|
||||
if matches!(normalized.as_str(), "y" | "yes") {
|
||||
runtime::PermissionPromptDecision::Allow
|
||||
} else {
|
||||
runtime::PermissionPromptDecision::Deny {
|
||||
reason: format!(
|
||||
"tool '{}' denied by user approval prompt",
|
||||
request.tool_name
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(error) => runtime::PermissionPromptDecision::Deny {
|
||||
reason: format!("permission approval failed: {error}"),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct AnthropicRuntimeClient {
|
||||
runtime: tokio::runtime::Runtime,
|
||||
client: AnthropicClient,
|
||||
@@ -1791,7 +2017,7 @@ impl AnthropicRuntimeClient {
|
||||
) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
Ok(Self {
|
||||
runtime: tokio::runtime::Runtime::new()?,
|
||||
client: AnthropicClient::from_env()?,
|
||||
client: AnthropicClient::from_auth(resolve_cli_auth_source()?),
|
||||
model,
|
||||
enable_tools,
|
||||
allowed_tools,
|
||||
@@ -1799,6 +2025,16 @@ impl AnthropicRuntimeClient {
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_cli_auth_source() -> Result<AuthSource, Box<dyn std::error::Error>> {
|
||||
Ok(resolve_startup_auth_source(|| {
|
||||
let cwd = env::current_dir().map_err(api::ApiError::from)?;
|
||||
let config = ConfigLoader::default_for(&cwd).load().map_err(|error| {
|
||||
api::ApiError::Auth(format!("failed to load runtime OAuth config: {error}"))
|
||||
})?;
|
||||
Ok(config.oauth().cloned())
|
||||
})?)
|
||||
}
|
||||
|
||||
impl ApiClient for AnthropicRuntimeClient {
|
||||
#[allow(clippy::too_many_lines)]
|
||||
fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
|
||||
@@ -2072,15 +2308,16 @@ impl ToolExecutor for CliToolExecutor {
|
||||
}
|
||||
}
|
||||
|
||||
fn permission_policy(mode: &str) -> PermissionPolicy {
|
||||
if normalize_permission_mode(mode) == Some("read-only") {
|
||||
PermissionPolicy::new(PermissionMode::Deny)
|
||||
.with_tool_mode("read_file", PermissionMode::Allow)
|
||||
.with_tool_mode("glob_search", PermissionMode::Allow)
|
||||
.with_tool_mode("grep_search", PermissionMode::Allow)
|
||||
} else {
|
||||
PermissionPolicy::new(PermissionMode::Allow)
|
||||
}
|
||||
fn permission_policy(mode: PermissionMode) -> PermissionPolicy {
|
||||
tool_permission_specs()
|
||||
.into_iter()
|
||||
.fold(PermissionPolicy::new(mode), |policy, spec| {
|
||||
policy.with_tool_requirement(spec.name, spec.required_permission)
|
||||
})
|
||||
}
|
||||
|
||||
fn tool_permission_specs() -> Vec<ToolSpec> {
|
||||
mvp_tool_specs()
|
||||
}
|
||||
|
||||
fn convert_messages(messages: &[ConversationMessage]) -> Vec<InputMessage> {
|
||||
@@ -2139,10 +2376,13 @@ fn print_help() {
|
||||
println!(" rusty-claude-cli dump-manifests");
|
||||
println!(" rusty-claude-cli bootstrap-plan");
|
||||
println!(" rusty-claude-cli system-prompt [--cwd PATH] [--date YYYY-MM-DD]");
|
||||
println!(" rusty-claude-cli login");
|
||||
println!(" rusty-claude-cli logout");
|
||||
println!();
|
||||
println!("Flags:");
|
||||
println!(" --model MODEL Override the active model");
|
||||
println!(" --output-format FORMAT Non-interactive output format: text or json");
|
||||
println!(" --permission-mode MODE Set read-only, workspace-write, or danger-full-access");
|
||||
println!(" --allowedTools TOOLS Restrict enabled tools (repeatable; comma-separated aliases supported)");
|
||||
println!(" --version, -V Print version and build information locally");
|
||||
println!();
|
||||
@@ -2163,6 +2403,7 @@ fn print_help() {
|
||||
println!(" rusty-claude-cli --output-format json prompt \"explain src/main.rs\"");
|
||||
println!(" rusty-claude-cli --allowedTools read,glob \"summarize Cargo.toml\"");
|
||||
println!(" rusty-claude-cli --resume session.json /status /diff /export notes.txt");
|
||||
println!(" rusty-claude-cli login");
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -2176,7 +2417,7 @@ mod tests {
|
||||
render_memory_report, render_repl_help, resume_supported_slash_commands, status_context,
|
||||
CliAction, CliOutputFormat, SlashCommand, StatusUsage, DEFAULT_MODEL,
|
||||
};
|
||||
use runtime::{ContentBlock, ConversationMessage, MessageRole};
|
||||
use runtime::{ContentBlock, ConversationMessage, MessageRole, PermissionMode};
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
#[test]
|
||||
@@ -2186,6 +2427,7 @@ mod tests {
|
||||
CliAction::Repl {
|
||||
model: DEFAULT_MODEL.to_string(),
|
||||
allowed_tools: None,
|
||||
permission_mode: PermissionMode::WorkspaceWrite,
|
||||
}
|
||||
);
|
||||
}
|
||||
@@ -2204,6 +2446,7 @@ mod tests {
|
||||
model: DEFAULT_MODEL.to_string(),
|
||||
output_format: CliOutputFormat::Text,
|
||||
allowed_tools: None,
|
||||
permission_mode: PermissionMode::WorkspaceWrite,
|
||||
}
|
||||
);
|
||||
}
|
||||
@@ -2224,6 +2467,7 @@ mod tests {
|
||||
model: "claude-opus".to_string(),
|
||||
output_format: CliOutputFormat::Json,
|
||||
allowed_tools: None,
|
||||
permission_mode: PermissionMode::WorkspaceWrite,
|
||||
}
|
||||
);
|
||||
}
|
||||
@@ -2240,6 +2484,19 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_permission_mode_flag() {
|
||||
let args = vec!["--permission-mode=read-only".to_string()];
|
||||
assert_eq!(
|
||||
parse_args(&args).expect("args should parse"),
|
||||
CliAction::Repl {
|
||||
model: DEFAULT_MODEL.to_string(),
|
||||
allowed_tools: None,
|
||||
permission_mode: PermissionMode::ReadOnly,
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_allowed_tools_flags_with_aliases_and_lists() {
|
||||
let args = vec![
|
||||
@@ -2257,6 +2514,7 @@ mod tests {
|
||||
.map(str::to_string)
|
||||
.collect()
|
||||
),
|
||||
permission_mode: PermissionMode::WorkspaceWrite,
|
||||
}
|
||||
);
|
||||
}
|
||||
@@ -2286,6 +2544,18 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_login_and_logout_subcommands() {
|
||||
assert_eq!(
|
||||
parse_args(&["login".to_string()]).expect("login should parse"),
|
||||
CliAction::Login
|
||||
);
|
||||
assert_eq!(
|
||||
parse_args(&["logout".to_string()]).expect("logout should parse"),
|
||||
CliAction::Logout
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_resume_flag_with_slash_command() {
|
||||
let args = vec![
|
||||
@@ -2522,7 +2792,7 @@ mod tests {
|
||||
assert!(report.contains("Memory"));
|
||||
assert!(report.contains("Working directory"));
|
||||
assert!(report.contains("Instruction files"));
|
||||
assert!(report.contains("Discovered files"));
|
||||
assert!(report.contains("Project memory files"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -2547,7 +2817,7 @@ mod tests {
|
||||
fn status_context_reads_real_workspace_metadata() {
|
||||
let context = status_context(None).expect("status context should load");
|
||||
assert!(context.cwd.is_absolute());
|
||||
assert_eq!(context.discovered_config_files, 3);
|
||||
assert!(context.discovered_config_files >= 3);
|
||||
assert!(context.loaded_config_files <= context.discovered_config_files);
|
||||
}
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ use std::time::{Duration, Instant};
|
||||
use reqwest::blocking::Client;
|
||||
use runtime::{
|
||||
edit_file, execute_bash, glob_search, grep_search, read_file, write_file, BashCommandInput,
|
||||
GrepSearchInput,
|
||||
GrepSearchInput, PermissionMode,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
@@ -45,6 +45,7 @@ pub struct ToolSpec {
|
||||
pub name: &'static str,
|
||||
pub description: &'static str,
|
||||
pub input_schema: Value,
|
||||
pub required_permission: PermissionMode,
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
@@ -66,6 +67,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
|
||||
"required": ["command"],
|
||||
"additionalProperties": false
|
||||
}),
|
||||
required_permission: PermissionMode::DangerFullAccess,
|
||||
},
|
||||
ToolSpec {
|
||||
name: "read_file",
|
||||
@@ -80,6 +82,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
|
||||
"required": ["path"],
|
||||
"additionalProperties": false
|
||||
}),
|
||||
required_permission: PermissionMode::ReadOnly,
|
||||
},
|
||||
ToolSpec {
|
||||
name: "write_file",
|
||||
@@ -93,6 +96,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
|
||||
"required": ["path", "content"],
|
||||
"additionalProperties": false
|
||||
}),
|
||||
required_permission: PermissionMode::WorkspaceWrite,
|
||||
},
|
||||
ToolSpec {
|
||||
name: "edit_file",
|
||||
@@ -108,6 +112,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
|
||||
"required": ["path", "old_string", "new_string"],
|
||||
"additionalProperties": false
|
||||
}),
|
||||
required_permission: PermissionMode::WorkspaceWrite,
|
||||
},
|
||||
ToolSpec {
|
||||
name: "glob_search",
|
||||
@@ -121,6 +126,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
|
||||
"required": ["pattern"],
|
||||
"additionalProperties": false
|
||||
}),
|
||||
required_permission: PermissionMode::ReadOnly,
|
||||
},
|
||||
ToolSpec {
|
||||
name: "grep_search",
|
||||
@@ -146,6 +152,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
|
||||
"required": ["pattern"],
|
||||
"additionalProperties": false
|
||||
}),
|
||||
required_permission: PermissionMode::ReadOnly,
|
||||
},
|
||||
ToolSpec {
|
||||
name: "WebFetch",
|
||||
@@ -160,6 +167,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
|
||||
"required": ["url", "prompt"],
|
||||
"additionalProperties": false
|
||||
}),
|
||||
required_permission: PermissionMode::ReadOnly,
|
||||
},
|
||||
ToolSpec {
|
||||
name: "WebSearch",
|
||||
@@ -180,6 +188,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
|
||||
"required": ["query"],
|
||||
"additionalProperties": false
|
||||
}),
|
||||
required_permission: PermissionMode::ReadOnly,
|
||||
},
|
||||
ToolSpec {
|
||||
name: "TodoWrite",
|
||||
@@ -207,6 +216,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
|
||||
"required": ["todos"],
|
||||
"additionalProperties": false
|
||||
}),
|
||||
required_permission: PermissionMode::WorkspaceWrite,
|
||||
},
|
||||
ToolSpec {
|
||||
name: "Skill",
|
||||
@@ -220,6 +230,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
|
||||
"required": ["skill"],
|
||||
"additionalProperties": false
|
||||
}),
|
||||
required_permission: PermissionMode::ReadOnly,
|
||||
},
|
||||
ToolSpec {
|
||||
name: "Agent",
|
||||
@@ -236,6 +247,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
|
||||
"required": ["description", "prompt"],
|
||||
"additionalProperties": false
|
||||
}),
|
||||
required_permission: PermissionMode::DangerFullAccess,
|
||||
},
|
||||
ToolSpec {
|
||||
name: "ToolSearch",
|
||||
@@ -249,6 +261,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
|
||||
"required": ["query"],
|
||||
"additionalProperties": false
|
||||
}),
|
||||
required_permission: PermissionMode::ReadOnly,
|
||||
},
|
||||
ToolSpec {
|
||||
name: "NotebookEdit",
|
||||
@@ -265,6 +278,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
|
||||
"required": ["notebook_path"],
|
||||
"additionalProperties": false
|
||||
}),
|
||||
required_permission: PermissionMode::WorkspaceWrite,
|
||||
},
|
||||
ToolSpec {
|
||||
name: "Sleep",
|
||||
@@ -277,6 +291,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
|
||||
"required": ["duration_ms"],
|
||||
"additionalProperties": false
|
||||
}),
|
||||
required_permission: PermissionMode::ReadOnly,
|
||||
},
|
||||
ToolSpec {
|
||||
name: "SendUserMessage",
|
||||
@@ -297,6 +312,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
|
||||
"required": ["message", "status"],
|
||||
"additionalProperties": false
|
||||
}),
|
||||
required_permission: PermissionMode::ReadOnly,
|
||||
},
|
||||
ToolSpec {
|
||||
name: "Config",
|
||||
@@ -312,6 +328,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
|
||||
"required": ["setting"],
|
||||
"additionalProperties": false
|
||||
}),
|
||||
required_permission: PermissionMode::WorkspaceWrite,
|
||||
},
|
||||
ToolSpec {
|
||||
name: "StructuredOutput",
|
||||
@@ -320,6 +337,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
|
||||
"type": "object",
|
||||
"additionalProperties": true
|
||||
}),
|
||||
required_permission: PermissionMode::ReadOnly,
|
||||
},
|
||||
ToolSpec {
|
||||
name: "REPL",
|
||||
@@ -334,6 +352,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
|
||||
"required": ["code", "language"],
|
||||
"additionalProperties": false
|
||||
}),
|
||||
required_permission: PermissionMode::DangerFullAccess,
|
||||
},
|
||||
ToolSpec {
|
||||
name: "PowerShell",
|
||||
@@ -349,6 +368,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
|
||||
"required": ["command"],
|
||||
"additionalProperties": false
|
||||
}),
|
||||
required_permission: PermissionMode::DangerFullAccess,
|
||||
},
|
||||
]
|
||||
}
|
||||
@@ -1179,10 +1199,9 @@ fn execute_todo_write(input: TodoWriteInput) -> Result<TodoWriteOutput, String>
|
||||
validate_todos(&input.todos)?;
|
||||
let store_path = todo_store_path()?;
|
||||
let old_todos = if store_path.exists() {
|
||||
serde_json::from_str::<Vec<TodoItem>>(
|
||||
parse_todo_markdown(
|
||||
&std::fs::read_to_string(&store_path).map_err(|error| error.to_string())?,
|
||||
)
|
||||
.map_err(|error| error.to_string())?
|
||||
)?
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
@@ -1200,11 +1219,8 @@ fn execute_todo_write(input: TodoWriteInput) -> Result<TodoWriteOutput, String>
|
||||
if let Some(parent) = store_path.parent() {
|
||||
std::fs::create_dir_all(parent).map_err(|error| error.to_string())?;
|
||||
}
|
||||
std::fs::write(
|
||||
&store_path,
|
||||
serde_json::to_string_pretty(&persisted).map_err(|error| error.to_string())?,
|
||||
)
|
||||
.map_err(|error| error.to_string())?;
|
||||
std::fs::write(&store_path, render_todo_markdown(&persisted))
|
||||
.map_err(|error| error.to_string())?;
|
||||
|
||||
let verification_nudge_needed = (all_done
|
||||
&& input.todos.len() >= 3
|
||||
@@ -1262,7 +1278,58 @@ fn todo_store_path() -> Result<std::path::PathBuf, String> {
|
||||
return Ok(std::path::PathBuf::from(path));
|
||||
}
|
||||
let cwd = std::env::current_dir().map_err(|error| error.to_string())?;
|
||||
Ok(cwd.join(".clawd-todos.json"))
|
||||
Ok(cwd.join(".claude").join("todos.md"))
|
||||
}
|
||||
|
||||
fn render_todo_markdown(todos: &[TodoItem]) -> String {
|
||||
let mut lines = vec!["# Todo list".to_string(), String::new()];
|
||||
for todo in todos {
|
||||
let marker = match todo.status {
|
||||
TodoStatus::Pending => "[ ]",
|
||||
TodoStatus::InProgress => "[~]",
|
||||
TodoStatus::Completed => "[x]",
|
||||
};
|
||||
lines.push(format!(
|
||||
"- {marker} {} :: {}",
|
||||
todo.content, todo.active_form
|
||||
));
|
||||
}
|
||||
lines.push(String::new());
|
||||
lines.join("\n")
|
||||
}
|
||||
|
||||
fn parse_todo_markdown(content: &str) -> Result<Vec<TodoItem>, String> {
|
||||
let mut todos = Vec::new();
|
||||
for line in content.lines() {
|
||||
let trimmed = line.trim();
|
||||
if trimmed.is_empty() || trimmed.starts_with('#') {
|
||||
continue;
|
||||
}
|
||||
let Some(rest) = trimmed.strip_prefix("- [") else {
|
||||
continue;
|
||||
};
|
||||
let mut chars = rest.chars();
|
||||
let status = match chars.next() {
|
||||
Some(' ') => TodoStatus::Pending,
|
||||
Some('~') => TodoStatus::InProgress,
|
||||
Some('x' | 'X') => TodoStatus::Completed,
|
||||
Some(other) => return Err(format!("unsupported todo status marker: {other}")),
|
||||
None => return Err(String::from("malformed todo line")),
|
||||
};
|
||||
let remainder = chars.as_str();
|
||||
let Some(body) = remainder.strip_prefix("] ") else {
|
||||
return Err(String::from("malformed todo line"));
|
||||
};
|
||||
let Some((content, active_form)) = body.split_once(" :: ") else {
|
||||
return Err(String::from("todo line missing active form separator"));
|
||||
};
|
||||
todos.push(TodoItem {
|
||||
content: content.trim().to_string(),
|
||||
active_form: active_form.trim().to_string(),
|
||||
status,
|
||||
});
|
||||
}
|
||||
Ok(todos)
|
||||
}
|
||||
|
||||
fn resolve_skill_path(skill: &str) -> Result<std::path::PathBuf, String> {
|
||||
@@ -2349,8 +2416,10 @@ fn parse_skill_description(contents: &str) -> Option<String> {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::fs;
|
||||
use std::io::{Read, Write};
|
||||
use std::net::{SocketAddr, TcpListener};
|
||||
use std::path::PathBuf;
|
||||
use std::sync::{Arc, Mutex, OnceLock};
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
@@ -2363,6 +2432,14 @@ mod tests {
|
||||
LOCK.get_or_init(|| Mutex::new(()))
|
||||
}
|
||||
|
||||
fn temp_path(name: &str) -> PathBuf {
|
||||
let unique = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.expect("time")
|
||||
.as_nanos();
|
||||
std::env::temp_dir().join(format!("clawd-tools-{unique}-{name}"))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn exposes_mvp_tools() {
|
||||
let names = mvp_tool_specs()
|
||||
@@ -2432,6 +2509,40 @@ mod tests {
|
||||
assert!(titled_summary.contains("Title: Ignored"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn web_fetch_supports_plain_text_and_rejects_invalid_url() {
|
||||
let server = TestServer::spawn(Arc::new(|request_line: &str| {
|
||||
assert!(request_line.starts_with("GET /plain "));
|
||||
HttpResponse::text(200, "OK", "plain text response")
|
||||
}));
|
||||
|
||||
let result = execute_tool(
|
||||
"WebFetch",
|
||||
&json!({
|
||||
"url": format!("http://{}/plain", server.addr()),
|
||||
"prompt": "Show me the content"
|
||||
}),
|
||||
)
|
||||
.expect("WebFetch should succeed for text content");
|
||||
|
||||
let output: serde_json::Value = serde_json::from_str(&result).expect("valid json");
|
||||
assert_eq!(output["url"], format!("http://{}/plain", server.addr()));
|
||||
assert!(output["result"]
|
||||
.as_str()
|
||||
.expect("result")
|
||||
.contains("plain text response"));
|
||||
|
||||
let error = execute_tool(
|
||||
"WebFetch",
|
||||
&json!({
|
||||
"url": "not a url",
|
||||
"prompt": "Summarize"
|
||||
}),
|
||||
)
|
||||
.expect_err("invalid URL should fail");
|
||||
assert!(error.contains("relative URL without a base") || error.contains("invalid"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn web_search_extracts_and_filters_results() {
|
||||
let server = TestServer::spawn(Arc::new(|request_line: &str| {
|
||||
@@ -2476,15 +2587,63 @@ mod tests {
|
||||
assert_eq!(content[0]["url"], "https://docs.rs/reqwest");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn web_search_handles_generic_links_and_invalid_base_url() {
|
||||
let _guard = env_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
let server = TestServer::spawn(Arc::new(|request_line: &str| {
|
||||
assert!(request_line.contains("GET /fallback?q=generic+links "));
|
||||
HttpResponse::html(
|
||||
200,
|
||||
"OK",
|
||||
r#"
|
||||
<html><body>
|
||||
<a href="https://example.com/one">Example One</a>
|
||||
<a href="https://example.com/one">Duplicate Example One</a>
|
||||
<a href="https://docs.rs/tokio">Tokio Docs</a>
|
||||
</body></html>
|
||||
"#,
|
||||
)
|
||||
}));
|
||||
|
||||
std::env::set_var(
|
||||
"CLAWD_WEB_SEARCH_BASE_URL",
|
||||
format!("http://{}/fallback", server.addr()),
|
||||
);
|
||||
let result = execute_tool(
|
||||
"WebSearch",
|
||||
&json!({
|
||||
"query": "generic links"
|
||||
}),
|
||||
)
|
||||
.expect("WebSearch fallback parsing should succeed");
|
||||
std::env::remove_var("CLAWD_WEB_SEARCH_BASE_URL");
|
||||
|
||||
let output: serde_json::Value = serde_json::from_str(&result).expect("valid json");
|
||||
let results = output["results"].as_array().expect("results array");
|
||||
let search_result = results
|
||||
.iter()
|
||||
.find(|item| item.get("content").is_some())
|
||||
.expect("search result block present");
|
||||
let content = search_result["content"].as_array().expect("content array");
|
||||
assert_eq!(content.len(), 2);
|
||||
assert_eq!(content[0]["url"], "https://example.com/one");
|
||||
assert_eq!(content[1]["url"], "https://docs.rs/tokio");
|
||||
|
||||
std::env::set_var("CLAWD_WEB_SEARCH_BASE_URL", "://bad-base-url");
|
||||
let error = execute_tool("WebSearch", &json!({ "query": "generic links" }))
|
||||
.expect_err("invalid base URL should fail");
|
||||
std::env::remove_var("CLAWD_WEB_SEARCH_BASE_URL");
|
||||
assert!(error.contains("relative URL without a base") || error.contains("empty host"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn todo_write_persists_and_returns_previous_state() {
|
||||
let path = std::env::temp_dir().join(format!(
|
||||
"clawd-tools-todos-{}.json",
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.expect("time")
|
||||
.as_nanos()
|
||||
));
|
||||
let _guard = env_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
let path = temp_path("todos.json");
|
||||
std::env::set_var("CLAWD_TODO_STORE", &path);
|
||||
|
||||
let first = execute_tool(
|
||||
@@ -2526,6 +2685,90 @@ mod tests {
|
||||
assert!(second_output["verificationNudgeNeeded"].is_null());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn todo_write_persists_markdown_in_claude_directory() {
|
||||
let _guard = env_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
let temp = temp_path("todos-md-dir");
|
||||
std::fs::create_dir_all(&temp).expect("temp dir");
|
||||
let previous = std::env::current_dir().expect("cwd");
|
||||
std::env::set_current_dir(&temp).expect("set cwd");
|
||||
|
||||
execute_tool(
|
||||
"TodoWrite",
|
||||
&json!({
|
||||
"todos": [
|
||||
{"content": "Add tool", "activeForm": "Adding tool", "status": "in_progress"},
|
||||
{"content": "Run tests", "activeForm": "Running tests", "status": "pending"}
|
||||
]
|
||||
}),
|
||||
)
|
||||
.expect("TodoWrite should succeed");
|
||||
|
||||
let persisted = std::fs::read_to_string(temp.join(".claude").join("todos.md"))
|
||||
.expect("todo markdown exists");
|
||||
std::env::set_current_dir(previous).expect("restore cwd");
|
||||
let _ = std::fs::remove_dir_all(temp);
|
||||
|
||||
assert!(persisted.contains("# Todo list"));
|
||||
assert!(persisted.contains("- [~] Add tool :: Adding tool"));
|
||||
assert!(persisted.contains("- [ ] Run tests :: Running tests"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn todo_write_rejects_invalid_payloads_and_sets_verification_nudge() {
|
||||
let _guard = env_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
let path = temp_path("todos-errors.json");
|
||||
std::env::set_var("CLAWD_TODO_STORE", &path);
|
||||
|
||||
let empty = execute_tool("TodoWrite", &json!({ "todos": [] }))
|
||||
.expect_err("empty todos should fail");
|
||||
assert!(empty.contains("todos must not be empty"));
|
||||
|
||||
let too_many_active = execute_tool(
|
||||
"TodoWrite",
|
||||
&json!({
|
||||
"todos": [
|
||||
{"content": "One", "activeForm": "Doing one", "status": "in_progress"},
|
||||
{"content": "Two", "activeForm": "Doing two", "status": "in_progress"}
|
||||
]
|
||||
}),
|
||||
)
|
||||
.expect_err("multiple in-progress todos should fail");
|
||||
assert!(too_many_active.contains("zero or one todo items may be in_progress"));
|
||||
|
||||
let blank_content = execute_tool(
|
||||
"TodoWrite",
|
||||
&json!({
|
||||
"todos": [
|
||||
{"content": " ", "activeForm": "Doing it", "status": "pending"}
|
||||
]
|
||||
}),
|
||||
)
|
||||
.expect_err("blank content should fail");
|
||||
assert!(blank_content.contains("todo content must not be empty"));
|
||||
|
||||
let nudge = execute_tool(
|
||||
"TodoWrite",
|
||||
&json!({
|
||||
"todos": [
|
||||
{"content": "Write tests", "activeForm": "Writing tests", "status": "completed"},
|
||||
{"content": "Fix errors", "activeForm": "Fixing errors", "status": "completed"},
|
||||
{"content": "Ship branch", "activeForm": "Shipping branch", "status": "completed"}
|
||||
]
|
||||
}),
|
||||
)
|
||||
.expect("completed todos should succeed");
|
||||
std::env::remove_var("CLAWD_TODO_STORE");
|
||||
let _ = fs::remove_file(path);
|
||||
|
||||
let output: serde_json::Value = serde_json::from_str(&nudge).expect("valid json");
|
||||
assert_eq!(output["verificationNudgeNeeded"], true);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn skill_loads_local_skill_prompt() {
|
||||
let result = execute_tool(
|
||||
@@ -2599,13 +2842,10 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn agent_persists_handoff_metadata() {
|
||||
let dir = std::env::temp_dir().join(format!(
|
||||
"clawd-agent-store-{}",
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.expect("time")
|
||||
.as_nanos()
|
||||
));
|
||||
let _guard = env_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
let dir = temp_path("agent-store");
|
||||
std::env::set_var("CLAWD_AGENT_STORE", &dir);
|
||||
|
||||
let result = execute_tool(
|
||||
@@ -2661,15 +2901,32 @@ mod tests {
|
||||
let _ = std::fs::remove_dir_all(dir);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn agent_rejects_blank_required_fields() {
|
||||
let missing_description = execute_tool(
|
||||
"Agent",
|
||||
&json!({
|
||||
"description": " ",
|
||||
"prompt": "Inspect"
|
||||
}),
|
||||
)
|
||||
.expect_err("blank description should fail");
|
||||
assert!(missing_description.contains("description must not be empty"));
|
||||
|
||||
let missing_prompt = execute_tool(
|
||||
"Agent",
|
||||
&json!({
|
||||
"description": "Inspect branch",
|
||||
"prompt": " "
|
||||
}),
|
||||
)
|
||||
.expect_err("blank prompt should fail");
|
||||
assert!(missing_prompt.contains("prompt must not be empty"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn notebook_edit_replaces_inserts_and_deletes_cells() {
|
||||
let path = std::env::temp_dir().join(format!(
|
||||
"clawd-notebook-{}.ipynb",
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.expect("time")
|
||||
.as_nanos()
|
||||
));
|
||||
let path = temp_path("notebook.ipynb");
|
||||
std::fs::write(
|
||||
&path,
|
||||
r#"{
|
||||
@@ -2747,6 +3004,270 @@ mod tests {
|
||||
let _ = std::fs::remove_file(path);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn notebook_edit_rejects_invalid_inputs() {
|
||||
let text_path = temp_path("notebook.txt");
|
||||
fs::write(&text_path, "not a notebook").expect("write text file");
|
||||
let wrong_extension = execute_tool(
|
||||
"NotebookEdit",
|
||||
&json!({
|
||||
"notebook_path": text_path.display().to_string(),
|
||||
"new_source": "print(1)\n"
|
||||
}),
|
||||
)
|
||||
.expect_err("non-ipynb file should fail");
|
||||
assert!(wrong_extension.contains("Jupyter notebook"));
|
||||
let _ = fs::remove_file(&text_path);
|
||||
|
||||
let empty_notebook = temp_path("empty.ipynb");
|
||||
fs::write(
|
||||
&empty_notebook,
|
||||
r#"{"cells":[],"metadata":{"kernelspec":{"language":"python"}},"nbformat":4,"nbformat_minor":5}"#,
|
||||
)
|
||||
.expect("write empty notebook");
|
||||
|
||||
let missing_source = execute_tool(
|
||||
"NotebookEdit",
|
||||
&json!({
|
||||
"notebook_path": empty_notebook.display().to_string(),
|
||||
"edit_mode": "insert"
|
||||
}),
|
||||
)
|
||||
.expect_err("insert without source should fail");
|
||||
assert!(missing_source.contains("new_source is required"));
|
||||
|
||||
let missing_cell = execute_tool(
|
||||
"NotebookEdit",
|
||||
&json!({
|
||||
"notebook_path": empty_notebook.display().to_string(),
|
||||
"edit_mode": "delete"
|
||||
}),
|
||||
)
|
||||
.expect_err("delete on empty notebook should fail");
|
||||
assert!(missing_cell.contains("Notebook has no cells to edit"));
|
||||
let _ = fs::remove_file(empty_notebook);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bash_tool_reports_success_exit_failure_timeout_and_background() {
|
||||
let success = execute_tool("bash", &json!({ "command": "printf 'hello'" }))
|
||||
.expect("bash should succeed");
|
||||
let success_output: serde_json::Value = serde_json::from_str(&success).expect("json");
|
||||
assert_eq!(success_output["stdout"], "hello");
|
||||
assert_eq!(success_output["interrupted"], false);
|
||||
|
||||
let failure = execute_tool("bash", &json!({ "command": "printf 'oops' >&2; exit 7" }))
|
||||
.expect("bash failure should still return structured output");
|
||||
let failure_output: serde_json::Value = serde_json::from_str(&failure).expect("json");
|
||||
assert_eq!(failure_output["returnCodeInterpretation"], "exit_code:7");
|
||||
assert!(failure_output["stderr"]
|
||||
.as_str()
|
||||
.expect("stderr")
|
||||
.contains("oops"));
|
||||
|
||||
let timeout = execute_tool("bash", &json!({ "command": "sleep 1", "timeout": 10 }))
|
||||
.expect("bash timeout should return output");
|
||||
let timeout_output: serde_json::Value = serde_json::from_str(&timeout).expect("json");
|
||||
assert_eq!(timeout_output["interrupted"], true);
|
||||
assert_eq!(timeout_output["returnCodeInterpretation"], "timeout");
|
||||
assert!(timeout_output["stderr"]
|
||||
.as_str()
|
||||
.expect("stderr")
|
||||
.contains("Command exceeded timeout"));
|
||||
|
||||
let background = execute_tool(
|
||||
"bash",
|
||||
&json!({ "command": "sleep 1", "run_in_background": true }),
|
||||
)
|
||||
.expect("bash background should succeed");
|
||||
let background_output: serde_json::Value = serde_json::from_str(&background).expect("json");
|
||||
assert!(background_output["backgroundTaskId"].as_str().is_some());
|
||||
assert_eq!(background_output["noOutputExpected"], true);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn file_tools_cover_read_write_and_edit_behaviors() {
|
||||
let _guard = env_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
let root = temp_path("fs-suite");
|
||||
fs::create_dir_all(&root).expect("create root");
|
||||
let original_dir = std::env::current_dir().expect("cwd");
|
||||
std::env::set_current_dir(&root).expect("set cwd");
|
||||
|
||||
let write_create = execute_tool(
|
||||
"write_file",
|
||||
&json!({ "path": "nested/demo.txt", "content": "alpha\nbeta\nalpha\n" }),
|
||||
)
|
||||
.expect("write create should succeed");
|
||||
let write_create_output: serde_json::Value =
|
||||
serde_json::from_str(&write_create).expect("json");
|
||||
assert_eq!(write_create_output["type"], "create");
|
||||
assert!(root.join("nested/demo.txt").exists());
|
||||
|
||||
let write_update = execute_tool(
|
||||
"write_file",
|
||||
&json!({ "path": "nested/demo.txt", "content": "alpha\nbeta\ngamma\n" }),
|
||||
)
|
||||
.expect("write update should succeed");
|
||||
let write_update_output: serde_json::Value =
|
||||
serde_json::from_str(&write_update).expect("json");
|
||||
assert_eq!(write_update_output["type"], "update");
|
||||
assert_eq!(write_update_output["originalFile"], "alpha\nbeta\nalpha\n");
|
||||
|
||||
let read_full = execute_tool("read_file", &json!({ "path": "nested/demo.txt" }))
|
||||
.expect("read full should succeed");
|
||||
let read_full_output: serde_json::Value = serde_json::from_str(&read_full).expect("json");
|
||||
assert_eq!(read_full_output["file"]["content"], "alpha\nbeta\ngamma");
|
||||
assert_eq!(read_full_output["file"]["startLine"], 1);
|
||||
|
||||
let read_slice = execute_tool(
|
||||
"read_file",
|
||||
&json!({ "path": "nested/demo.txt", "offset": 1, "limit": 1 }),
|
||||
)
|
||||
.expect("read slice should succeed");
|
||||
let read_slice_output: serde_json::Value = serde_json::from_str(&read_slice).expect("json");
|
||||
assert_eq!(read_slice_output["file"]["content"], "beta");
|
||||
assert_eq!(read_slice_output["file"]["startLine"], 2);
|
||||
|
||||
let read_past_end = execute_tool(
|
||||
"read_file",
|
||||
&json!({ "path": "nested/demo.txt", "offset": 50 }),
|
||||
)
|
||||
.expect("read past EOF should succeed");
|
||||
let read_past_end_output: serde_json::Value =
|
||||
serde_json::from_str(&read_past_end).expect("json");
|
||||
assert_eq!(read_past_end_output["file"]["content"], "");
|
||||
assert_eq!(read_past_end_output["file"]["startLine"], 4);
|
||||
|
||||
let read_error = execute_tool("read_file", &json!({ "path": "missing.txt" }))
|
||||
.expect_err("missing file should fail");
|
||||
assert!(!read_error.is_empty());
|
||||
|
||||
let edit_once = execute_tool(
|
||||
"edit_file",
|
||||
&json!({ "path": "nested/demo.txt", "old_string": "alpha", "new_string": "omega" }),
|
||||
)
|
||||
.expect("single edit should succeed");
|
||||
let edit_once_output: serde_json::Value = serde_json::from_str(&edit_once).expect("json");
|
||||
assert_eq!(edit_once_output["replaceAll"], false);
|
||||
assert_eq!(
|
||||
fs::read_to_string(root.join("nested/demo.txt")).expect("read file"),
|
||||
"omega\nbeta\ngamma\n"
|
||||
);
|
||||
|
||||
execute_tool(
|
||||
"write_file",
|
||||
&json!({ "path": "nested/demo.txt", "content": "alpha\nbeta\nalpha\n" }),
|
||||
)
|
||||
.expect("reset file");
|
||||
let edit_all = execute_tool(
|
||||
"edit_file",
|
||||
&json!({
|
||||
"path": "nested/demo.txt",
|
||||
"old_string": "alpha",
|
||||
"new_string": "omega",
|
||||
"replace_all": true
|
||||
}),
|
||||
)
|
||||
.expect("replace all should succeed");
|
||||
let edit_all_output: serde_json::Value = serde_json::from_str(&edit_all).expect("json");
|
||||
assert_eq!(edit_all_output["replaceAll"], true);
|
||||
assert_eq!(
|
||||
fs::read_to_string(root.join("nested/demo.txt")).expect("read file"),
|
||||
"omega\nbeta\nomega\n"
|
||||
);
|
||||
|
||||
let edit_same = execute_tool(
|
||||
"edit_file",
|
||||
&json!({ "path": "nested/demo.txt", "old_string": "omega", "new_string": "omega" }),
|
||||
)
|
||||
.expect_err("identical old/new should fail");
|
||||
assert!(edit_same.contains("must differ"));
|
||||
|
||||
let edit_missing = execute_tool(
|
||||
"edit_file",
|
||||
&json!({ "path": "nested/demo.txt", "old_string": "missing", "new_string": "omega" }),
|
||||
)
|
||||
.expect_err("missing substring should fail");
|
||||
assert!(edit_missing.contains("old_string not found"));
|
||||
|
||||
std::env::set_current_dir(&original_dir).expect("restore cwd");
|
||||
let _ = fs::remove_dir_all(root);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn glob_and_grep_tools_cover_success_and_errors() {
|
||||
let _guard = env_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
let root = temp_path("search-suite");
|
||||
fs::create_dir_all(root.join("nested")).expect("create root");
|
||||
let original_dir = std::env::current_dir().expect("cwd");
|
||||
std::env::set_current_dir(&root).expect("set cwd");
|
||||
|
||||
fs::write(
|
||||
root.join("nested/lib.rs"),
|
||||
"fn main() {}\nlet alpha = 1;\nlet alpha = 2;\n",
|
||||
)
|
||||
.expect("write rust file");
|
||||
fs::write(root.join("nested/notes.txt"), "alpha\nbeta\n").expect("write txt file");
|
||||
|
||||
let globbed = execute_tool("glob_search", &json!({ "pattern": "nested/*.rs" }))
|
||||
.expect("glob should succeed");
|
||||
let globbed_output: serde_json::Value = serde_json::from_str(&globbed).expect("json");
|
||||
assert_eq!(globbed_output["numFiles"], 1);
|
||||
assert!(globbed_output["filenames"][0]
|
||||
.as_str()
|
||||
.expect("filename")
|
||||
.ends_with("nested/lib.rs"));
|
||||
|
||||
let glob_error = execute_tool("glob_search", &json!({ "pattern": "[" }))
|
||||
.expect_err("invalid glob should fail");
|
||||
assert!(!glob_error.is_empty());
|
||||
|
||||
let grep_content = execute_tool(
|
||||
"grep_search",
|
||||
&json!({
|
||||
"pattern": "alpha",
|
||||
"path": "nested",
|
||||
"glob": "*.rs",
|
||||
"output_mode": "content",
|
||||
"-n": true,
|
||||
"head_limit": 1,
|
||||
"offset": 1
|
||||
}),
|
||||
)
|
||||
.expect("grep content should succeed");
|
||||
let grep_content_output: serde_json::Value =
|
||||
serde_json::from_str(&grep_content).expect("json");
|
||||
assert_eq!(grep_content_output["numFiles"], 0);
|
||||
assert!(grep_content_output["appliedLimit"].is_null());
|
||||
assert_eq!(grep_content_output["appliedOffset"], 1);
|
||||
assert!(grep_content_output["content"]
|
||||
.as_str()
|
||||
.expect("content")
|
||||
.contains("let alpha = 2;"));
|
||||
|
||||
let grep_count = execute_tool(
|
||||
"grep_search",
|
||||
&json!({ "pattern": "alpha", "path": "nested", "output_mode": "count" }),
|
||||
)
|
||||
.expect("grep count should succeed");
|
||||
let grep_count_output: serde_json::Value = serde_json::from_str(&grep_count).expect("json");
|
||||
assert_eq!(grep_count_output["numMatches"], 3);
|
||||
|
||||
let grep_error = execute_tool(
|
||||
"grep_search",
|
||||
&json!({ "pattern": "(alpha", "path": "nested" }),
|
||||
)
|
||||
.expect_err("invalid regex should fail");
|
||||
assert!(!grep_error.is_empty());
|
||||
|
||||
std::env::set_current_dir(&original_dir).expect("restore cwd");
|
||||
let _ = fs::remove_dir_all(root);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sleep_waits_and_reports_duration() {
|
||||
let started = std::time::Instant::now();
|
||||
@@ -3038,6 +3559,15 @@ printf 'pwsh:%s' "$1"
|
||||
}
|
||||
}
|
||||
|
||||
fn text(status: u16, reason: &'static str, body: &str) -> Self {
|
||||
Self {
|
||||
status,
|
||||
reason,
|
||||
content_type: "text/plain; charset=utf-8",
|
||||
body: body.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn to_bytes(&self) -> Vec<u8> {
|
||||
format!(
|
||||
"HTTP/1.1 {} {}\r\nContent-Type: {}\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
|
||||
|
||||
Reference in New Issue
Block a user