mirror of
https://github.com/instructkr/claw-code.git
synced 2026-07-02 14:46:43 -04:00
Compare commits
41 Commits
8d4a739c05
...
rcc/cli
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e2753f055a | ||
|
|
9a86aa6444 | ||
|
|
21b0887469 | ||
|
|
0d89231caa | ||
|
|
b445a3320f | ||
|
|
650a24b6e2 | ||
|
|
d018276fc1 | ||
|
|
387a8bb13f | ||
|
|
243a1ff74f | ||
|
|
583d191527 | ||
|
|
074bd5b7b7 | ||
|
|
bec07658b8 | ||
|
|
f403d3b107 | ||
|
|
bd494184fc | ||
|
|
a22700562d | ||
|
|
c14196c730 | ||
|
|
f544125c01 | ||
|
|
ccebabe605 | ||
|
|
cdf24b87b4 | ||
|
|
770fb8d0e7 | ||
|
|
e38e3ee4d7 | ||
|
|
331b8fc811 | ||
|
|
72b5f2fe80 | ||
|
|
b200198df7 | ||
|
|
2fd6241bd8 | ||
|
|
5b046836b9 | ||
|
|
549deb9a89 | ||
|
|
146260083c | ||
|
|
3ba60be514 | ||
|
|
d6341d54c1 | ||
|
|
cd01d0e387 | ||
|
|
863958b94c | ||
|
|
9455280f24 | ||
|
|
c92403994d | ||
|
|
e2f061fd08 | ||
|
|
c139fe9bee | ||
|
|
842abcfe85 | ||
|
|
807e29c8a1 | ||
|
|
32e89df631 | ||
|
|
1f8cfbce38 | ||
|
|
1e5002b521 |
1
.claude/sessions/session-1774998936453.json
Normal file
1
.claude/sessions/session-1774998936453.json
Normal file
@@ -0,0 +1 @@
|
|||||||
|
{"messages":[],"version":1}
|
||||||
1
.claude/sessions/session-1774998994373.json
Normal file
1
.claude/sessions/session-1774998994373.json
Normal file
@@ -0,0 +1 @@
|
|||||||
|
{"messages":[{"blocks":[{"text":"Say hello in one sentence","type":"text"}],"role":"user"},{"blocks":[{"text":"Hello! I'm Claude, an AI assistant ready to help you with software engineering tasks, code analysis, debugging, or any other programming challenges you might have.","type":"text"}],"role":"assistant","usage":{"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"input_tokens":11,"output_tokens":32}}],"version":1}
|
||||||
1
.claude/sessions/session-1775007533836.json
Normal file
1
.claude/sessions/session-1775007533836.json
Normal file
@@ -0,0 +1 @@
|
|||||||
|
{"messages":[],"version":1}
|
||||||
1
.claude/sessions/session-1775007622154.json
Normal file
1
.claude/sessions/session-1775007622154.json
Normal file
@@ -0,0 +1 @@
|
|||||||
|
{"messages":[{"blocks":[{"text":"What is 2+2? Reply with just the number.","type":"text"}],"role":"user"},{"blocks":[{"text":"4","type":"text"}],"role":"assistant","usage":{"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"input_tokens":19,"output_tokens":5}}],"version":1}
|
||||||
1
.claude/sessions/session-1775007632904.json
Normal file
1
.claude/sessions/session-1775007632904.json
Normal file
@@ -0,0 +1 @@
|
|||||||
|
{"messages":[{"blocks":[{"text":"Say hello in exactly 3 words","type":"text"}],"role":"user"},{"blocks":[{"text":"Hello there, friend!","type":"text"}],"role":"assistant","usage":{"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"input_tokens":14,"output_tokens":8}}],"version":1}
|
||||||
1
.claude/sessions/session-1775007846522.json
Normal file
1
.claude/sessions/session-1775007846522.json
Normal file
@@ -0,0 +1 @@
|
|||||||
|
{"messages":[{"blocks":[{"text":"Say hi in one sentence","type":"text"}],"role":"user"},{"blocks":[{"text":"Hi! I'm Claude, ready to help you with any software engineering tasks or questions you have.","type":"text"}],"role":"assistant","usage":{"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"input_tokens":11,"output_tokens":23}}],"version":1}
|
||||||
1
rust/.claude/sessions/session-1775007453382.json
Normal file
1
rust/.claude/sessions/session-1775007453382.json
Normal file
@@ -0,0 +1 @@
|
|||||||
|
{"messages":[],"version":1}
|
||||||
1
rust/.claude/sessions/session-1775007484031.json
Normal file
1
rust/.claude/sessions/session-1775007484031.json
Normal file
@@ -0,0 +1 @@
|
|||||||
|
{"messages":[],"version":1}
|
||||||
1
rust/.claude/sessions/session-1775007490104.json
Normal file
1
rust/.claude/sessions/session-1775007490104.json
Normal file
@@ -0,0 +1 @@
|
|||||||
|
{"messages":[],"version":1}
|
||||||
1
rust/.claude/sessions/session-1775007981374.json
Normal file
1
rust/.claude/sessions/session-1775007981374.json
Normal file
@@ -0,0 +1 @@
|
|||||||
|
{"messages":[],"version":1}
|
||||||
1
rust/.claude/sessions/session-1775008007069.json
Normal file
1
rust/.claude/sessions/session-1775008007069.json
Normal file
@@ -0,0 +1 @@
|
|||||||
|
{"messages":[],"version":1}
|
||||||
1
rust/.claude/sessions/session-1775008071886.json
Normal file
1
rust/.claude/sessions/session-1775008071886.json
Normal file
@@ -0,0 +1 @@
|
|||||||
|
{"messages":[],"version":1}
|
||||||
140
rust/Cargo.lock
generated
140
rust/Cargo.lock
generated
@@ -22,6 +22,7 @@ name = "api"
|
|||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"reqwest",
|
"reqwest",
|
||||||
|
"runtime",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"tokio",
|
"tokio",
|
||||||
@@ -97,6 +98,15 @@ version = "0.2.1"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724"
|
checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "clipboard-win"
|
||||||
|
version = "5.4.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "bde03770d3df201d4fb868f2c9c59e66a3e4e2bd06692a0fe701e7103c7e84d4"
|
||||||
|
dependencies = [
|
||||||
|
"error-code",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "commands"
|
name = "commands"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
@@ -141,7 +151,7 @@ dependencies = [
|
|||||||
"crossterm_winapi",
|
"crossterm_winapi",
|
||||||
"mio",
|
"mio",
|
||||||
"parking_lot",
|
"parking_lot",
|
||||||
"rustix",
|
"rustix 0.38.44",
|
||||||
"signal-hook",
|
"signal-hook",
|
||||||
"signal-hook-mio",
|
"signal-hook-mio",
|
||||||
"winapi",
|
"winapi",
|
||||||
@@ -196,6 +206,12 @@ dependencies = [
|
|||||||
"syn",
|
"syn",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "endian-type"
|
||||||
|
version = "0.1.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "c34f04666d835ff5d62e058c3995147c06f42fe86ff053337632bca83e42702d"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "equivalent"
|
name = "equivalent"
|
||||||
version = "1.0.2"
|
version = "1.0.2"
|
||||||
@@ -212,6 +228,23 @@ dependencies = [
|
|||||||
"windows-sys 0.61.2",
|
"windows-sys 0.61.2",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "error-code"
|
||||||
|
version = "3.3.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "dea2df4cf52843e0452895c455a1a2cfbb842a1e7329671acf418fdc53ed4c59"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "fd-lock"
|
||||||
|
version = "4.0.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0ce92ff622d6dadf7349484f42c93271a0d49b7cc4d466a936405bacbe10aa78"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
"rustix 1.1.4",
|
||||||
|
"windows-sys 0.52.0",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "find-msvc-tools"
|
name = "find-msvc-tools"
|
||||||
version = "0.1.9"
|
version = "0.1.9"
|
||||||
@@ -350,6 +383,15 @@ version = "0.16.1"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100"
|
checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "home"
|
||||||
|
version = "0.5.12"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "cc627f471c528ff0c4a49e1d5e60450c8f6461dd6d10ba9dcd3a61d3dff7728d"
|
||||||
|
dependencies = [
|
||||||
|
"windows-sys 0.61.2",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "http"
|
name = "http"
|
||||||
version = "1.4.0"
|
version = "1.4.0"
|
||||||
@@ -613,6 +655,12 @@ version = "0.4.15"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab"
|
checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "linux-raw-sys"
|
||||||
|
version = "0.12.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "32a66949e030da00e8c7d4434b251670a91556f4144941d37452769c25d58a53"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "litemap"
|
name = "litemap"
|
||||||
version = "0.8.1"
|
version = "0.8.1"
|
||||||
@@ -668,6 +716,27 @@ dependencies = [
|
|||||||
"windows-sys 0.61.2",
|
"windows-sys 0.61.2",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "nibble_vec"
|
||||||
|
version = "0.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "77a5d83df9f36fe23f0c3648c6bbb8b0298bb5f1939c8f2704431371f4b84d43"
|
||||||
|
dependencies = [
|
||||||
|
"smallvec",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "nix"
|
||||||
|
version = "0.29.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "71e2746dc3a24dd78b3cfcb7be93368c6de9963d30f43a6a73998a9cf4b17b46"
|
||||||
|
dependencies = [
|
||||||
|
"bitflags",
|
||||||
|
"cfg-if",
|
||||||
|
"cfg_aliases",
|
||||||
|
"libc",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "num-conv"
|
name = "num-conv"
|
||||||
version = "0.2.1"
|
version = "0.2.1"
|
||||||
@@ -887,6 +956,16 @@ version = "5.3.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f"
|
checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "radix_trie"
|
||||||
|
version = "0.2.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "c069c179fcdc6a2fe24d8d18305cf085fdbd4f922c041943e203685d6a1c58fd"
|
||||||
|
dependencies = [
|
||||||
|
"endian-type",
|
||||||
|
"nibble_vec",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rand"
|
name = "rand"
|
||||||
version = "0.9.2"
|
version = "0.9.2"
|
||||||
@@ -1036,10 +1115,23 @@ dependencies = [
|
|||||||
"bitflags",
|
"bitflags",
|
||||||
"errno",
|
"errno",
|
||||||
"libc",
|
"libc",
|
||||||
"linux-raw-sys",
|
"linux-raw-sys 0.4.15",
|
||||||
"windows-sys 0.52.0",
|
"windows-sys 0.52.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rustix"
|
||||||
|
version = "1.1.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "b6fe4565b9518b83ef4f91bb47ce29620ca828bd32cb7e408f0062e9930ba190"
|
||||||
|
dependencies = [
|
||||||
|
"bitflags",
|
||||||
|
"errno",
|
||||||
|
"libc",
|
||||||
|
"linux-raw-sys 0.12.1",
|
||||||
|
"windows-sys 0.61.2",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rustls"
|
name = "rustls"
|
||||||
version = "0.23.37"
|
version = "0.23.37"
|
||||||
@@ -1091,12 +1183,35 @@ dependencies = [
|
|||||||
"crossterm",
|
"crossterm",
|
||||||
"pulldown-cmark",
|
"pulldown-cmark",
|
||||||
"runtime",
|
"runtime",
|
||||||
|
"rustyline",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"syntect",
|
"syntect",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tools",
|
"tools",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rustyline"
|
||||||
|
version = "15.0.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "2ee1e066dc922e513bda599c6ccb5f3bb2b0ea5870a579448f2622993f0a9a2f"
|
||||||
|
dependencies = [
|
||||||
|
"bitflags",
|
||||||
|
"cfg-if",
|
||||||
|
"clipboard-win",
|
||||||
|
"fd-lock",
|
||||||
|
"home",
|
||||||
|
"libc",
|
||||||
|
"log",
|
||||||
|
"memchr",
|
||||||
|
"nix",
|
||||||
|
"radix_trie",
|
||||||
|
"unicode-segmentation",
|
||||||
|
"unicode-width",
|
||||||
|
"utf8parse",
|
||||||
|
"windows-sys 0.59.0",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ryu"
|
name = "ryu"
|
||||||
version = "1.0.23"
|
version = "1.0.23"
|
||||||
@@ -1524,6 +1639,12 @@ version = "1.0.24"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75"
|
checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "unicode-segmentation"
|
||||||
|
version = "1.13.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "9629274872b2bfaf8d66f5f15725007f635594914870f65218920345aa11aa8c"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "unicode-width"
|
name = "unicode-width"
|
||||||
version = "0.2.2"
|
version = "0.2.2"
|
||||||
@@ -1554,6 +1675,12 @@ version = "1.0.4"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be"
|
checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "utf8parse"
|
||||||
|
version = "0.2.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "version_check"
|
name = "version_check"
|
||||||
version = "0.9.5"
|
version = "0.9.5"
|
||||||
@@ -1724,6 +1851,15 @@ dependencies = [
|
|||||||
"windows-targets 0.52.6",
|
"windows-targets 0.52.6",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows-sys"
|
||||||
|
version = "0.59.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b"
|
||||||
|
dependencies = [
|
||||||
|
"windows-targets 0.52.6",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows-sys"
|
name = "windows-sys"
|
||||||
version = "0.60.2"
|
version = "0.60.2"
|
||||||
|
|||||||
@@ -64,6 +64,35 @@ cd rust
|
|||||||
cargo run -p rusty-claude-cli -- --version
|
cargo run -p rusty-claude-cli -- --version
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Login with OAuth
|
||||||
|
|
||||||
|
Configure `settings.json` with an `oauth` block containing `clientId`, `authorizeUrl`, `tokenUrl`, optional `callbackPort`, and optional `scopes`, then run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd rust
|
||||||
|
cargo run -p rusty-claude-cli -- login
|
||||||
|
```
|
||||||
|
|
||||||
|
This opens the browser, listens on the configured localhost callback, exchanges the auth code for tokens, and stores OAuth credentials in `~/.claude/credentials.json` (or `$CLAUDE_CONFIG_HOME/credentials.json`).
|
||||||
|
|
||||||
|
### Logout
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd rust
|
||||||
|
cargo run -p rusty-claude-cli -- logout
|
||||||
|
```
|
||||||
|
|
||||||
|
This removes only the stored OAuth credentials and preserves unrelated JSON fields in `credentials.json`.
|
||||||
|
|
||||||
|
### Self-update
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd rust
|
||||||
|
cargo run -p rusty-claude-cli -- self-update
|
||||||
|
```
|
||||||
|
|
||||||
|
The command checks the latest GitHub release for `instructkr/clawd-code`, compares it to the current binary version, downloads the matching binary asset plus checksum manifest, verifies SHA-256, replaces the current executable, and prints the release changelog. If no published release or matching asset exists, it exits safely with an explanatory message.
|
||||||
|
|
||||||
## Usage examples
|
## Usage examples
|
||||||
|
|
||||||
### 1) Prompt mode
|
### 1) Prompt mode
|
||||||
@@ -89,6 +118,13 @@ cd rust
|
|||||||
cargo run -p rusty-claude-cli -- --allowedTools read,glob
|
cargo run -p rusty-claude-cli -- --allowedTools read,glob
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Bootstrap Claude project files for the current repo:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd rust
|
||||||
|
cargo run -p rusty-claude-cli -- init
|
||||||
|
```
|
||||||
|
|
||||||
### 2) REPL mode
|
### 2) REPL mode
|
||||||
|
|
||||||
Start the interactive shell:
|
Start the interactive shell:
|
||||||
@@ -113,6 +149,7 @@ Inside the REPL, useful commands include:
|
|||||||
/diff
|
/diff
|
||||||
/version
|
/version
|
||||||
/export notes.txt
|
/export notes.txt
|
||||||
|
/sessions
|
||||||
/session list
|
/session list
|
||||||
/exit
|
/exit
|
||||||
```
|
```
|
||||||
@@ -123,14 +160,14 @@ Inspect or maintain a saved session file without entering the REPL:
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
cd rust
|
cd rust
|
||||||
cargo run -p rusty-claude-cli -- --resume session.json /status /compact /cost
|
cargo run -p rusty-claude-cli -- --resume session-123456 /status /compact /cost
|
||||||
```
|
```
|
||||||
|
|
||||||
You can also inspect memory/config state for a restored session:
|
You can also inspect memory/config state for a restored session:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cd rust
|
cd rust
|
||||||
cargo run -p rusty-claude-cli -- --resume session.json /memory /config
|
cargo run -p rusty-claude-cli -- --resume ~/.claude/sessions/session-123456.json /memory /config
|
||||||
```
|
```
|
||||||
|
|
||||||
## Available commands
|
## Available commands
|
||||||
@@ -138,10 +175,11 @@ cargo run -p rusty-claude-cli -- --resume session.json /memory /config
|
|||||||
### Top-level CLI commands
|
### Top-level CLI commands
|
||||||
|
|
||||||
- `prompt <text...>` — run one prompt non-interactively
|
- `prompt <text...>` — run one prompt non-interactively
|
||||||
- `--resume <session.json> [/commands...]` — inspect or maintain a saved session
|
- `--resume <session-id-or-path> [/commands...]` — inspect or maintain a saved session stored under `~/.claude/sessions/`
|
||||||
- `dump-manifests` — print extracted upstream manifest counts
|
- `dump-manifests` — print extracted upstream manifest counts
|
||||||
- `bootstrap-plan` — print the current bootstrap skeleton
|
- `bootstrap-plan` — print the current bootstrap skeleton
|
||||||
- `system-prompt [--cwd PATH] [--date YYYY-MM-DD]` — render the synthesized system prompt
|
- `system-prompt [--cwd PATH] [--date YYYY-MM-DD]` — render the synthesized system prompt
|
||||||
|
- `self-update` — update the installed binary from the latest GitHub release when a matching asset is available
|
||||||
- `--help` / `-h` — show CLI help
|
- `--help` / `-h` — show CLI help
|
||||||
- `--version` / `-V` — print the CLI version and build info locally (no API call)
|
- `--version` / `-V` — print the CLI version and build info locally (no API call)
|
||||||
- `--output-format text|json` — choose non-interactive prompt output rendering
|
- `--output-format text|json` — choose non-interactive prompt output rendering
|
||||||
@@ -156,13 +194,14 @@ cargo run -p rusty-claude-cli -- --resume session.json /memory /config
|
|||||||
- `/permissions [read-only|workspace-write|danger-full-access]` — inspect or switch permissions
|
- `/permissions [read-only|workspace-write|danger-full-access]` — inspect or switch permissions
|
||||||
- `/clear [--confirm]` — clear the current local session
|
- `/clear [--confirm]` — clear the current local session
|
||||||
- `/cost` — show token usage totals
|
- `/cost` — show token usage totals
|
||||||
- `/resume <session-path>` — load a saved session into the REPL
|
- `/resume <session-id-or-path>` — load a saved session into the REPL
|
||||||
- `/config [env|hooks|model]` — inspect discovered Claude config
|
- `/config [env|hooks|model]` — inspect discovered Claude config
|
||||||
- `/memory` — inspect loaded instruction memory files
|
- `/memory` — inspect loaded instruction memory files
|
||||||
- `/init` — create a starter `CLAUDE.md`
|
- `/init` — bootstrap `.claude.json`, `.claude/`, `CLAUDE.md`, and local ignore rules
|
||||||
- `/diff` — show the current git diff for the workspace
|
- `/diff` — show the current git diff for the workspace
|
||||||
- `/version` — print version and build metadata locally
|
- `/version` — print version and build metadata locally
|
||||||
- `/export [file]` — export the current conversation transcript
|
- `/export [file]` — export the current conversation transcript
|
||||||
|
- `/sessions` — list recent managed local sessions from `~/.claude/sessions/`
|
||||||
- `/session [list|switch <session-id>]` — inspect or switch managed local sessions
|
- `/session [list|switch <session-id>]` — inspect or switch managed local sessions
|
||||||
- `/exit` — leave the REPL
|
- `/exit` — leave the REPL
|
||||||
|
|
||||||
@@ -170,8 +209,9 @@ cargo run -p rusty-claude-cli -- --resume session.json /memory /config
|
|||||||
|
|
||||||
### Anthropic/API
|
### Anthropic/API
|
||||||
|
|
||||||
- `ANTHROPIC_AUTH_TOKEN` — preferred bearer token for API auth
|
- `ANTHROPIC_API_KEY` — highest-precedence API credential
|
||||||
- `ANTHROPIC_API_KEY` — legacy API key fallback if auth token is unset
|
- `ANTHROPIC_AUTH_TOKEN` — bearer-token override used when no API key is set
|
||||||
|
- Persisted OAuth credentials in `~/.claude/credentials.json` — used when neither env var is set
|
||||||
- `ANTHROPIC_BASE_URL` — override the Anthropic API base URL
|
- `ANTHROPIC_BASE_URL` — override the Anthropic API base URL
|
||||||
- `ANTHROPIC_MODEL` — default model used by selected live integration tests
|
- `ANTHROPIC_MODEL` — default model used by selected live integration tests
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ publish.workspace = true
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls"] }
|
reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls"] }
|
||||||
|
runtime = { path = "../runtime" }
|
||||||
serde = { version = "1", features = ["derive"] }
|
serde = { version = "1", features = ["derive"] }
|
||||||
serde_json = "1"
|
serde_json = "1"
|
||||||
tokio = { version = "1", features = ["io-util", "macros", "net", "rt-multi-thread", "time"] }
|
tokio = { version = "1", features = ["io-util", "macros", "net", "rt-multi-thread", "time"] }
|
||||||
|
|||||||
@@ -1,6 +1,10 @@
|
|||||||
use std::collections::VecDeque;
|
use std::collections::VecDeque;
|
||||||
use std::time::Duration;
|
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||||
|
|
||||||
|
use runtime::{
|
||||||
|
load_oauth_credentials, save_oauth_credentials, OAuthConfig, OAuthRefreshRequest,
|
||||||
|
OAuthTokenExchangeRequest,
|
||||||
|
};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
|
|
||||||
use crate::error::ApiError;
|
use crate::error::ApiError;
|
||||||
@@ -81,11 +85,12 @@ impl AuthSource {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
|
||||||
pub struct OAuthTokenSet {
|
pub struct OAuthTokenSet {
|
||||||
pub access_token: String,
|
pub access_token: String,
|
||||||
pub refresh_token: Option<String>,
|
pub refresh_token: Option<String>,
|
||||||
pub expires_at: Option<u64>,
|
pub expires_at: Option<u64>,
|
||||||
|
#[serde(default)]
|
||||||
pub scopes: Vec<String>,
|
pub scopes: Vec<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -131,7 +136,7 @@ impl AnthropicClient {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn from_env() -> Result<Self, ApiError> {
|
pub fn from_env() -> Result<Self, ApiError> {
|
||||||
Ok(Self::from_auth(AuthSource::from_env()?).with_base_url(read_base_url()))
|
Ok(Self::from_auth(AuthSource::from_env_or_saved()?).with_base_url(read_base_url()))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[must_use]
|
#[must_use]
|
||||||
@@ -225,6 +230,46 @@ impl AnthropicClient {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn exchange_oauth_code(
|
||||||
|
&self,
|
||||||
|
config: &OAuthConfig,
|
||||||
|
request: &OAuthTokenExchangeRequest,
|
||||||
|
) -> Result<OAuthTokenSet, ApiError> {
|
||||||
|
let response = self
|
||||||
|
.http
|
||||||
|
.post(&config.token_url)
|
||||||
|
.header("content-type", "application/x-www-form-urlencoded")
|
||||||
|
.form(&request.form_params())
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(ApiError::from)?;
|
||||||
|
let response = expect_success(response).await?;
|
||||||
|
response
|
||||||
|
.json::<OAuthTokenSet>()
|
||||||
|
.await
|
||||||
|
.map_err(ApiError::from)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn refresh_oauth_token(
|
||||||
|
&self,
|
||||||
|
config: &OAuthConfig,
|
||||||
|
request: &OAuthRefreshRequest,
|
||||||
|
) -> Result<OAuthTokenSet, ApiError> {
|
||||||
|
let response = self
|
||||||
|
.http
|
||||||
|
.post(&config.token_url)
|
||||||
|
.header("content-type", "application/x-www-form-urlencoded")
|
||||||
|
.form(&request.form_params())
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(ApiError::from)?;
|
||||||
|
let response = expect_success(response).await?;
|
||||||
|
response
|
||||||
|
.json::<OAuthTokenSet>()
|
||||||
|
.await
|
||||||
|
.map_err(ApiError::from)
|
||||||
|
}
|
||||||
|
|
||||||
async fn send_with_retry(
|
async fn send_with_retry(
|
||||||
&self,
|
&self,
|
||||||
request: &MessageRequest,
|
request: &MessageRequest,
|
||||||
@@ -304,6 +349,153 @@ impl AnthropicClient {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl AuthSource {
|
||||||
|
pub fn from_env_or_saved() -> Result<Self, ApiError> {
|
||||||
|
if let Some(api_key) = read_env_non_empty("ANTHROPIC_API_KEY")? {
|
||||||
|
return match read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
|
||||||
|
Some(bearer_token) => Ok(Self::ApiKeyAndBearer {
|
||||||
|
api_key,
|
||||||
|
bearer_token,
|
||||||
|
}),
|
||||||
|
None => Ok(Self::ApiKey(api_key)),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
if let Some(bearer_token) = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
|
||||||
|
return Ok(Self::BearerToken(bearer_token));
|
||||||
|
}
|
||||||
|
match load_saved_oauth_token() {
|
||||||
|
Ok(Some(token_set)) if oauth_token_is_expired(&token_set) => {
|
||||||
|
if token_set.refresh_token.is_some() {
|
||||||
|
Err(ApiError::Auth(
|
||||||
|
"saved OAuth token is expired; load runtime OAuth config to refresh it"
|
||||||
|
.to_string(),
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
Err(ApiError::ExpiredOAuthToken)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(Some(token_set)) => Ok(Self::BearerToken(token_set.access_token)),
|
||||||
|
Ok(None) => Err(ApiError::MissingApiKey),
|
||||||
|
Err(error) => Err(error),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn oauth_token_is_expired(token_set: &OAuthTokenSet) -> bool {
|
||||||
|
token_set
|
||||||
|
.expires_at
|
||||||
|
.is_some_and(|expires_at| expires_at <= now_unix_timestamp())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn resolve_saved_oauth_token(config: &OAuthConfig) -> Result<Option<OAuthTokenSet>, ApiError> {
|
||||||
|
let Some(token_set) = load_saved_oauth_token()? else {
|
||||||
|
return Ok(None);
|
||||||
|
};
|
||||||
|
resolve_saved_oauth_token_set(config, token_set).map(Some)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn resolve_startup_auth_source<F>(load_oauth_config: F) -> Result<AuthSource, ApiError>
|
||||||
|
where
|
||||||
|
F: FnOnce() -> Result<Option<OAuthConfig>, ApiError>,
|
||||||
|
{
|
||||||
|
if let Some(api_key) = read_env_non_empty("ANTHROPIC_API_KEY")? {
|
||||||
|
return match read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
|
||||||
|
Some(bearer_token) => Ok(AuthSource::ApiKeyAndBearer {
|
||||||
|
api_key,
|
||||||
|
bearer_token,
|
||||||
|
}),
|
||||||
|
None => Ok(AuthSource::ApiKey(api_key)),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
if let Some(bearer_token) = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
|
||||||
|
return Ok(AuthSource::BearerToken(bearer_token));
|
||||||
|
}
|
||||||
|
|
||||||
|
let Some(token_set) = load_saved_oauth_token()? else {
|
||||||
|
return Err(ApiError::MissingApiKey);
|
||||||
|
};
|
||||||
|
if !oauth_token_is_expired(&token_set) {
|
||||||
|
return Ok(AuthSource::BearerToken(token_set.access_token));
|
||||||
|
}
|
||||||
|
if token_set.refresh_token.is_none() {
|
||||||
|
return Err(ApiError::ExpiredOAuthToken);
|
||||||
|
}
|
||||||
|
|
||||||
|
let Some(config) = load_oauth_config()? else {
|
||||||
|
return Err(ApiError::Auth(
|
||||||
|
"saved OAuth token is expired; runtime OAuth config is missing".to_string(),
|
||||||
|
));
|
||||||
|
};
|
||||||
|
Ok(AuthSource::from(resolve_saved_oauth_token_set(
|
||||||
|
&config, token_set,
|
||||||
|
)?))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn resolve_saved_oauth_token_set(
|
||||||
|
config: &OAuthConfig,
|
||||||
|
token_set: OAuthTokenSet,
|
||||||
|
) -> Result<OAuthTokenSet, ApiError> {
|
||||||
|
if !oauth_token_is_expired(&token_set) {
|
||||||
|
return Ok(token_set);
|
||||||
|
}
|
||||||
|
let Some(refresh_token) = token_set.refresh_token.clone() else {
|
||||||
|
return Err(ApiError::ExpiredOAuthToken);
|
||||||
|
};
|
||||||
|
let client = AnthropicClient::from_auth(AuthSource::None).with_base_url(read_base_url());
|
||||||
|
let refreshed = client_runtime_block_on(async {
|
||||||
|
client
|
||||||
|
.refresh_oauth_token(
|
||||||
|
config,
|
||||||
|
&OAuthRefreshRequest::from_config(
|
||||||
|
config,
|
||||||
|
refresh_token,
|
||||||
|
Some(token_set.scopes.clone()),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
})?;
|
||||||
|
let resolved = OAuthTokenSet {
|
||||||
|
access_token: refreshed.access_token,
|
||||||
|
refresh_token: refreshed.refresh_token.or(token_set.refresh_token),
|
||||||
|
expires_at: refreshed.expires_at,
|
||||||
|
scopes: refreshed.scopes,
|
||||||
|
};
|
||||||
|
save_oauth_credentials(&runtime::OAuthTokenSet {
|
||||||
|
access_token: resolved.access_token.clone(),
|
||||||
|
refresh_token: resolved.refresh_token.clone(),
|
||||||
|
expires_at: resolved.expires_at,
|
||||||
|
scopes: resolved.scopes.clone(),
|
||||||
|
})
|
||||||
|
.map_err(ApiError::from)?;
|
||||||
|
Ok(resolved)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn client_runtime_block_on<F, T>(future: F) -> Result<T, ApiError>
|
||||||
|
where
|
||||||
|
F: std::future::Future<Output = Result<T, ApiError>>,
|
||||||
|
{
|
||||||
|
tokio::runtime::Runtime::new()
|
||||||
|
.map_err(ApiError::from)?
|
||||||
|
.block_on(future)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_saved_oauth_token() -> Result<Option<OAuthTokenSet>, ApiError> {
|
||||||
|
let token_set = load_oauth_credentials().map_err(ApiError::from)?;
|
||||||
|
Ok(token_set.map(|token_set| OAuthTokenSet {
|
||||||
|
access_token: token_set.access_token,
|
||||||
|
refresh_token: token_set.refresh_token,
|
||||||
|
expires_at: token_set.expires_at,
|
||||||
|
scopes: token_set.scopes,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn now_unix_timestamp() -> u64 {
|
||||||
|
SystemTime::now()
|
||||||
|
.duration_since(UNIX_EPOCH)
|
||||||
|
.map_or(0, |duration| duration.as_secs())
|
||||||
|
}
|
||||||
|
|
||||||
fn read_env_non_empty(key: &str) -> Result<Option<String>, ApiError> {
|
fn read_env_non_empty(key: &str) -> Result<Option<String>, ApiError> {
|
||||||
match std::env::var(key) {
|
match std::env::var(key) {
|
||||||
Ok(value) if !value.is_empty() => Ok(Some(value)),
|
Ok(value) if !value.is_empty() => Ok(Some(value)),
|
||||||
@@ -314,7 +506,7 @@ fn read_env_non_empty(key: &str) -> Result<Option<String>, ApiError> {
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
fn read_api_key() -> Result<String, ApiError> {
|
fn read_api_key() -> Result<String, ApiError> {
|
||||||
let auth = AuthSource::from_env()?;
|
let auth = AuthSource::from_env_or_saved()?;
|
||||||
auth.api_key()
|
auth.api_key()
|
||||||
.or_else(|| auth.bearer_token())
|
.or_else(|| auth.bearer_token())
|
||||||
.map(ToOwned::to_owned)
|
.map(ToOwned::to_owned)
|
||||||
@@ -328,7 +520,8 @@ fn read_auth_token() -> Option<String> {
|
|||||||
.and_then(std::convert::identity)
|
.and_then(std::convert::identity)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn read_base_url() -> String {
|
#[must_use]
|
||||||
|
pub fn read_base_url() -> String {
|
||||||
std::env::var("ANTHROPIC_BASE_URL").unwrap_or_else(|_| DEFAULT_BASE_URL.to_string())
|
std::env::var("ANTHROPIC_BASE_URL").unwrap_or_else(|_| DEFAULT_BASE_URL.to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -424,10 +617,18 @@ struct AnthropicErrorBody {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::{ALT_REQUEST_ID_HEADER, REQUEST_ID_HEADER};
|
use super::{ALT_REQUEST_ID_HEADER, REQUEST_ID_HEADER};
|
||||||
|
use std::io::{Read, Write};
|
||||||
|
use std::net::TcpListener;
|
||||||
use std::sync::{Mutex, OnceLock};
|
use std::sync::{Mutex, OnceLock};
|
||||||
use std::time::Duration;
|
use std::thread;
|
||||||
|
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||||
|
|
||||||
use crate::client::{AuthSource, OAuthTokenSet};
|
use runtime::{clear_oauth_credentials, save_oauth_credentials, OAuthConfig};
|
||||||
|
|
||||||
|
use crate::client::{
|
||||||
|
now_unix_timestamp, oauth_token_is_expired, resolve_saved_oauth_token,
|
||||||
|
resolve_startup_auth_source, AnthropicClient, AuthSource, OAuthTokenSet,
|
||||||
|
};
|
||||||
use crate::types::{ContentBlockDelta, MessageRequest};
|
use crate::types::{ContentBlockDelta, MessageRequest};
|
||||||
|
|
||||||
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
|
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
|
||||||
@@ -437,11 +638,53 @@ mod tests {
|
|||||||
.expect("env lock")
|
.expect("env lock")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn temp_config_home() -> std::path::PathBuf {
|
||||||
|
std::env::temp_dir().join(format!(
|
||||||
|
"api-oauth-test-{}-{}",
|
||||||
|
std::process::id(),
|
||||||
|
SystemTime::now()
|
||||||
|
.duration_since(UNIX_EPOCH)
|
||||||
|
.expect("time")
|
||||||
|
.as_nanos()
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sample_oauth_config(token_url: String) -> OAuthConfig {
|
||||||
|
OAuthConfig {
|
||||||
|
client_id: "runtime-client".to_string(),
|
||||||
|
authorize_url: "https://console.test/oauth/authorize".to_string(),
|
||||||
|
token_url,
|
||||||
|
callback_port: Some(4545),
|
||||||
|
manual_redirect_url: Some("https://console.test/oauth/callback".to_string()),
|
||||||
|
scopes: vec!["org:read".to_string(), "user:write".to_string()],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn spawn_token_server(response_body: &'static str) -> String {
|
||||||
|
let listener = TcpListener::bind("127.0.0.1:0").expect("bind listener");
|
||||||
|
let address = listener.local_addr().expect("local addr");
|
||||||
|
thread::spawn(move || {
|
||||||
|
let (mut stream, _) = listener.accept().expect("accept connection");
|
||||||
|
let mut buffer = [0_u8; 4096];
|
||||||
|
let _ = stream.read(&mut buffer).expect("read request");
|
||||||
|
let response = format!(
|
||||||
|
"HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\n\r\n{}",
|
||||||
|
response_body.len(),
|
||||||
|
response_body
|
||||||
|
);
|
||||||
|
stream
|
||||||
|
.write_all(response.as_bytes())
|
||||||
|
.expect("write response");
|
||||||
|
});
|
||||||
|
format!("http://{address}/oauth/token")
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn read_api_key_requires_presence() {
|
fn read_api_key_requires_presence() {
|
||||||
let _guard = env_lock();
|
let _guard = env_lock();
|
||||||
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
|
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
|
||||||
std::env::remove_var("ANTHROPIC_API_KEY");
|
std::env::remove_var("ANTHROPIC_API_KEY");
|
||||||
|
std::env::remove_var("CLAUDE_CONFIG_HOME");
|
||||||
let error = super::read_api_key().expect_err("missing key should error");
|
let error = super::read_api_key().expect_err("missing key should error");
|
||||||
assert!(matches!(error, crate::error::ApiError::MissingApiKey));
|
assert!(matches!(error, crate::error::ApiError::MissingApiKey));
|
||||||
}
|
}
|
||||||
@@ -453,6 +696,7 @@ mod tests {
|
|||||||
std::env::remove_var("ANTHROPIC_API_KEY");
|
std::env::remove_var("ANTHROPIC_API_KEY");
|
||||||
let error = super::read_api_key().expect_err("empty key should error");
|
let error = super::read_api_key().expect_err("empty key should error");
|
||||||
assert!(matches!(error, crate::error::ApiError::MissingApiKey));
|
assert!(matches!(error, crate::error::ApiError::MissingApiKey));
|
||||||
|
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -500,10 +744,170 @@ mod tests {
|
|||||||
std::env::remove_var("ANTHROPIC_API_KEY");
|
std::env::remove_var("ANTHROPIC_API_KEY");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn auth_source_from_saved_oauth_when_env_absent() {
|
||||||
|
let _guard = env_lock();
|
||||||
|
let config_home = temp_config_home();
|
||||||
|
std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
|
||||||
|
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
|
||||||
|
std::env::remove_var("ANTHROPIC_API_KEY");
|
||||||
|
save_oauth_credentials(&runtime::OAuthTokenSet {
|
||||||
|
access_token: "saved-access-token".to_string(),
|
||||||
|
refresh_token: Some("refresh".to_string()),
|
||||||
|
expires_at: Some(now_unix_timestamp() + 300),
|
||||||
|
scopes: vec!["scope:a".to_string()],
|
||||||
|
})
|
||||||
|
.expect("save oauth credentials");
|
||||||
|
|
||||||
|
let auth = AuthSource::from_env_or_saved().expect("saved auth");
|
||||||
|
assert_eq!(auth.bearer_token(), Some("saved-access-token"));
|
||||||
|
|
||||||
|
clear_oauth_credentials().expect("clear credentials");
|
||||||
|
std::env::remove_var("CLAUDE_CONFIG_HOME");
|
||||||
|
std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn oauth_token_expiry_uses_expires_at_timestamp() {
|
||||||
|
assert!(oauth_token_is_expired(&OAuthTokenSet {
|
||||||
|
access_token: "access-token".to_string(),
|
||||||
|
refresh_token: None,
|
||||||
|
expires_at: Some(1),
|
||||||
|
scopes: Vec::new(),
|
||||||
|
}));
|
||||||
|
assert!(!oauth_token_is_expired(&OAuthTokenSet {
|
||||||
|
access_token: "access-token".to_string(),
|
||||||
|
refresh_token: None,
|
||||||
|
expires_at: Some(now_unix_timestamp() + 60),
|
||||||
|
scopes: Vec::new(),
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn resolve_saved_oauth_token_refreshes_expired_credentials() {
|
||||||
|
let _guard = env_lock();
|
||||||
|
let config_home = temp_config_home();
|
||||||
|
std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
|
||||||
|
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
|
||||||
|
std::env::remove_var("ANTHROPIC_API_KEY");
|
||||||
|
save_oauth_credentials(&runtime::OAuthTokenSet {
|
||||||
|
access_token: "expired-access-token".to_string(),
|
||||||
|
refresh_token: Some("refresh-token".to_string()),
|
||||||
|
expires_at: Some(1),
|
||||||
|
scopes: vec!["scope:a".to_string()],
|
||||||
|
})
|
||||||
|
.expect("save expired oauth credentials");
|
||||||
|
|
||||||
|
let token_url = spawn_token_server(
|
||||||
|
"{\"access_token\":\"refreshed-token\",\"refresh_token\":\"fresh-refresh\",\"expires_at\":9999999999,\"scopes\":[\"scope:a\"]}",
|
||||||
|
);
|
||||||
|
let resolved = resolve_saved_oauth_token(&sample_oauth_config(token_url))
|
||||||
|
.expect("resolve refreshed token")
|
||||||
|
.expect("token set present");
|
||||||
|
assert_eq!(resolved.access_token, "refreshed-token");
|
||||||
|
let stored = runtime::load_oauth_credentials()
|
||||||
|
.expect("load stored credentials")
|
||||||
|
.expect("stored token set");
|
||||||
|
assert_eq!(stored.access_token, "refreshed-token");
|
||||||
|
|
||||||
|
clear_oauth_credentials().expect("clear credentials");
|
||||||
|
std::env::remove_var("CLAUDE_CONFIG_HOME");
|
||||||
|
std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn resolve_startup_auth_source_uses_saved_oauth_without_loading_config() {
|
||||||
|
let _guard = env_lock();
|
||||||
|
let config_home = temp_config_home();
|
||||||
|
std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
|
||||||
|
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
|
||||||
|
std::env::remove_var("ANTHROPIC_API_KEY");
|
||||||
|
save_oauth_credentials(&runtime::OAuthTokenSet {
|
||||||
|
access_token: "saved-access-token".to_string(),
|
||||||
|
refresh_token: Some("refresh".to_string()),
|
||||||
|
expires_at: Some(now_unix_timestamp() + 300),
|
||||||
|
scopes: vec!["scope:a".to_string()],
|
||||||
|
})
|
||||||
|
.expect("save oauth credentials");
|
||||||
|
|
||||||
|
let auth = resolve_startup_auth_source(|| panic!("config should not be loaded"))
|
||||||
|
.expect("startup auth");
|
||||||
|
assert_eq!(auth.bearer_token(), Some("saved-access-token"));
|
||||||
|
|
||||||
|
clear_oauth_credentials().expect("clear credentials");
|
||||||
|
std::env::remove_var("CLAUDE_CONFIG_HOME");
|
||||||
|
std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn resolve_startup_auth_source_errors_when_refreshable_token_lacks_config() {
|
||||||
|
let _guard = env_lock();
|
||||||
|
let config_home = temp_config_home();
|
||||||
|
std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
|
||||||
|
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
|
||||||
|
std::env::remove_var("ANTHROPIC_API_KEY");
|
||||||
|
save_oauth_credentials(&runtime::OAuthTokenSet {
|
||||||
|
access_token: "expired-access-token".to_string(),
|
||||||
|
refresh_token: Some("refresh-token".to_string()),
|
||||||
|
expires_at: Some(1),
|
||||||
|
scopes: vec!["scope:a".to_string()],
|
||||||
|
})
|
||||||
|
.expect("save expired oauth credentials");
|
||||||
|
|
||||||
|
let error =
|
||||||
|
resolve_startup_auth_source(|| Ok(None)).expect_err("missing config should error");
|
||||||
|
assert!(
|
||||||
|
matches!(error, crate::error::ApiError::Auth(message) if message.contains("runtime OAuth config is missing"))
|
||||||
|
);
|
||||||
|
|
||||||
|
let stored = runtime::load_oauth_credentials()
|
||||||
|
.expect("load stored credentials")
|
||||||
|
.expect("stored token set");
|
||||||
|
assert_eq!(stored.access_token, "expired-access-token");
|
||||||
|
assert_eq!(stored.refresh_token.as_deref(), Some("refresh-token"));
|
||||||
|
|
||||||
|
clear_oauth_credentials().expect("clear credentials");
|
||||||
|
std::env::remove_var("CLAUDE_CONFIG_HOME");
|
||||||
|
std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn resolve_saved_oauth_token_preserves_refresh_token_when_refresh_response_omits_it() {
|
||||||
|
let _guard = env_lock();
|
||||||
|
let config_home = temp_config_home();
|
||||||
|
std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
|
||||||
|
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
|
||||||
|
std::env::remove_var("ANTHROPIC_API_KEY");
|
||||||
|
save_oauth_credentials(&runtime::OAuthTokenSet {
|
||||||
|
access_token: "expired-access-token".to_string(),
|
||||||
|
refresh_token: Some("refresh-token".to_string()),
|
||||||
|
expires_at: Some(1),
|
||||||
|
scopes: vec!["scope:a".to_string()],
|
||||||
|
})
|
||||||
|
.expect("save expired oauth credentials");
|
||||||
|
|
||||||
|
let token_url = spawn_token_server(
|
||||||
|
"{\"access_token\":\"refreshed-token\",\"expires_at\":9999999999,\"scopes\":[\"scope:a\"]}",
|
||||||
|
);
|
||||||
|
let resolved = resolve_saved_oauth_token(&sample_oauth_config(token_url))
|
||||||
|
.expect("resolve refreshed token")
|
||||||
|
.expect("token set present");
|
||||||
|
assert_eq!(resolved.access_token, "refreshed-token");
|
||||||
|
assert_eq!(resolved.refresh_token.as_deref(), Some("refresh-token"));
|
||||||
|
let stored = runtime::load_oauth_credentials()
|
||||||
|
.expect("load stored credentials")
|
||||||
|
.expect("stored token set");
|
||||||
|
assert_eq!(stored.refresh_token.as_deref(), Some("refresh-token"));
|
||||||
|
|
||||||
|
clear_oauth_credentials().expect("clear credentials");
|
||||||
|
std::env::remove_var("CLAUDE_CONFIG_HOME");
|
||||||
|
std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn message_request_stream_helper_sets_stream_true() {
|
fn message_request_stream_helper_sets_stream_true() {
|
||||||
let request = MessageRequest {
|
let request = MessageRequest {
|
||||||
model: "claude-3-7-sonnet-latest".to_string(),
|
model: "claude-opus-4-6".to_string(),
|
||||||
max_tokens: 64,
|
max_tokens: 64,
|
||||||
messages: vec![],
|
messages: vec![],
|
||||||
system: None,
|
system: None,
|
||||||
@@ -517,7 +921,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn backoff_doubles_until_maximum() {
|
fn backoff_doubles_until_maximum() {
|
||||||
let client = super::AnthropicClient::new("test-key").with_retry_policy(
|
let client = AnthropicClient::new("test-key").with_retry_policy(
|
||||||
3,
|
3,
|
||||||
Duration::from_millis(10),
|
Duration::from_millis(10),
|
||||||
Duration::from_millis(25),
|
Duration::from_millis(25),
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ use std::time::Duration;
|
|||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum ApiError {
|
pub enum ApiError {
|
||||||
MissingApiKey,
|
MissingApiKey,
|
||||||
|
ExpiredOAuthToken,
|
||||||
|
Auth(String),
|
||||||
InvalidApiKeyEnv(VarError),
|
InvalidApiKeyEnv(VarError),
|
||||||
Http(reqwest::Error),
|
Http(reqwest::Error),
|
||||||
Io(std::io::Error),
|
Io(std::io::Error),
|
||||||
@@ -35,6 +37,8 @@ impl ApiError {
|
|||||||
Self::Api { retryable, .. } => *retryable,
|
Self::Api { retryable, .. } => *retryable,
|
||||||
Self::RetriesExhausted { last_error, .. } => last_error.is_retryable(),
|
Self::RetriesExhausted { last_error, .. } => last_error.is_retryable(),
|
||||||
Self::MissingApiKey
|
Self::MissingApiKey
|
||||||
|
| Self::ExpiredOAuthToken
|
||||||
|
| Self::Auth(_)
|
||||||
| Self::InvalidApiKeyEnv(_)
|
| Self::InvalidApiKeyEnv(_)
|
||||||
| Self::Io(_)
|
| Self::Io(_)
|
||||||
| Self::Json(_)
|
| Self::Json(_)
|
||||||
@@ -53,6 +57,13 @@ impl Display for ApiError {
|
|||||||
"ANTHROPIC_AUTH_TOKEN or ANTHROPIC_API_KEY is not set; export one before calling the Anthropic API"
|
"ANTHROPIC_AUTH_TOKEN or ANTHROPIC_API_KEY is not set; export one before calling the Anthropic API"
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
Self::ExpiredOAuthToken => {
|
||||||
|
write!(
|
||||||
|
f,
|
||||||
|
"saved OAuth token is expired and no refresh token is available"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
Self::Auth(message) => write!(f, "auth error: {message}"),
|
||||||
Self::InvalidApiKeyEnv(error) => {
|
Self::InvalidApiKeyEnv(error) => {
|
||||||
write!(
|
write!(
|
||||||
f,
|
f,
|
||||||
|
|||||||
@@ -3,7 +3,10 @@ mod error;
|
|||||||
mod sse;
|
mod sse;
|
||||||
mod types;
|
mod types;
|
||||||
|
|
||||||
pub use client::{AnthropicClient, AuthSource, MessageStream, OAuthTokenSet};
|
pub use client::{
|
||||||
|
oauth_token_is_expired, read_base_url, resolve_saved_oauth_token,
|
||||||
|
resolve_startup_auth_source, AnthropicClient, AuthSource, MessageStream, OAuthTokenSet,
|
||||||
|
};
|
||||||
pub use error::ApiError;
|
pub use error::ApiError;
|
||||||
pub use sse::{parse_frame, SseParser};
|
pub use sse::{parse_frame, SseParser};
|
||||||
pub use types::{
|
pub use types::{
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
use std::env;
|
||||||
use std::io;
|
use std::io;
|
||||||
use std::process::{Command, Stdio};
|
use std::process::{Command, Stdio};
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
@@ -7,6 +8,12 @@ use tokio::process::Command as TokioCommand;
|
|||||||
use tokio::runtime::Builder;
|
use tokio::runtime::Builder;
|
||||||
use tokio::time::timeout;
|
use tokio::time::timeout;
|
||||||
|
|
||||||
|
use crate::sandbox::{
|
||||||
|
build_linux_sandbox_command, resolve_sandbox_status_for_request, FilesystemIsolationMode,
|
||||||
|
SandboxConfig, SandboxStatus,
|
||||||
|
};
|
||||||
|
use crate::ConfigLoader;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||||
pub struct BashCommandInput {
|
pub struct BashCommandInput {
|
||||||
pub command: String,
|
pub command: String,
|
||||||
@@ -16,6 +23,14 @@ pub struct BashCommandInput {
|
|||||||
pub run_in_background: Option<bool>,
|
pub run_in_background: Option<bool>,
|
||||||
#[serde(rename = "dangerouslyDisableSandbox")]
|
#[serde(rename = "dangerouslyDisableSandbox")]
|
||||||
pub dangerously_disable_sandbox: Option<bool>,
|
pub dangerously_disable_sandbox: Option<bool>,
|
||||||
|
#[serde(rename = "namespaceRestrictions")]
|
||||||
|
pub namespace_restrictions: Option<bool>,
|
||||||
|
#[serde(rename = "isolateNetwork")]
|
||||||
|
pub isolate_network: Option<bool>,
|
||||||
|
#[serde(rename = "filesystemMode")]
|
||||||
|
pub filesystem_mode: Option<FilesystemIsolationMode>,
|
||||||
|
#[serde(rename = "allowedMounts")]
|
||||||
|
pub allowed_mounts: Option<Vec<String>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||||
@@ -45,13 +60,17 @@ pub struct BashCommandOutput {
|
|||||||
pub persisted_output_path: Option<String>,
|
pub persisted_output_path: Option<String>,
|
||||||
#[serde(rename = "persistedOutputSize")]
|
#[serde(rename = "persistedOutputSize")]
|
||||||
pub persisted_output_size: Option<u64>,
|
pub persisted_output_size: Option<u64>,
|
||||||
|
#[serde(rename = "sandboxStatus")]
|
||||||
|
pub sandbox_status: Option<SandboxStatus>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn execute_bash(input: BashCommandInput) -> io::Result<BashCommandOutput> {
|
pub fn execute_bash(input: BashCommandInput) -> io::Result<BashCommandOutput> {
|
||||||
|
let cwd = env::current_dir()?;
|
||||||
|
let sandbox_status = sandbox_status_for_input(&input, &cwd);
|
||||||
|
|
||||||
if input.run_in_background.unwrap_or(false) {
|
if input.run_in_background.unwrap_or(false) {
|
||||||
let child = Command::new("sh")
|
let mut child = prepare_command(&input.command, &cwd, &sandbox_status, false);
|
||||||
.arg("-lc")
|
let child = child
|
||||||
.arg(&input.command)
|
|
||||||
.stdin(Stdio::null())
|
.stdin(Stdio::null())
|
||||||
.stdout(Stdio::null())
|
.stdout(Stdio::null())
|
||||||
.stderr(Stdio::null())
|
.stderr(Stdio::null())
|
||||||
@@ -72,16 +91,20 @@ pub fn execute_bash(input: BashCommandInput) -> io::Result<BashCommandOutput> {
|
|||||||
structured_content: None,
|
structured_content: None,
|
||||||
persisted_output_path: None,
|
persisted_output_path: None,
|
||||||
persisted_output_size: None,
|
persisted_output_size: None,
|
||||||
|
sandbox_status: Some(sandbox_status),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
let runtime = Builder::new_current_thread().enable_all().build()?;
|
let runtime = Builder::new_current_thread().enable_all().build()?;
|
||||||
runtime.block_on(execute_bash_async(input))
|
runtime.block_on(execute_bash_async(input, sandbox_status, cwd))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn execute_bash_async(input: BashCommandInput) -> io::Result<BashCommandOutput> {
|
async fn execute_bash_async(
|
||||||
let mut command = TokioCommand::new("sh");
|
input: BashCommandInput,
|
||||||
command.arg("-lc").arg(&input.command);
|
sandbox_status: SandboxStatus,
|
||||||
|
cwd: std::path::PathBuf,
|
||||||
|
) -> io::Result<BashCommandOutput> {
|
||||||
|
let mut command = prepare_tokio_command(&input.command, &cwd, &sandbox_status, true);
|
||||||
|
|
||||||
let output_result = if let Some(timeout_ms) = input.timeout {
|
let output_result = if let Some(timeout_ms) = input.timeout {
|
||||||
match timeout(Duration::from_millis(timeout_ms), command.output()).await {
|
match timeout(Duration::from_millis(timeout_ms), command.output()).await {
|
||||||
@@ -102,6 +125,7 @@ async fn execute_bash_async(input: BashCommandInput) -> io::Result<BashCommandOu
|
|||||||
structured_content: None,
|
structured_content: None,
|
||||||
persisted_output_path: None,
|
persisted_output_path: None,
|
||||||
persisted_output_size: None,
|
persisted_output_size: None,
|
||||||
|
sandbox_status: Some(sandbox_status),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -136,12 +160,88 @@ async fn execute_bash_async(input: BashCommandInput) -> io::Result<BashCommandOu
|
|||||||
structured_content: None,
|
structured_content: None,
|
||||||
persisted_output_path: None,
|
persisted_output_path: None,
|
||||||
persisted_output_size: None,
|
persisted_output_size: None,
|
||||||
|
sandbox_status: Some(sandbox_status),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn sandbox_status_for_input(input: &BashCommandInput, cwd: &std::path::Path) -> SandboxStatus {
|
||||||
|
let config = ConfigLoader::default_for(cwd).load().map_or_else(
|
||||||
|
|_| SandboxConfig::default(),
|
||||||
|
|runtime_config| runtime_config.sandbox().clone(),
|
||||||
|
);
|
||||||
|
let request = config.resolve_request(
|
||||||
|
input.dangerously_disable_sandbox.map(|disabled| !disabled),
|
||||||
|
input.namespace_restrictions,
|
||||||
|
input.isolate_network,
|
||||||
|
input.filesystem_mode,
|
||||||
|
input.allowed_mounts.clone(),
|
||||||
|
);
|
||||||
|
resolve_sandbox_status_for_request(&request, cwd)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn prepare_command(
|
||||||
|
command: &str,
|
||||||
|
cwd: &std::path::Path,
|
||||||
|
sandbox_status: &SandboxStatus,
|
||||||
|
create_dirs: bool,
|
||||||
|
) -> Command {
|
||||||
|
if create_dirs {
|
||||||
|
prepare_sandbox_dirs(cwd);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(launcher) = build_linux_sandbox_command(command, cwd, sandbox_status) {
|
||||||
|
let mut prepared = Command::new(launcher.program);
|
||||||
|
prepared.args(launcher.args);
|
||||||
|
prepared.current_dir(cwd);
|
||||||
|
prepared.envs(launcher.env);
|
||||||
|
return prepared;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut prepared = Command::new("sh");
|
||||||
|
prepared.arg("-lc").arg(command).current_dir(cwd);
|
||||||
|
if sandbox_status.filesystem_active {
|
||||||
|
prepared.env("HOME", cwd.join(".sandbox-home"));
|
||||||
|
prepared.env("TMPDIR", cwd.join(".sandbox-tmp"));
|
||||||
|
}
|
||||||
|
prepared
|
||||||
|
}
|
||||||
|
|
||||||
|
fn prepare_tokio_command(
|
||||||
|
command: &str,
|
||||||
|
cwd: &std::path::Path,
|
||||||
|
sandbox_status: &SandboxStatus,
|
||||||
|
create_dirs: bool,
|
||||||
|
) -> TokioCommand {
|
||||||
|
if create_dirs {
|
||||||
|
prepare_sandbox_dirs(cwd);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(launcher) = build_linux_sandbox_command(command, cwd, sandbox_status) {
|
||||||
|
let mut prepared = TokioCommand::new(launcher.program);
|
||||||
|
prepared.args(launcher.args);
|
||||||
|
prepared.current_dir(cwd);
|
||||||
|
prepared.envs(launcher.env);
|
||||||
|
return prepared;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut prepared = TokioCommand::new("sh");
|
||||||
|
prepared.arg("-lc").arg(command).current_dir(cwd);
|
||||||
|
if sandbox_status.filesystem_active {
|
||||||
|
prepared.env("HOME", cwd.join(".sandbox-home"));
|
||||||
|
prepared.env("TMPDIR", cwd.join(".sandbox-tmp"));
|
||||||
|
}
|
||||||
|
prepared
|
||||||
|
}
|
||||||
|
|
||||||
|
fn prepare_sandbox_dirs(cwd: &std::path::Path) {
|
||||||
|
let _ = std::fs::create_dir_all(cwd.join(".sandbox-home"));
|
||||||
|
let _ = std::fs::create_dir_all(cwd.join(".sandbox-tmp"));
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::{execute_bash, BashCommandInput};
|
use super::{execute_bash, BashCommandInput};
|
||||||
|
use crate::sandbox::FilesystemIsolationMode;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn executes_simple_command() {
|
fn executes_simple_command() {
|
||||||
@@ -151,10 +251,33 @@ mod tests {
|
|||||||
description: None,
|
description: None,
|
||||||
run_in_background: Some(false),
|
run_in_background: Some(false),
|
||||||
dangerously_disable_sandbox: Some(false),
|
dangerously_disable_sandbox: Some(false),
|
||||||
|
namespace_restrictions: Some(false),
|
||||||
|
isolate_network: Some(false),
|
||||||
|
filesystem_mode: Some(FilesystemIsolationMode::WorkspaceOnly),
|
||||||
|
allowed_mounts: None,
|
||||||
})
|
})
|
||||||
.expect("bash command should execute");
|
.expect("bash command should execute");
|
||||||
|
|
||||||
assert_eq!(output.stdout, "hello");
|
assert_eq!(output.stdout, "hello");
|
||||||
assert!(!output.interrupted);
|
assert!(!output.interrupted);
|
||||||
|
assert!(output.sandbox_status.is_some());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn disables_sandbox_when_requested() {
|
||||||
|
let output = execute_bash(BashCommandInput {
|
||||||
|
command: String::from("printf 'hello'"),
|
||||||
|
timeout: Some(1_000),
|
||||||
|
description: None,
|
||||||
|
run_in_background: Some(false),
|
||||||
|
dangerously_disable_sandbox: Some(true),
|
||||||
|
namespace_restrictions: None,
|
||||||
|
isolate_network: None,
|
||||||
|
filesystem_mode: None,
|
||||||
|
allowed_mounts: None,
|
||||||
|
})
|
||||||
|
.expect("bash command should execute");
|
||||||
|
|
||||||
|
assert!(!output.sandbox_status.expect("sandbox status").enabled);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ use std::fs;
|
|||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
|
|
||||||
use crate::json::JsonValue;
|
use crate::json::JsonValue;
|
||||||
|
use crate::sandbox::{FilesystemIsolationMode, SandboxConfig};
|
||||||
|
|
||||||
pub const CLAUDE_CODE_SETTINGS_SCHEMA_NAME: &str = "SettingsSchema";
|
pub const CLAUDE_CODE_SETTINGS_SCHEMA_NAME: &str = "SettingsSchema";
|
||||||
|
|
||||||
@@ -14,6 +15,13 @@ pub enum ConfigSource {
|
|||||||
Local,
|
Local,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub enum ResolvedPermissionMode {
|
||||||
|
ReadOnly,
|
||||||
|
WorkspaceWrite,
|
||||||
|
DangerFullAccess,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
pub struct ConfigEntry {
|
pub struct ConfigEntry {
|
||||||
pub source: ConfigSource,
|
pub source: ConfigSource,
|
||||||
@@ -31,6 +39,9 @@ pub struct RuntimeConfig {
|
|||||||
pub struct RuntimeFeatureConfig {
|
pub struct RuntimeFeatureConfig {
|
||||||
mcp: McpConfigCollection,
|
mcp: McpConfigCollection,
|
||||||
oauth: Option<OAuthConfig>,
|
oauth: Option<OAuthConfig>,
|
||||||
|
model: Option<String>,
|
||||||
|
permission_mode: Option<ResolvedPermissionMode>,
|
||||||
|
sandbox: SandboxConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Default)]
|
#[derive(Debug, Clone, PartialEq, Eq, Default)]
|
||||||
@@ -165,11 +176,23 @@ impl ConfigLoader {
|
|||||||
|
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn discover(&self) -> Vec<ConfigEntry> {
|
pub fn discover(&self) -> Vec<ConfigEntry> {
|
||||||
|
let user_legacy_path = self.config_home.parent().map_or_else(
|
||||||
|
|| PathBuf::from(".claude.json"),
|
||||||
|
|parent| parent.join(".claude.json"),
|
||||||
|
);
|
||||||
vec![
|
vec![
|
||||||
|
ConfigEntry {
|
||||||
|
source: ConfigSource::User,
|
||||||
|
path: user_legacy_path,
|
||||||
|
},
|
||||||
ConfigEntry {
|
ConfigEntry {
|
||||||
source: ConfigSource::User,
|
source: ConfigSource::User,
|
||||||
path: self.config_home.join("settings.json"),
|
path: self.config_home.join("settings.json"),
|
||||||
},
|
},
|
||||||
|
ConfigEntry {
|
||||||
|
source: ConfigSource::Project,
|
||||||
|
path: self.cwd.join(".claude.json"),
|
||||||
|
},
|
||||||
ConfigEntry {
|
ConfigEntry {
|
||||||
source: ConfigSource::Project,
|
source: ConfigSource::Project,
|
||||||
path: self.cwd.join(".claude").join("settings.json"),
|
path: self.cwd.join(".claude").join("settings.json"),
|
||||||
@@ -195,14 +218,16 @@ impl ConfigLoader {
|
|||||||
loaded_entries.push(entry);
|
loaded_entries.push(entry);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let merged_value = JsonValue::Object(merged.clone());
|
||||||
|
|
||||||
let feature_config = RuntimeFeatureConfig {
|
let feature_config = RuntimeFeatureConfig {
|
||||||
mcp: McpConfigCollection {
|
mcp: McpConfigCollection {
|
||||||
servers: mcp_servers,
|
servers: mcp_servers,
|
||||||
},
|
},
|
||||||
oauth: parse_optional_oauth_config(
|
oauth: parse_optional_oauth_config(&merged_value, "merged settings.oauth")?,
|
||||||
&JsonValue::Object(merged.clone()),
|
model: parse_optional_model(&merged_value),
|
||||||
"merged settings.oauth",
|
permission_mode: parse_optional_permission_mode(&merged_value)?,
|
||||||
)?,
|
sandbox: parse_optional_sandbox_config(&merged_value)?,
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(RuntimeConfig {
|
Ok(RuntimeConfig {
|
||||||
@@ -257,6 +282,21 @@ impl RuntimeConfig {
|
|||||||
pub fn oauth(&self) -> Option<&OAuthConfig> {
|
pub fn oauth(&self) -> Option<&OAuthConfig> {
|
||||||
self.feature_config.oauth.as_ref()
|
self.feature_config.oauth.as_ref()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn model(&self) -> Option<&str> {
|
||||||
|
self.feature_config.model.as_deref()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn permission_mode(&self) -> Option<ResolvedPermissionMode> {
|
||||||
|
self.feature_config.permission_mode
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn sandbox(&self) -> &SandboxConfig {
|
||||||
|
&self.feature_config.sandbox
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RuntimeFeatureConfig {
|
impl RuntimeFeatureConfig {
|
||||||
@@ -269,6 +309,21 @@ impl RuntimeFeatureConfig {
|
|||||||
pub fn oauth(&self) -> Option<&OAuthConfig> {
|
pub fn oauth(&self) -> Option<&OAuthConfig> {
|
||||||
self.oauth.as_ref()
|
self.oauth.as_ref()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn model(&self) -> Option<&str> {
|
||||||
|
self.model.as_deref()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn permission_mode(&self) -> Option<ResolvedPermissionMode> {
|
||||||
|
self.permission_mode
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn sandbox(&self) -> &SandboxConfig {
|
||||||
|
&self.sandbox
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl McpConfigCollection {
|
impl McpConfigCollection {
|
||||||
@@ -307,6 +362,7 @@ impl McpServerConfig {
|
|||||||
fn read_optional_json_object(
|
fn read_optional_json_object(
|
||||||
path: &Path,
|
path: &Path,
|
||||||
) -> Result<Option<BTreeMap<String, JsonValue>>, ConfigError> {
|
) -> Result<Option<BTreeMap<String, JsonValue>>, ConfigError> {
|
||||||
|
let is_legacy_config = path.file_name().and_then(|name| name.to_str()) == Some(".claude.json");
|
||||||
let contents = match fs::read_to_string(path) {
|
let contents = match fs::read_to_string(path) {
|
||||||
Ok(contents) => contents,
|
Ok(contents) => contents,
|
||||||
Err(error) if error.kind() == std::io::ErrorKind::NotFound => return Ok(None),
|
Err(error) if error.kind() == std::io::ErrorKind::NotFound => return Ok(None),
|
||||||
@@ -317,14 +373,20 @@ fn read_optional_json_object(
|
|||||||
return Ok(Some(BTreeMap::new()));
|
return Ok(Some(BTreeMap::new()));
|
||||||
}
|
}
|
||||||
|
|
||||||
let parsed = JsonValue::parse(&contents)
|
let parsed = match JsonValue::parse(&contents) {
|
||||||
.map_err(|error| ConfigError::Parse(format!("{}: {error}", path.display())))?;
|
Ok(parsed) => parsed,
|
||||||
let object = parsed.as_object().ok_or_else(|| {
|
Err(error) if is_legacy_config => return Ok(None),
|
||||||
ConfigError::Parse(format!(
|
Err(error) => return Err(ConfigError::Parse(format!("{}: {error}", path.display()))),
|
||||||
|
};
|
||||||
|
let Some(object) = parsed.as_object() else {
|
||||||
|
if is_legacy_config {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
return Err(ConfigError::Parse(format!(
|
||||||
"{}: top-level settings value must be a JSON object",
|
"{}: top-level settings value must be a JSON object",
|
||||||
path.display()
|
path.display()
|
||||||
))
|
)));
|
||||||
})?;
|
};
|
||||||
Ok(Some(object.clone()))
|
Ok(Some(object.clone()))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -355,6 +417,83 @@ fn merge_mcp_servers(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn parse_optional_model(root: &JsonValue) -> Option<String> {
|
||||||
|
root.as_object()
|
||||||
|
.and_then(|object| object.get("model"))
|
||||||
|
.and_then(JsonValue::as_str)
|
||||||
|
.map(ToOwned::to_owned)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_optional_permission_mode(
|
||||||
|
root: &JsonValue,
|
||||||
|
) -> Result<Option<ResolvedPermissionMode>, ConfigError> {
|
||||||
|
let Some(object) = root.as_object() else {
|
||||||
|
return Ok(None);
|
||||||
|
};
|
||||||
|
if let Some(mode) = object.get("permissionMode").and_then(JsonValue::as_str) {
|
||||||
|
return parse_permission_mode_label(mode, "merged settings.permissionMode").map(Some);
|
||||||
|
}
|
||||||
|
let Some(mode) = object
|
||||||
|
.get("permissions")
|
||||||
|
.and_then(JsonValue::as_object)
|
||||||
|
.and_then(|permissions| permissions.get("defaultMode"))
|
||||||
|
.and_then(JsonValue::as_str)
|
||||||
|
else {
|
||||||
|
return Ok(None);
|
||||||
|
};
|
||||||
|
parse_permission_mode_label(mode, "merged settings.permissions.defaultMode").map(Some)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_permission_mode_label(
|
||||||
|
mode: &str,
|
||||||
|
context: &str,
|
||||||
|
) -> Result<ResolvedPermissionMode, ConfigError> {
|
||||||
|
match mode {
|
||||||
|
"default" | "plan" | "read-only" => Ok(ResolvedPermissionMode::ReadOnly),
|
||||||
|
"acceptEdits" | "auto" | "workspace-write" => Ok(ResolvedPermissionMode::WorkspaceWrite),
|
||||||
|
"dontAsk" | "danger-full-access" => Ok(ResolvedPermissionMode::DangerFullAccess),
|
||||||
|
other => Err(ConfigError::Parse(format!(
|
||||||
|
"{context}: unsupported permission mode {other}"
|
||||||
|
))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_optional_sandbox_config(root: &JsonValue) -> Result<SandboxConfig, ConfigError> {
|
||||||
|
let Some(object) = root.as_object() else {
|
||||||
|
return Ok(SandboxConfig::default());
|
||||||
|
};
|
||||||
|
let Some(sandbox_value) = object.get("sandbox") else {
|
||||||
|
return Ok(SandboxConfig::default());
|
||||||
|
};
|
||||||
|
let sandbox = expect_object(sandbox_value, "merged settings.sandbox")?;
|
||||||
|
let filesystem_mode = optional_string(sandbox, "filesystemMode", "merged settings.sandbox")?
|
||||||
|
.map(parse_filesystem_mode_label)
|
||||||
|
.transpose()?;
|
||||||
|
Ok(SandboxConfig {
|
||||||
|
enabled: optional_bool(sandbox, "enabled", "merged settings.sandbox")?,
|
||||||
|
namespace_restrictions: optional_bool(
|
||||||
|
sandbox,
|
||||||
|
"namespaceRestrictions",
|
||||||
|
"merged settings.sandbox",
|
||||||
|
)?,
|
||||||
|
network_isolation: optional_bool(sandbox, "networkIsolation", "merged settings.sandbox")?,
|
||||||
|
filesystem_mode,
|
||||||
|
allowed_mounts: optional_string_array(sandbox, "allowedMounts", "merged settings.sandbox")?
|
||||||
|
.unwrap_or_default(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_filesystem_mode_label(value: &str) -> Result<FilesystemIsolationMode, ConfigError> {
|
||||||
|
match value {
|
||||||
|
"off" => Ok(FilesystemIsolationMode::Off),
|
||||||
|
"workspace-only" => Ok(FilesystemIsolationMode::WorkspaceOnly),
|
||||||
|
"allow-list" => Ok(FilesystemIsolationMode::AllowList),
|
||||||
|
other => Err(ConfigError::Parse(format!(
|
||||||
|
"merged settings.sandbox.filesystemMode: unsupported filesystem mode {other}"
|
||||||
|
))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn parse_optional_oauth_config(
|
fn parse_optional_oauth_config(
|
||||||
root: &JsonValue,
|
root: &JsonValue,
|
||||||
context: &str,
|
context: &str,
|
||||||
@@ -594,9 +733,11 @@ fn deep_merge_objects(
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::{
|
use super::{
|
||||||
ConfigLoader, ConfigSource, McpServerConfig, McpTransport, CLAUDE_CODE_SETTINGS_SCHEMA_NAME,
|
ConfigLoader, ConfigSource, McpServerConfig, McpTransport, ResolvedPermissionMode,
|
||||||
|
CLAUDE_CODE_SETTINGS_SCHEMA_NAME,
|
||||||
};
|
};
|
||||||
use crate::json::JsonValue;
|
use crate::json::JsonValue;
|
||||||
|
use crate::sandbox::FilesystemIsolationMode;
|
||||||
use std::fs;
|
use std::fs;
|
||||||
use std::time::{SystemTime, UNIX_EPOCH};
|
use std::time::{SystemTime, UNIX_EPOCH};
|
||||||
|
|
||||||
@@ -635,14 +776,24 @@ mod tests {
|
|||||||
fs::create_dir_all(cwd.join(".claude")).expect("project config dir");
|
fs::create_dir_all(cwd.join(".claude")).expect("project config dir");
|
||||||
fs::create_dir_all(&home).expect("home config dir");
|
fs::create_dir_all(&home).expect("home config dir");
|
||||||
|
|
||||||
|
fs::write(
|
||||||
|
home.parent().expect("home parent").join(".claude.json"),
|
||||||
|
r#"{"model":"haiku","env":{"A":"1"},"mcpServers":{"home":{"command":"uvx","args":["home"]}}}"#,
|
||||||
|
)
|
||||||
|
.expect("write user compat config");
|
||||||
fs::write(
|
fs::write(
|
||||||
home.join("settings.json"),
|
home.join("settings.json"),
|
||||||
r#"{"model":"sonnet","env":{"A":"1"},"hooks":{"PreToolUse":["base"]}}"#,
|
r#"{"model":"sonnet","env":{"A2":"1"},"hooks":{"PreToolUse":["base"]},"permissions":{"defaultMode":"plan"}}"#,
|
||||||
)
|
)
|
||||||
.expect("write user settings");
|
.expect("write user settings");
|
||||||
|
fs::write(
|
||||||
|
cwd.join(".claude.json"),
|
||||||
|
r#"{"model":"project-compat","env":{"B":"2"}}"#,
|
||||||
|
)
|
||||||
|
.expect("write project compat config");
|
||||||
fs::write(
|
fs::write(
|
||||||
cwd.join(".claude").join("settings.json"),
|
cwd.join(".claude").join("settings.json"),
|
||||||
r#"{"env":{"B":"2"},"hooks":{"PostToolUse":["project"]}}"#,
|
r#"{"env":{"C":"3"},"hooks":{"PostToolUse":["project"]},"mcpServers":{"project":{"command":"uvx","args":["project"]}}}"#,
|
||||||
)
|
)
|
||||||
.expect("write project settings");
|
.expect("write project settings");
|
||||||
fs::write(
|
fs::write(
|
||||||
@@ -656,25 +807,75 @@ mod tests {
|
|||||||
.expect("config should load");
|
.expect("config should load");
|
||||||
|
|
||||||
assert_eq!(CLAUDE_CODE_SETTINGS_SCHEMA_NAME, "SettingsSchema");
|
assert_eq!(CLAUDE_CODE_SETTINGS_SCHEMA_NAME, "SettingsSchema");
|
||||||
assert_eq!(loaded.loaded_entries().len(), 3);
|
assert_eq!(loaded.loaded_entries().len(), 5);
|
||||||
assert_eq!(loaded.loaded_entries()[0].source, ConfigSource::User);
|
assert_eq!(loaded.loaded_entries()[0].source, ConfigSource::User);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
loaded.get("model"),
|
loaded.get("model"),
|
||||||
Some(&JsonValue::String("opus".to_string()))
|
Some(&JsonValue::String("opus".to_string()))
|
||||||
);
|
);
|
||||||
|
assert_eq!(loaded.model(), Some("opus"));
|
||||||
|
assert_eq!(
|
||||||
|
loaded.permission_mode(),
|
||||||
|
Some(ResolvedPermissionMode::WorkspaceWrite)
|
||||||
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
loaded
|
loaded
|
||||||
.get("env")
|
.get("env")
|
||||||
.and_then(JsonValue::as_object)
|
.and_then(JsonValue::as_object)
|
||||||
.expect("env object")
|
.expect("env object")
|
||||||
.len(),
|
.len(),
|
||||||
2
|
4
|
||||||
);
|
);
|
||||||
assert!(loaded
|
assert!(loaded
|
||||||
.get("hooks")
|
.get("hooks")
|
||||||
.and_then(JsonValue::as_object)
|
.and_then(JsonValue::as_object)
|
||||||
.expect("hooks object")
|
.expect("hooks object")
|
||||||
.contains_key("PreToolUse"));
|
.contains_key("PreToolUse"));
|
||||||
|
assert!(loaded
|
||||||
|
.get("hooks")
|
||||||
|
.and_then(JsonValue::as_object)
|
||||||
|
.expect("hooks object")
|
||||||
|
.contains_key("PostToolUse"));
|
||||||
|
assert!(loaded.mcp().get("home").is_some());
|
||||||
|
assert!(loaded.mcp().get("project").is_some());
|
||||||
|
|
||||||
|
fs::remove_dir_all(root).expect("cleanup temp dir");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parses_sandbox_config() {
|
||||||
|
let root = temp_dir();
|
||||||
|
let cwd = root.join("project");
|
||||||
|
let home = root.join("home").join(".claude");
|
||||||
|
fs::create_dir_all(cwd.join(".claude")).expect("project config dir");
|
||||||
|
fs::create_dir_all(&home).expect("home config dir");
|
||||||
|
|
||||||
|
fs::write(
|
||||||
|
cwd.join(".claude").join("settings.local.json"),
|
||||||
|
r#"{
|
||||||
|
"sandbox": {
|
||||||
|
"enabled": true,
|
||||||
|
"namespaceRestrictions": false,
|
||||||
|
"networkIsolation": true,
|
||||||
|
"filesystemMode": "allow-list",
|
||||||
|
"allowedMounts": ["logs", "tmp/cache"]
|
||||||
|
}
|
||||||
|
}"#,
|
||||||
|
)
|
||||||
|
.expect("write local settings");
|
||||||
|
|
||||||
|
let loaded = ConfigLoader::new(&cwd, &home)
|
||||||
|
.load()
|
||||||
|
.expect("config should load");
|
||||||
|
|
||||||
|
assert_eq!(loaded.sandbox().enabled, Some(true));
|
||||||
|
assert_eq!(loaded.sandbox().namespace_restrictions, Some(false));
|
||||||
|
assert_eq!(loaded.sandbox().network_isolation, Some(true));
|
||||||
|
assert_eq!(
|
||||||
|
loaded.sandbox().filesystem_mode,
|
||||||
|
Some(FilesystemIsolationMode::AllowList)
|
||||||
|
);
|
||||||
|
assert_eq!(loaded.sandbox().allowed_mounts, vec!["logs", "tmp/cache"]);
|
||||||
|
|
||||||
fs::remove_dir_all(root).expect("cleanup temp dir");
|
fs::remove_dir_all(root).expect("cleanup temp dir");
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -408,12 +408,13 @@ mod tests {
|
|||||||
.sum::<i32>();
|
.sum::<i32>();
|
||||||
Ok(total.to_string())
|
Ok(total.to_string())
|
||||||
});
|
});
|
||||||
let permission_policy = PermissionPolicy::new(PermissionMode::Prompt);
|
let permission_policy = PermissionPolicy::new(PermissionMode::WorkspaceWrite);
|
||||||
let system_prompt = SystemPromptBuilder::new()
|
let system_prompt = SystemPromptBuilder::new()
|
||||||
.with_project_context(ProjectContext {
|
.with_project_context(ProjectContext {
|
||||||
cwd: PathBuf::from("/tmp/project"),
|
cwd: PathBuf::from("/tmp/project"),
|
||||||
current_date: "2026-03-31".to_string(),
|
current_date: "2026-03-31".to_string(),
|
||||||
git_status: None,
|
git_status: None,
|
||||||
|
git_diff: None,
|
||||||
instruction_files: Vec::new(),
|
instruction_files: Vec::new(),
|
||||||
})
|
})
|
||||||
.with_os("linux", "6.8")
|
.with_os("linux", "6.8")
|
||||||
@@ -487,7 +488,7 @@ mod tests {
|
|||||||
Session::new(),
|
Session::new(),
|
||||||
SingleCallApiClient,
|
SingleCallApiClient,
|
||||||
StaticToolExecutor::new(),
|
StaticToolExecutor::new(),
|
||||||
PermissionPolicy::new(PermissionMode::Prompt),
|
PermissionPolicy::new(PermissionMode::WorkspaceWrite),
|
||||||
vec!["system".to_string()],
|
vec!["system".to_string()],
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -536,7 +537,7 @@ mod tests {
|
|||||||
session,
|
session,
|
||||||
SimpleApi,
|
SimpleApi,
|
||||||
StaticToolExecutor::new(),
|
StaticToolExecutor::new(),
|
||||||
PermissionPolicy::new(PermissionMode::Allow),
|
PermissionPolicy::new(PermissionMode::DangerFullAccess),
|
||||||
vec!["system".to_string()],
|
vec!["system".to_string()],
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -563,7 +564,7 @@ mod tests {
|
|||||||
Session::new(),
|
Session::new(),
|
||||||
SimpleApi,
|
SimpleApi,
|
||||||
StaticToolExecutor::new(),
|
StaticToolExecutor::new(),
|
||||||
PermissionPolicy::new(PermissionMode::Allow),
|
PermissionPolicy::new(PermissionMode::DangerFullAccess),
|
||||||
vec!["system".to_string()],
|
vec!["system".to_string()],
|
||||||
);
|
);
|
||||||
runtime.run_turn("a", None).expect("turn a");
|
runtime.run_turn("a", None).expect("turn a");
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ mod oauth;
|
|||||||
mod permissions;
|
mod permissions;
|
||||||
mod prompt;
|
mod prompt;
|
||||||
mod remote;
|
mod remote;
|
||||||
|
pub mod sandbox;
|
||||||
mod session;
|
mod session;
|
||||||
mod usage;
|
mod usage;
|
||||||
|
|
||||||
@@ -25,7 +26,8 @@ pub use config::{
|
|||||||
ConfigEntry, ConfigError, ConfigLoader, ConfigSource, McpClaudeAiProxyServerConfig,
|
ConfigEntry, ConfigError, ConfigLoader, ConfigSource, McpClaudeAiProxyServerConfig,
|
||||||
McpConfigCollection, McpOAuthConfig, McpRemoteServerConfig, McpSdkServerConfig,
|
McpConfigCollection, McpOAuthConfig, McpRemoteServerConfig, McpSdkServerConfig,
|
||||||
McpServerConfig, McpStdioServerConfig, McpTransport, McpWebSocketServerConfig, OAuthConfig,
|
McpServerConfig, McpStdioServerConfig, McpTransport, McpWebSocketServerConfig, OAuthConfig,
|
||||||
RuntimeConfig, RuntimeFeatureConfig, ScopedMcpServerConfig, CLAUDE_CODE_SETTINGS_SCHEMA_NAME,
|
ResolvedPermissionMode, RuntimeConfig, RuntimeFeatureConfig, ScopedMcpServerConfig,
|
||||||
|
CLAUDE_CODE_SETTINGS_SCHEMA_NAME,
|
||||||
};
|
};
|
||||||
pub use conversation::{
|
pub use conversation::{
|
||||||
ApiClient, ApiRequest, AssistantEvent, ConversationRuntime, RuntimeError, StaticToolExecutor,
|
ApiClient, ApiRequest, AssistantEvent, ConversationRuntime, RuntimeError, StaticToolExecutor,
|
||||||
@@ -46,14 +48,17 @@ pub use mcp_client::{
|
|||||||
};
|
};
|
||||||
pub use mcp_stdio::{
|
pub use mcp_stdio::{
|
||||||
spawn_mcp_stdio_process, JsonRpcError, JsonRpcId, JsonRpcRequest, JsonRpcResponse,
|
spawn_mcp_stdio_process, JsonRpcError, JsonRpcId, JsonRpcRequest, JsonRpcResponse,
|
||||||
McpInitializeClientInfo, McpInitializeParams, McpInitializeResult, McpInitializeServerInfo,
|
ManagedMcpTool, McpInitializeClientInfo, McpInitializeParams, McpInitializeResult,
|
||||||
McpListResourcesParams, McpListResourcesResult, McpListToolsParams, McpListToolsResult,
|
McpInitializeServerInfo, McpListResourcesParams, McpListResourcesResult, McpListToolsParams,
|
||||||
McpReadResourceParams, McpReadResourceResult, McpResource, McpResourceContents,
|
McpListToolsResult, McpReadResourceParams, McpReadResourceResult, McpResource,
|
||||||
McpStdioProcess, McpTool, McpToolCallContent, McpToolCallParams, McpToolCallResult,
|
McpResourceContents, McpServerManager, McpServerManagerError, McpStdioProcess, McpTool,
|
||||||
|
McpToolCallContent, McpToolCallParams, McpToolCallResult, UnsupportedMcpServer,
|
||||||
};
|
};
|
||||||
pub use oauth::{
|
pub use oauth::{
|
||||||
code_challenge_s256, generate_pkce_pair, generate_state, loopback_redirect_uri,
|
clear_oauth_credentials, code_challenge_s256, credentials_path, generate_pkce_pair,
|
||||||
OAuthAuthorizationRequest, OAuthRefreshRequest, OAuthTokenExchangeRequest, OAuthTokenSet,
|
generate_state, load_oauth_credentials, loopback_redirect_uri, parse_oauth_callback_query,
|
||||||
|
parse_oauth_callback_request_target, save_oauth_credentials, OAuthAuthorizationRequest,
|
||||||
|
OAuthCallbackParams, OAuthRefreshRequest, OAuthTokenExchangeRequest, OAuthTokenSet,
|
||||||
PkceChallengeMethod, PkceCodePair,
|
PkceChallengeMethod, PkceCodePair,
|
||||||
};
|
};
|
||||||
pub use permissions::{
|
pub use permissions::{
|
||||||
@@ -73,3 +78,11 @@ pub use session::{ContentBlock, ConversationMessage, MessageRole, Session, Sessi
|
|||||||
pub use usage::{
|
pub use usage::{
|
||||||
format_usd, pricing_for_model, ModelPricing, TokenUsage, UsageCostEstimate, UsageTracker,
|
format_usd, pricing_for_model, ModelPricing, TokenUsage, UsageCostEstimate, UsageTracker,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
pub(crate) fn test_env_lock() -> std::sync::MutexGuard<'static, ()> {
|
||||||
|
static LOCK: std::sync::OnceLock<std::sync::Mutex<()>> = std::sync::OnceLock::new();
|
||||||
|
LOCK.get_or_init(|| std::sync::Mutex::new(()))
|
||||||
|
.lock()
|
||||||
|
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||||
|
}
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ use serde_json::Value as JsonValue;
|
|||||||
use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
|
use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
|
||||||
use tokio::process::{Child, ChildStdin, ChildStdout, Command};
|
use tokio::process::{Child, ChildStdin, ChildStdout, Command};
|
||||||
|
|
||||||
|
use crate::config::{McpTransport, RuntimeConfig, ScopedMcpServerConfig};
|
||||||
|
use crate::mcp::mcp_tool_name;
|
||||||
use crate::mcp_client::{McpClientBootstrap, McpClientTransport, McpStdioTransport};
|
use crate::mcp_client::{McpClientBootstrap, McpClientTransport, McpStdioTransport};
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||||
@@ -200,6 +202,374 @@ pub struct McpReadResourceResult {
|
|||||||
pub contents: Vec<McpResourceContents>,
|
pub contents: Vec<McpResourceContents>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
|
pub struct ManagedMcpTool {
|
||||||
|
pub server_name: String,
|
||||||
|
pub qualified_name: String,
|
||||||
|
pub raw_name: String,
|
||||||
|
pub tool: McpTool,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct UnsupportedMcpServer {
|
||||||
|
pub server_name: String,
|
||||||
|
pub transport: McpTransport,
|
||||||
|
pub reason: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum McpServerManagerError {
|
||||||
|
Io(io::Error),
|
||||||
|
JsonRpc {
|
||||||
|
server_name: String,
|
||||||
|
method: &'static str,
|
||||||
|
error: JsonRpcError,
|
||||||
|
},
|
||||||
|
InvalidResponse {
|
||||||
|
server_name: String,
|
||||||
|
method: &'static str,
|
||||||
|
details: String,
|
||||||
|
},
|
||||||
|
UnknownTool {
|
||||||
|
qualified_name: String,
|
||||||
|
},
|
||||||
|
UnknownServer {
|
||||||
|
server_name: String,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for McpServerManagerError {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
Self::Io(error) => write!(f, "{error}"),
|
||||||
|
Self::JsonRpc {
|
||||||
|
server_name,
|
||||||
|
method,
|
||||||
|
error,
|
||||||
|
} => write!(
|
||||||
|
f,
|
||||||
|
"MCP server `{server_name}` returned JSON-RPC error for {method}: {} ({})",
|
||||||
|
error.message, error.code
|
||||||
|
),
|
||||||
|
Self::InvalidResponse {
|
||||||
|
server_name,
|
||||||
|
method,
|
||||||
|
details,
|
||||||
|
} => write!(
|
||||||
|
f,
|
||||||
|
"MCP server `{server_name}` returned invalid response for {method}: {details}"
|
||||||
|
),
|
||||||
|
Self::UnknownTool { qualified_name } => {
|
||||||
|
write!(f, "unknown MCP tool `{qualified_name}`")
|
||||||
|
}
|
||||||
|
Self::UnknownServer { server_name } => write!(f, "unknown MCP server `{server_name}`"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::error::Error for McpServerManagerError {
|
||||||
|
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||||
|
match self {
|
||||||
|
Self::Io(error) => Some(error),
|
||||||
|
Self::JsonRpc { .. }
|
||||||
|
| Self::InvalidResponse { .. }
|
||||||
|
| Self::UnknownTool { .. }
|
||||||
|
| Self::UnknownServer { .. } => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<io::Error> for McpServerManagerError {
|
||||||
|
fn from(value: io::Error) -> Self {
|
||||||
|
Self::Io(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
struct ToolRoute {
|
||||||
|
server_name: String,
|
||||||
|
raw_name: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct ManagedMcpServer {
|
||||||
|
bootstrap: McpClientBootstrap,
|
||||||
|
process: Option<McpStdioProcess>,
|
||||||
|
initialized: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ManagedMcpServer {
|
||||||
|
fn new(bootstrap: McpClientBootstrap) -> Self {
|
||||||
|
Self {
|
||||||
|
bootstrap,
|
||||||
|
process: None,
|
||||||
|
initialized: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct McpServerManager {
|
||||||
|
servers: BTreeMap<String, ManagedMcpServer>,
|
||||||
|
unsupported_servers: Vec<UnsupportedMcpServer>,
|
||||||
|
tool_index: BTreeMap<String, ToolRoute>,
|
||||||
|
next_request_id: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl McpServerManager {
|
||||||
|
#[must_use]
|
||||||
|
pub fn from_runtime_config(config: &RuntimeConfig) -> Self {
|
||||||
|
Self::from_servers(config.mcp().servers())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn from_servers(servers: &BTreeMap<String, ScopedMcpServerConfig>) -> Self {
|
||||||
|
let mut managed_servers = BTreeMap::new();
|
||||||
|
let mut unsupported_servers = Vec::new();
|
||||||
|
|
||||||
|
for (server_name, server_config) in servers {
|
||||||
|
if server_config.transport() == McpTransport::Stdio {
|
||||||
|
let bootstrap = McpClientBootstrap::from_scoped_config(server_name, server_config);
|
||||||
|
managed_servers.insert(server_name.clone(), ManagedMcpServer::new(bootstrap));
|
||||||
|
} else {
|
||||||
|
unsupported_servers.push(UnsupportedMcpServer {
|
||||||
|
server_name: server_name.clone(),
|
||||||
|
transport: server_config.transport(),
|
||||||
|
reason: format!(
|
||||||
|
"transport {:?} is not supported by McpServerManager",
|
||||||
|
server_config.transport()
|
||||||
|
),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Self {
|
||||||
|
servers: managed_servers,
|
||||||
|
unsupported_servers,
|
||||||
|
tool_index: BTreeMap::new(),
|
||||||
|
next_request_id: 1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn unsupported_servers(&self) -> &[UnsupportedMcpServer] {
|
||||||
|
&self.unsupported_servers
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn discover_tools(&mut self) -> Result<Vec<ManagedMcpTool>, McpServerManagerError> {
|
||||||
|
let server_names = self.servers.keys().cloned().collect::<Vec<_>>();
|
||||||
|
let mut discovered_tools = Vec::new();
|
||||||
|
|
||||||
|
for server_name in server_names {
|
||||||
|
self.ensure_server_ready(&server_name).await?;
|
||||||
|
self.clear_routes_for_server(&server_name);
|
||||||
|
|
||||||
|
let mut cursor = None;
|
||||||
|
loop {
|
||||||
|
let request_id = self.take_request_id();
|
||||||
|
let response = {
|
||||||
|
let server = self.server_mut(&server_name)?;
|
||||||
|
let process = server.process.as_mut().ok_or_else(|| {
|
||||||
|
McpServerManagerError::InvalidResponse {
|
||||||
|
server_name: server_name.clone(),
|
||||||
|
method: "tools/list",
|
||||||
|
details: "server process missing after initialization".to_string(),
|
||||||
|
}
|
||||||
|
})?;
|
||||||
|
process
|
||||||
|
.list_tools(
|
||||||
|
request_id,
|
||||||
|
Some(McpListToolsParams {
|
||||||
|
cursor: cursor.clone(),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.await?
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Some(error) = response.error {
|
||||||
|
return Err(McpServerManagerError::JsonRpc {
|
||||||
|
server_name: server_name.clone(),
|
||||||
|
method: "tools/list",
|
||||||
|
error,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let result =
|
||||||
|
response
|
||||||
|
.result
|
||||||
|
.ok_or_else(|| McpServerManagerError::InvalidResponse {
|
||||||
|
server_name: server_name.clone(),
|
||||||
|
method: "tools/list",
|
||||||
|
details: "missing result payload".to_string(),
|
||||||
|
})?;
|
||||||
|
|
||||||
|
for tool in result.tools {
|
||||||
|
let qualified_name = mcp_tool_name(&server_name, &tool.name);
|
||||||
|
self.tool_index.insert(
|
||||||
|
qualified_name.clone(),
|
||||||
|
ToolRoute {
|
||||||
|
server_name: server_name.clone(),
|
||||||
|
raw_name: tool.name.clone(),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
discovered_tools.push(ManagedMcpTool {
|
||||||
|
server_name: server_name.clone(),
|
||||||
|
qualified_name,
|
||||||
|
raw_name: tool.name.clone(),
|
||||||
|
tool,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
match result.next_cursor {
|
||||||
|
Some(next_cursor) => cursor = Some(next_cursor),
|
||||||
|
None => break,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(discovered_tools)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn call_tool(
|
||||||
|
&mut self,
|
||||||
|
qualified_tool_name: &str,
|
||||||
|
arguments: Option<JsonValue>,
|
||||||
|
) -> Result<JsonRpcResponse<McpToolCallResult>, McpServerManagerError> {
|
||||||
|
let route = self
|
||||||
|
.tool_index
|
||||||
|
.get(qualified_tool_name)
|
||||||
|
.cloned()
|
||||||
|
.ok_or_else(|| McpServerManagerError::UnknownTool {
|
||||||
|
qualified_name: qualified_tool_name.to_string(),
|
||||||
|
})?;
|
||||||
|
|
||||||
|
self.ensure_server_ready(&route.server_name).await?;
|
||||||
|
let request_id = self.take_request_id();
|
||||||
|
let response =
|
||||||
|
{
|
||||||
|
let server = self.server_mut(&route.server_name)?;
|
||||||
|
let process = server.process.as_mut().ok_or_else(|| {
|
||||||
|
McpServerManagerError::InvalidResponse {
|
||||||
|
server_name: route.server_name.clone(),
|
||||||
|
method: "tools/call",
|
||||||
|
details: "server process missing after initialization".to_string(),
|
||||||
|
}
|
||||||
|
})?;
|
||||||
|
process
|
||||||
|
.call_tool(
|
||||||
|
request_id,
|
||||||
|
McpToolCallParams {
|
||||||
|
name: route.raw_name,
|
||||||
|
arguments,
|
||||||
|
meta: None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await?
|
||||||
|
};
|
||||||
|
Ok(response)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn shutdown(&mut self) -> Result<(), McpServerManagerError> {
|
||||||
|
let server_names = self.servers.keys().cloned().collect::<Vec<_>>();
|
||||||
|
for server_name in server_names {
|
||||||
|
let server = self.server_mut(&server_name)?;
|
||||||
|
if let Some(process) = server.process.as_mut() {
|
||||||
|
process.shutdown().await?;
|
||||||
|
}
|
||||||
|
server.process = None;
|
||||||
|
server.initialized = false;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn clear_routes_for_server(&mut self, server_name: &str) {
|
||||||
|
self.tool_index
|
||||||
|
.retain(|_, route| route.server_name != server_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn server_mut(
|
||||||
|
&mut self,
|
||||||
|
server_name: &str,
|
||||||
|
) -> Result<&mut ManagedMcpServer, McpServerManagerError> {
|
||||||
|
self.servers
|
||||||
|
.get_mut(server_name)
|
||||||
|
.ok_or_else(|| McpServerManagerError::UnknownServer {
|
||||||
|
server_name: server_name.to_string(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn take_request_id(&mut self) -> JsonRpcId {
|
||||||
|
let id = self.next_request_id;
|
||||||
|
self.next_request_id = self.next_request_id.saturating_add(1);
|
||||||
|
JsonRpcId::Number(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn ensure_server_ready(
|
||||||
|
&mut self,
|
||||||
|
server_name: &str,
|
||||||
|
) -> Result<(), McpServerManagerError> {
|
||||||
|
let needs_spawn = self
|
||||||
|
.servers
|
||||||
|
.get(server_name)
|
||||||
|
.map(|server| server.process.is_none())
|
||||||
|
.ok_or_else(|| McpServerManagerError::UnknownServer {
|
||||||
|
server_name: server_name.to_string(),
|
||||||
|
})?;
|
||||||
|
|
||||||
|
if needs_spawn {
|
||||||
|
let server = self.server_mut(server_name)?;
|
||||||
|
server.process = Some(spawn_mcp_stdio_process(&server.bootstrap)?);
|
||||||
|
server.initialized = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
let needs_initialize = self
|
||||||
|
.servers
|
||||||
|
.get(server_name)
|
||||||
|
.map(|server| !server.initialized)
|
||||||
|
.ok_or_else(|| McpServerManagerError::UnknownServer {
|
||||||
|
server_name: server_name.to_string(),
|
||||||
|
})?;
|
||||||
|
|
||||||
|
if needs_initialize {
|
||||||
|
let request_id = self.take_request_id();
|
||||||
|
let response = {
|
||||||
|
let server = self.server_mut(server_name)?;
|
||||||
|
let process = server.process.as_mut().ok_or_else(|| {
|
||||||
|
McpServerManagerError::InvalidResponse {
|
||||||
|
server_name: server_name.to_string(),
|
||||||
|
method: "initialize",
|
||||||
|
details: "server process missing before initialize".to_string(),
|
||||||
|
}
|
||||||
|
})?;
|
||||||
|
process
|
||||||
|
.initialize(request_id, default_initialize_params())
|
||||||
|
.await?
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Some(error) = response.error {
|
||||||
|
return Err(McpServerManagerError::JsonRpc {
|
||||||
|
server_name: server_name.to_string(),
|
||||||
|
method: "initialize",
|
||||||
|
error,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if response.result.is_none() {
|
||||||
|
return Err(McpServerManagerError::InvalidResponse {
|
||||||
|
server_name: server_name.to_string(),
|
||||||
|
method: "initialize",
|
||||||
|
details: "missing result payload".to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let server = self.server_mut(server_name)?;
|
||||||
|
server.initialized = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct McpStdioProcess {
|
pub struct McpStdioProcess {
|
||||||
child: Child,
|
child: Child,
|
||||||
@@ -385,6 +755,14 @@ impl McpStdioProcess {
|
|||||||
pub async fn wait(&mut self) -> io::Result<std::process::ExitStatus> {
|
pub async fn wait(&mut self) -> io::Result<std::process::ExitStatus> {
|
||||||
self.child.wait().await
|
self.child.wait().await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn shutdown(&mut self) -> io::Result<()> {
|
||||||
|
if self.child.try_wait()?.is_none() {
|
||||||
|
self.child.kill().await?;
|
||||||
|
}
|
||||||
|
let _ = self.child.wait().await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn spawn_mcp_stdio_process(bootstrap: &McpClientBootstrap) -> io::Result<McpStdioProcess> {
|
pub fn spawn_mcp_stdio_process(bootstrap: &McpClientBootstrap) -> io::Result<McpStdioProcess> {
|
||||||
@@ -413,6 +791,17 @@ fn encode_frame(payload: &[u8]) -> Vec<u8> {
|
|||||||
framed
|
framed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn default_initialize_params() -> McpInitializeParams {
|
||||||
|
McpInitializeParams {
|
||||||
|
protocol_version: "2025-03-26".to_string(),
|
||||||
|
capabilities: JsonValue::Object(serde_json::Map::new()),
|
||||||
|
client_info: McpInitializeClientInfo {
|
||||||
|
name: "runtime".to_string(),
|
||||||
|
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use std::collections::BTreeMap;
|
use std::collections::BTreeMap;
|
||||||
@@ -426,15 +815,17 @@ mod tests {
|
|||||||
use tokio::runtime::Builder;
|
use tokio::runtime::Builder;
|
||||||
|
|
||||||
use crate::config::{
|
use crate::config::{
|
||||||
ConfigSource, McpServerConfig, McpStdioServerConfig, ScopedMcpServerConfig,
|
ConfigSource, McpRemoteServerConfig, McpSdkServerConfig, McpServerConfig,
|
||||||
|
McpStdioServerConfig, McpWebSocketServerConfig, ScopedMcpServerConfig,
|
||||||
};
|
};
|
||||||
|
use crate::mcp::mcp_tool_name;
|
||||||
use crate::mcp_client::McpClientBootstrap;
|
use crate::mcp_client::McpClientBootstrap;
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
spawn_mcp_stdio_process, JsonRpcId, JsonRpcRequest, JsonRpcResponse,
|
spawn_mcp_stdio_process, JsonRpcId, JsonRpcRequest, JsonRpcResponse,
|
||||||
McpInitializeClientInfo, McpInitializeParams, McpInitializeResult, McpInitializeServerInfo,
|
McpInitializeClientInfo, McpInitializeParams, McpInitializeResult, McpInitializeServerInfo,
|
||||||
McpListToolsResult, McpReadResourceParams, McpReadResourceResult, McpStdioProcess, McpTool,
|
McpListToolsResult, McpReadResourceParams, McpReadResourceResult, McpServerManager,
|
||||||
McpToolCallParams,
|
McpServerManagerError, McpStdioProcess, McpTool, McpToolCallParams,
|
||||||
};
|
};
|
||||||
|
|
||||||
fn temp_dir() -> PathBuf {
|
fn temp_dir() -> PathBuf {
|
||||||
@@ -628,6 +1019,110 @@ mod tests {
|
|||||||
script_path
|
script_path
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_lines)]
|
||||||
|
fn write_manager_mcp_server_script() -> PathBuf {
|
||||||
|
let root = temp_dir();
|
||||||
|
fs::create_dir_all(&root).expect("temp dir");
|
||||||
|
let script_path = root.join("manager-mcp-server.py");
|
||||||
|
let script = [
|
||||||
|
"#!/usr/bin/env python3",
|
||||||
|
"import json, os, sys",
|
||||||
|
"",
|
||||||
|
"LABEL = os.environ.get('MCP_SERVER_LABEL', 'server')",
|
||||||
|
"LOG_PATH = os.environ.get('MCP_LOG_PATH')",
|
||||||
|
"initialize_count = 0",
|
||||||
|
"",
|
||||||
|
"def log(method):",
|
||||||
|
" if LOG_PATH:",
|
||||||
|
" with open(LOG_PATH, 'a', encoding='utf-8') as handle:",
|
||||||
|
" handle.write(f'{method}\\n')",
|
||||||
|
"",
|
||||||
|
"def read_message():",
|
||||||
|
" header = b''",
|
||||||
|
r" while not header.endswith(b'\r\n\r\n'):",
|
||||||
|
" chunk = sys.stdin.buffer.read(1)",
|
||||||
|
" if not chunk:",
|
||||||
|
" return None",
|
||||||
|
" header += chunk",
|
||||||
|
" length = 0",
|
||||||
|
r" for line in header.decode().split('\r\n'):",
|
||||||
|
r" if line.lower().startswith('content-length:'):",
|
||||||
|
r" length = int(line.split(':', 1)[1].strip())",
|
||||||
|
" payload = sys.stdin.buffer.read(length)",
|
||||||
|
" return json.loads(payload.decode())",
|
||||||
|
"",
|
||||||
|
"def send_message(message):",
|
||||||
|
" payload = json.dumps(message).encode()",
|
||||||
|
r" sys.stdout.buffer.write(f'Content-Length: {len(payload)}\r\n\r\n'.encode() + payload)",
|
||||||
|
" sys.stdout.buffer.flush()",
|
||||||
|
"",
|
||||||
|
"while True:",
|
||||||
|
" request = read_message()",
|
||||||
|
" if request is None:",
|
||||||
|
" break",
|
||||||
|
" method = request['method']",
|
||||||
|
" log(method)",
|
||||||
|
" if method == 'initialize':",
|
||||||
|
" initialize_count += 1",
|
||||||
|
" send_message({",
|
||||||
|
" 'jsonrpc': '2.0',",
|
||||||
|
" 'id': request['id'],",
|
||||||
|
" 'result': {",
|
||||||
|
" 'protocolVersion': request['params']['protocolVersion'],",
|
||||||
|
" 'capabilities': {'tools': {}},",
|
||||||
|
" 'serverInfo': {'name': LABEL, 'version': '1.0.0'}",
|
||||||
|
" }",
|
||||||
|
" })",
|
||||||
|
" elif method == 'tools/list':",
|
||||||
|
" send_message({",
|
||||||
|
" 'jsonrpc': '2.0',",
|
||||||
|
" 'id': request['id'],",
|
||||||
|
" 'result': {",
|
||||||
|
" 'tools': [",
|
||||||
|
" {",
|
||||||
|
" 'name': 'echo',",
|
||||||
|
" 'description': f'Echo tool for {LABEL}',",
|
||||||
|
" 'inputSchema': {",
|
||||||
|
" 'type': 'object',",
|
||||||
|
" 'properties': {'text': {'type': 'string'}},",
|
||||||
|
" 'required': ['text']",
|
||||||
|
" }",
|
||||||
|
" }",
|
||||||
|
" ]",
|
||||||
|
" }",
|
||||||
|
" })",
|
||||||
|
" elif method == 'tools/call':",
|
||||||
|
" args = request['params'].get('arguments') or {}",
|
||||||
|
" text = args.get('text', '')",
|
||||||
|
" send_message({",
|
||||||
|
" 'jsonrpc': '2.0',",
|
||||||
|
" 'id': request['id'],",
|
||||||
|
" 'result': {",
|
||||||
|
" 'content': [{'type': 'text', 'text': f'{LABEL}:{text}'}],",
|
||||||
|
" 'structuredContent': {",
|
||||||
|
" 'server': LABEL,",
|
||||||
|
" 'echoed': text,",
|
||||||
|
" 'initializeCount': initialize_count",
|
||||||
|
" },",
|
||||||
|
" 'isError': False",
|
||||||
|
" }",
|
||||||
|
" })",
|
||||||
|
" else:",
|
||||||
|
" send_message({",
|
||||||
|
" 'jsonrpc': '2.0',",
|
||||||
|
" 'id': request['id'],",
|
||||||
|
" 'error': {'code': -32601, 'message': f'unknown method: {method}'},",
|
||||||
|
" })",
|
||||||
|
"",
|
||||||
|
]
|
||||||
|
.join("\n");
|
||||||
|
fs::write(&script_path, script).expect("write script");
|
||||||
|
let mut permissions = fs::metadata(&script_path).expect("metadata").permissions();
|
||||||
|
permissions.set_mode(0o755);
|
||||||
|
fs::set_permissions(&script_path, permissions).expect("chmod");
|
||||||
|
script_path
|
||||||
|
}
|
||||||
|
|
||||||
fn sample_bootstrap(script_path: &Path) -> McpClientBootstrap {
|
fn sample_bootstrap(script_path: &Path) -> McpClientBootstrap {
|
||||||
let config = ScopedMcpServerConfig {
|
let config = ScopedMcpServerConfig {
|
||||||
scope: ConfigSource::Local,
|
scope: ConfigSource::Local,
|
||||||
@@ -653,6 +1148,27 @@ mod tests {
|
|||||||
fs::remove_dir_all(script_path.parent().expect("script parent")).expect("cleanup dir");
|
fs::remove_dir_all(script_path.parent().expect("script parent")).expect("cleanup dir");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn manager_server_config(
|
||||||
|
script_path: &Path,
|
||||||
|
label: &str,
|
||||||
|
log_path: &Path,
|
||||||
|
) -> ScopedMcpServerConfig {
|
||||||
|
ScopedMcpServerConfig {
|
||||||
|
scope: ConfigSource::Local,
|
||||||
|
config: McpServerConfig::Stdio(McpStdioServerConfig {
|
||||||
|
command: "python3".to_string(),
|
||||||
|
args: vec![script_path.to_string_lossy().into_owned()],
|
||||||
|
env: BTreeMap::from([
|
||||||
|
("MCP_SERVER_LABEL".to_string(), label.to_string()),
|
||||||
|
(
|
||||||
|
"MCP_LOG_PATH".to_string(),
|
||||||
|
log_path.to_string_lossy().into_owned(),
|
||||||
|
),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn spawns_stdio_process_and_round_trips_io() {
|
fn spawns_stdio_process_and_round_trips_io() {
|
||||||
let runtime = Builder::new_current_thread()
|
let runtime = Builder::new_current_thread()
|
||||||
@@ -935,4 +1451,247 @@ mod tests {
|
|||||||
cleanup_script(&script_path);
|
cleanup_script(&script_path);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn manager_discovers_tools_from_stdio_config() {
|
||||||
|
let runtime = Builder::new_current_thread()
|
||||||
|
.enable_all()
|
||||||
|
.build()
|
||||||
|
.expect("runtime");
|
||||||
|
runtime.block_on(async {
|
||||||
|
let script_path = write_manager_mcp_server_script();
|
||||||
|
let root = script_path.parent().expect("script parent");
|
||||||
|
let log_path = root.join("alpha.log");
|
||||||
|
let servers = BTreeMap::from([(
|
||||||
|
"alpha".to_string(),
|
||||||
|
manager_server_config(&script_path, "alpha", &log_path),
|
||||||
|
)]);
|
||||||
|
let mut manager = McpServerManager::from_servers(&servers);
|
||||||
|
|
||||||
|
let tools = manager.discover_tools().await.expect("discover tools");
|
||||||
|
|
||||||
|
assert_eq!(tools.len(), 1);
|
||||||
|
assert_eq!(tools[0].server_name, "alpha");
|
||||||
|
assert_eq!(tools[0].raw_name, "echo");
|
||||||
|
assert_eq!(tools[0].qualified_name, mcp_tool_name("alpha", "echo"));
|
||||||
|
assert_eq!(tools[0].tool.name, "echo");
|
||||||
|
assert!(manager.unsupported_servers().is_empty());
|
||||||
|
|
||||||
|
manager.shutdown().await.expect("shutdown");
|
||||||
|
cleanup_script(&script_path);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn manager_routes_tool_calls_to_correct_server() {
|
||||||
|
let runtime = Builder::new_current_thread()
|
||||||
|
.enable_all()
|
||||||
|
.build()
|
||||||
|
.expect("runtime");
|
||||||
|
runtime.block_on(async {
|
||||||
|
let script_path = write_manager_mcp_server_script();
|
||||||
|
let root = script_path.parent().expect("script parent");
|
||||||
|
let alpha_log = root.join("alpha.log");
|
||||||
|
let beta_log = root.join("beta.log");
|
||||||
|
let servers = BTreeMap::from([
|
||||||
|
(
|
||||||
|
"alpha".to_string(),
|
||||||
|
manager_server_config(&script_path, "alpha", &alpha_log),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"beta".to_string(),
|
||||||
|
manager_server_config(&script_path, "beta", &beta_log),
|
||||||
|
),
|
||||||
|
]);
|
||||||
|
let mut manager = McpServerManager::from_servers(&servers);
|
||||||
|
|
||||||
|
let tools = manager.discover_tools().await.expect("discover tools");
|
||||||
|
assert_eq!(tools.len(), 2);
|
||||||
|
|
||||||
|
let alpha = manager
|
||||||
|
.call_tool(
|
||||||
|
&mcp_tool_name("alpha", "echo"),
|
||||||
|
Some(json!({"text": "hello"})),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("call alpha tool");
|
||||||
|
let beta = manager
|
||||||
|
.call_tool(
|
||||||
|
&mcp_tool_name("beta", "echo"),
|
||||||
|
Some(json!({"text": "world"})),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("call beta tool");
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
alpha
|
||||||
|
.result
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|result| result.structured_content.as_ref())
|
||||||
|
.and_then(|value| value.get("server")),
|
||||||
|
Some(&json!("alpha"))
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
beta.result
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|result| result.structured_content.as_ref())
|
||||||
|
.and_then(|value| value.get("server")),
|
||||||
|
Some(&json!("beta"))
|
||||||
|
);
|
||||||
|
|
||||||
|
manager.shutdown().await.expect("shutdown");
|
||||||
|
cleanup_script(&script_path);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn manager_records_unsupported_non_stdio_servers_without_panicking() {
|
||||||
|
let servers = BTreeMap::from([
|
||||||
|
(
|
||||||
|
"http".to_string(),
|
||||||
|
ScopedMcpServerConfig {
|
||||||
|
scope: ConfigSource::Local,
|
||||||
|
config: McpServerConfig::Http(McpRemoteServerConfig {
|
||||||
|
url: "https://example.test/mcp".to_string(),
|
||||||
|
headers: BTreeMap::new(),
|
||||||
|
headers_helper: None,
|
||||||
|
oauth: None,
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"sdk".to_string(),
|
||||||
|
ScopedMcpServerConfig {
|
||||||
|
scope: ConfigSource::Local,
|
||||||
|
config: McpServerConfig::Sdk(McpSdkServerConfig {
|
||||||
|
name: "sdk-server".to_string(),
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"ws".to_string(),
|
||||||
|
ScopedMcpServerConfig {
|
||||||
|
scope: ConfigSource::Local,
|
||||||
|
config: McpServerConfig::Ws(McpWebSocketServerConfig {
|
||||||
|
url: "wss://example.test/mcp".to_string(),
|
||||||
|
headers: BTreeMap::new(),
|
||||||
|
headers_helper: None,
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]);
|
||||||
|
|
||||||
|
let manager = McpServerManager::from_servers(&servers);
|
||||||
|
let unsupported = manager.unsupported_servers();
|
||||||
|
|
||||||
|
assert_eq!(unsupported.len(), 3);
|
||||||
|
assert_eq!(unsupported[0].server_name, "http");
|
||||||
|
assert_eq!(unsupported[1].server_name, "sdk");
|
||||||
|
assert_eq!(unsupported[2].server_name, "ws");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn manager_shutdown_terminates_spawned_children_and_is_idempotent() {
|
||||||
|
let runtime = Builder::new_current_thread()
|
||||||
|
.enable_all()
|
||||||
|
.build()
|
||||||
|
.expect("runtime");
|
||||||
|
runtime.block_on(async {
|
||||||
|
let script_path = write_manager_mcp_server_script();
|
||||||
|
let root = script_path.parent().expect("script parent");
|
||||||
|
let log_path = root.join("alpha.log");
|
||||||
|
let servers = BTreeMap::from([(
|
||||||
|
"alpha".to_string(),
|
||||||
|
manager_server_config(&script_path, "alpha", &log_path),
|
||||||
|
)]);
|
||||||
|
let mut manager = McpServerManager::from_servers(&servers);
|
||||||
|
|
||||||
|
manager.discover_tools().await.expect("discover tools");
|
||||||
|
manager.shutdown().await.expect("first shutdown");
|
||||||
|
manager.shutdown().await.expect("second shutdown");
|
||||||
|
|
||||||
|
cleanup_script(&script_path);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn manager_reuses_spawned_server_between_discovery_and_call() {
|
||||||
|
let runtime = Builder::new_current_thread()
|
||||||
|
.enable_all()
|
||||||
|
.build()
|
||||||
|
.expect("runtime");
|
||||||
|
runtime.block_on(async {
|
||||||
|
let script_path = write_manager_mcp_server_script();
|
||||||
|
let root = script_path.parent().expect("script parent");
|
||||||
|
let log_path = root.join("alpha.log");
|
||||||
|
let servers = BTreeMap::from([(
|
||||||
|
"alpha".to_string(),
|
||||||
|
manager_server_config(&script_path, "alpha", &log_path),
|
||||||
|
)]);
|
||||||
|
let mut manager = McpServerManager::from_servers(&servers);
|
||||||
|
|
||||||
|
manager.discover_tools().await.expect("discover tools");
|
||||||
|
let response = manager
|
||||||
|
.call_tool(
|
||||||
|
&mcp_tool_name("alpha", "echo"),
|
||||||
|
Some(json!({"text": "reuse"})),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("call tool");
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
response
|
||||||
|
.result
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|result| result.structured_content.as_ref())
|
||||||
|
.and_then(|value| value.get("initializeCount")),
|
||||||
|
Some(&json!(1))
|
||||||
|
);
|
||||||
|
|
||||||
|
let log = fs::read_to_string(&log_path).expect("read log");
|
||||||
|
assert_eq!(log.lines().filter(|line| *line == "initialize").count(), 1);
|
||||||
|
assert_eq!(
|
||||||
|
log.lines().collect::<Vec<_>>(),
|
||||||
|
vec!["initialize", "tools/list", "tools/call"]
|
||||||
|
);
|
||||||
|
|
||||||
|
manager.shutdown().await.expect("shutdown");
|
||||||
|
cleanup_script(&script_path);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn manager_reports_unknown_qualified_tool_name() {
|
||||||
|
let runtime = Builder::new_current_thread()
|
||||||
|
.enable_all()
|
||||||
|
.build()
|
||||||
|
.expect("runtime");
|
||||||
|
runtime.block_on(async {
|
||||||
|
let script_path = write_manager_mcp_server_script();
|
||||||
|
let root = script_path.parent().expect("script parent");
|
||||||
|
let log_path = root.join("alpha.log");
|
||||||
|
let servers = BTreeMap::from([(
|
||||||
|
"alpha".to_string(),
|
||||||
|
manager_server_config(&script_path, "alpha", &log_path),
|
||||||
|
)]);
|
||||||
|
let mut manager = McpServerManager::from_servers(&servers);
|
||||||
|
|
||||||
|
let error = manager
|
||||||
|
.call_tool(
|
||||||
|
&mcp_tool_name("alpha", "missing"),
|
||||||
|
Some(json!({"text": "nope"})),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect_err("unknown qualified tool should fail");
|
||||||
|
|
||||||
|
match error {
|
||||||
|
McpServerManagerError::UnknownTool { qualified_name } => {
|
||||||
|
assert_eq!(qualified_name, mcp_tool_name("alpha", "missing"));
|
||||||
|
}
|
||||||
|
other => panic!("expected unknown tool error, got {other:?}"),
|
||||||
|
}
|
||||||
|
|
||||||
|
cleanup_script(&script_path);
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,12 +1,15 @@
|
|||||||
use std::collections::BTreeMap;
|
use std::collections::BTreeMap;
|
||||||
use std::fs::File;
|
use std::fs::{self, File};
|
||||||
use std::io::{self, Read};
|
use std::io::{self, Read};
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::{Map, Value};
|
||||||
use sha2::{Digest, Sha256};
|
use sha2::{Digest, Sha256};
|
||||||
|
|
||||||
use crate::config::OAuthConfig;
|
use crate::config::OAuthConfig;
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
pub struct OAuthTokenSet {
|
pub struct OAuthTokenSet {
|
||||||
pub access_token: String,
|
pub access_token: String,
|
||||||
pub refresh_token: Option<String>,
|
pub refresh_token: Option<String>,
|
||||||
@@ -65,6 +68,48 @@ pub struct OAuthRefreshRequest {
|
|||||||
pub scopes: Vec<String>,
|
pub scopes: Vec<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct OAuthCallbackParams {
|
||||||
|
pub code: Option<String>,
|
||||||
|
pub state: Option<String>,
|
||||||
|
pub error: Option<String>,
|
||||||
|
pub error_description: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
struct StoredOAuthCredentials {
|
||||||
|
access_token: String,
|
||||||
|
#[serde(default)]
|
||||||
|
refresh_token: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
expires_at: Option<u64>,
|
||||||
|
#[serde(default)]
|
||||||
|
scopes: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<OAuthTokenSet> for StoredOAuthCredentials {
|
||||||
|
fn from(value: OAuthTokenSet) -> Self {
|
||||||
|
Self {
|
||||||
|
access_token: value.access_token,
|
||||||
|
refresh_token: value.refresh_token,
|
||||||
|
expires_at: value.expires_at,
|
||||||
|
scopes: value.scopes,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<StoredOAuthCredentials> for OAuthTokenSet {
|
||||||
|
fn from(value: StoredOAuthCredentials) -> Self {
|
||||||
|
Self {
|
||||||
|
access_token: value.access_token,
|
||||||
|
refresh_token: value.refresh_token,
|
||||||
|
expires_at: value.expires_at,
|
||||||
|
scopes: value.scopes,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl OAuthAuthorizationRequest {
|
impl OAuthAuthorizationRequest {
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn from_config(
|
pub fn from_config(
|
||||||
@@ -137,7 +182,6 @@ impl OAuthTokenExchangeRequest {
|
|||||||
verifier: impl Into<String>,
|
verifier: impl Into<String>,
|
||||||
redirect_uri: impl Into<String>,
|
redirect_uri: impl Into<String>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let _ = config;
|
|
||||||
Self {
|
Self {
|
||||||
grant_type: "authorization_code",
|
grant_type: "authorization_code",
|
||||||
code: code.into(),
|
code: code.into(),
|
||||||
@@ -211,12 +255,116 @@ pub fn loopback_redirect_uri(port: u16) -> String {
|
|||||||
format!("http://localhost:{port}/callback")
|
format!("http://localhost:{port}/callback")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn credentials_path() -> io::Result<PathBuf> {
|
||||||
|
Ok(credentials_home_dir()?.join("credentials.json"))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load_oauth_credentials() -> io::Result<Option<OAuthTokenSet>> {
|
||||||
|
let path = credentials_path()?;
|
||||||
|
let root = read_credentials_root(&path)?;
|
||||||
|
let Some(oauth) = root.get("oauth") else {
|
||||||
|
return Ok(None);
|
||||||
|
};
|
||||||
|
if oauth.is_null() {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
let stored = serde_json::from_value::<StoredOAuthCredentials>(oauth.clone())
|
||||||
|
.map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?;
|
||||||
|
Ok(Some(stored.into()))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn save_oauth_credentials(token_set: &OAuthTokenSet) -> io::Result<()> {
|
||||||
|
let path = credentials_path()?;
|
||||||
|
let mut root = read_credentials_root(&path)?;
|
||||||
|
root.insert(
|
||||||
|
"oauth".to_string(),
|
||||||
|
serde_json::to_value(StoredOAuthCredentials::from(token_set.clone()))
|
||||||
|
.map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?,
|
||||||
|
);
|
||||||
|
write_credentials_root(&path, &root)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn clear_oauth_credentials() -> io::Result<()> {
|
||||||
|
let path = credentials_path()?;
|
||||||
|
let mut root = read_credentials_root(&path)?;
|
||||||
|
root.remove("oauth");
|
||||||
|
write_credentials_root(&path, &root)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn parse_oauth_callback_request_target(target: &str) -> Result<OAuthCallbackParams, String> {
|
||||||
|
let (path, query) = target
|
||||||
|
.split_once('?')
|
||||||
|
.map_or((target, ""), |(path, query)| (path, query));
|
||||||
|
if path != "/callback" {
|
||||||
|
return Err(format!("unexpected callback path: {path}"));
|
||||||
|
}
|
||||||
|
parse_oauth_callback_query(query)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn parse_oauth_callback_query(query: &str) -> Result<OAuthCallbackParams, String> {
|
||||||
|
let mut params = BTreeMap::new();
|
||||||
|
for pair in query.split('&').filter(|pair| !pair.is_empty()) {
|
||||||
|
let (key, value) = pair
|
||||||
|
.split_once('=')
|
||||||
|
.map_or((pair, ""), |(key, value)| (key, value));
|
||||||
|
params.insert(percent_decode(key)?, percent_decode(value)?);
|
||||||
|
}
|
||||||
|
Ok(OAuthCallbackParams {
|
||||||
|
code: params.get("code").cloned(),
|
||||||
|
state: params.get("state").cloned(),
|
||||||
|
error: params.get("error").cloned(),
|
||||||
|
error_description: params.get("error_description").cloned(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
fn generate_random_token(bytes: usize) -> io::Result<String> {
|
fn generate_random_token(bytes: usize) -> io::Result<String> {
|
||||||
let mut buffer = vec![0_u8; bytes];
|
let mut buffer = vec![0_u8; bytes];
|
||||||
File::open("/dev/urandom")?.read_exact(&mut buffer)?;
|
File::open("/dev/urandom")?.read_exact(&mut buffer)?;
|
||||||
Ok(base64url_encode(&buffer))
|
Ok(base64url_encode(&buffer))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn credentials_home_dir() -> io::Result<PathBuf> {
|
||||||
|
if let Some(path) = std::env::var_os("CLAUDE_CONFIG_HOME") {
|
||||||
|
return Ok(PathBuf::from(path));
|
||||||
|
}
|
||||||
|
let home = std::env::var_os("HOME")
|
||||||
|
.ok_or_else(|| io::Error::new(io::ErrorKind::NotFound, "HOME is not set"))?;
|
||||||
|
Ok(PathBuf::from(home).join(".claude"))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn read_credentials_root(path: &PathBuf) -> io::Result<Map<String, Value>> {
|
||||||
|
match fs::read_to_string(path) {
|
||||||
|
Ok(contents) => {
|
||||||
|
if contents.trim().is_empty() {
|
||||||
|
return Ok(Map::new());
|
||||||
|
}
|
||||||
|
serde_json::from_str::<Value>(&contents)
|
||||||
|
.map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?
|
||||||
|
.as_object()
|
||||||
|
.cloned()
|
||||||
|
.ok_or_else(|| {
|
||||||
|
io::Error::new(
|
||||||
|
io::ErrorKind::InvalidData,
|
||||||
|
"credentials file must contain a JSON object",
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
Err(error) if error.kind() == io::ErrorKind::NotFound => Ok(Map::new()),
|
||||||
|
Err(error) => Err(error),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn write_credentials_root(path: &PathBuf, root: &Map<String, Value>) -> io::Result<()> {
|
||||||
|
if let Some(parent) = path.parent() {
|
||||||
|
fs::create_dir_all(parent)?;
|
||||||
|
}
|
||||||
|
let rendered = serde_json::to_string_pretty(&Value::Object(root.clone()))
|
||||||
|
.map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?;
|
||||||
|
let temp_path = path.with_extension("json.tmp");
|
||||||
|
fs::write(&temp_path, format!("{rendered}\n"))?;
|
||||||
|
fs::rename(temp_path, path)
|
||||||
|
}
|
||||||
|
|
||||||
fn base64url_encode(bytes: &[u8]) -> String {
|
fn base64url_encode(bytes: &[u8]) -> String {
|
||||||
const TABLE: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
|
const TABLE: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
|
||||||
let mut output = String::new();
|
let mut output = String::new();
|
||||||
@@ -264,11 +412,49 @@ fn percent_encode(value: &str) -> String {
|
|||||||
encoded
|
encoded
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn percent_decode(value: &str) -> Result<String, String> {
|
||||||
|
let mut decoded = Vec::with_capacity(value.len());
|
||||||
|
let bytes = value.as_bytes();
|
||||||
|
let mut index = 0;
|
||||||
|
while index < bytes.len() {
|
||||||
|
match bytes[index] {
|
||||||
|
b'%' if index + 2 < bytes.len() => {
|
||||||
|
let hi = decode_hex(bytes[index + 1])?;
|
||||||
|
let lo = decode_hex(bytes[index + 2])?;
|
||||||
|
decoded.push((hi << 4) | lo);
|
||||||
|
index += 3;
|
||||||
|
}
|
||||||
|
b'+' => {
|
||||||
|
decoded.push(b' ');
|
||||||
|
index += 1;
|
||||||
|
}
|
||||||
|
byte => {
|
||||||
|
decoded.push(byte);
|
||||||
|
index += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
String::from_utf8(decoded).map_err(|error| error.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn decode_hex(byte: u8) -> Result<u8, String> {
|
||||||
|
match byte {
|
||||||
|
b'0'..=b'9' => Ok(byte - b'0'),
|
||||||
|
b'a'..=b'f' => Ok(byte - b'a' + 10),
|
||||||
|
b'A'..=b'F' => Ok(byte - b'A' + 10),
|
||||||
|
_ => Err(format!("invalid percent-encoding byte: {byte}")),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
use std::time::{SystemTime, UNIX_EPOCH};
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
code_challenge_s256, generate_pkce_pair, generate_state, loopback_redirect_uri,
|
clear_oauth_credentials, code_challenge_s256, credentials_path, generate_pkce_pair,
|
||||||
OAuthAuthorizationRequest, OAuthConfig, OAuthRefreshRequest, OAuthTokenExchangeRequest,
|
generate_state, load_oauth_credentials, loopback_redirect_uri, parse_oauth_callback_query,
|
||||||
|
parse_oauth_callback_request_target, save_oauth_credentials, OAuthAuthorizationRequest,
|
||||||
|
OAuthConfig, OAuthRefreshRequest, OAuthTokenExchangeRequest, OAuthTokenSet,
|
||||||
};
|
};
|
||||||
|
|
||||||
fn sample_config() -> OAuthConfig {
|
fn sample_config() -> OAuthConfig {
|
||||||
@@ -282,6 +468,21 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
|
||||||
|
crate::test_env_lock()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn temp_config_home() -> std::path::PathBuf {
|
||||||
|
std::env::temp_dir().join(format!(
|
||||||
|
"runtime-oauth-test-{}-{}",
|
||||||
|
std::process::id(),
|
||||||
|
SystemTime::now()
|
||||||
|
.duration_since(UNIX_EPOCH)
|
||||||
|
.expect("time")
|
||||||
|
.as_nanos()
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn s256_challenge_matches_expected_vector() {
|
fn s256_challenge_matches_expected_vector() {
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@@ -335,4 +536,54 @@ mod tests {
|
|||||||
Some("org:read user:write")
|
Some("org:read user:write")
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn oauth_credentials_round_trip_and_clear_preserves_other_fields() {
|
||||||
|
let _guard = env_lock();
|
||||||
|
let config_home = temp_config_home();
|
||||||
|
std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
|
||||||
|
let path = credentials_path().expect("credentials path");
|
||||||
|
std::fs::create_dir_all(path.parent().expect("parent")).expect("create parent");
|
||||||
|
std::fs::write(&path, "{\"other\":\"value\"}\n").expect("seed credentials");
|
||||||
|
|
||||||
|
let token_set = OAuthTokenSet {
|
||||||
|
access_token: "access-token".to_string(),
|
||||||
|
refresh_token: Some("refresh-token".to_string()),
|
||||||
|
expires_at: Some(123),
|
||||||
|
scopes: vec!["scope:a".to_string()],
|
||||||
|
};
|
||||||
|
save_oauth_credentials(&token_set).expect("save credentials");
|
||||||
|
assert_eq!(
|
||||||
|
load_oauth_credentials().expect("load credentials"),
|
||||||
|
Some(token_set)
|
||||||
|
);
|
||||||
|
let saved = std::fs::read_to_string(&path).expect("read saved file");
|
||||||
|
assert!(saved.contains("\"other\": \"value\""));
|
||||||
|
assert!(saved.contains("\"oauth\""));
|
||||||
|
|
||||||
|
clear_oauth_credentials().expect("clear credentials");
|
||||||
|
assert_eq!(load_oauth_credentials().expect("load cleared"), None);
|
||||||
|
let cleared = std::fs::read_to_string(&path).expect("read cleared file");
|
||||||
|
assert!(cleared.contains("\"other\": \"value\""));
|
||||||
|
assert!(!cleared.contains("\"oauth\""));
|
||||||
|
|
||||||
|
std::env::remove_var("CLAUDE_CONFIG_HOME");
|
||||||
|
std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parses_callback_query_and_target() {
|
||||||
|
let params =
|
||||||
|
parse_oauth_callback_query("code=abc123&state=state-1&error_description=needs%20login")
|
||||||
|
.expect("parse query");
|
||||||
|
assert_eq!(params.code.as_deref(), Some("abc123"));
|
||||||
|
assert_eq!(params.state.as_deref(), Some("state-1"));
|
||||||
|
assert_eq!(params.error_description.as_deref(), Some("needs login"));
|
||||||
|
|
||||||
|
let params = parse_oauth_callback_request_target("/callback?code=abc&state=xyz")
|
||||||
|
.expect("parse callback target");
|
||||||
|
assert_eq!(params.code.as_deref(), Some("abc"));
|
||||||
|
assert_eq!(params.state.as_deref(), Some("xyz"));
|
||||||
|
assert!(parse_oauth_callback_request_target("/wrong?code=abc").is_err());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,16 +1,33 @@
|
|||||||
use std::collections::BTreeMap;
|
use std::collections::BTreeMap;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
|
||||||
pub enum PermissionMode {
|
pub enum PermissionMode {
|
||||||
Allow,
|
ReadOnly,
|
||||||
Deny,
|
WorkspaceWrite,
|
||||||
|
DangerFullAccess,
|
||||||
Prompt,
|
Prompt,
|
||||||
|
Allow,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PermissionMode {
|
||||||
|
#[must_use]
|
||||||
|
pub fn as_str(self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
Self::ReadOnly => "read-only",
|
||||||
|
Self::WorkspaceWrite => "workspace-write",
|
||||||
|
Self::DangerFullAccess => "danger-full-access",
|
||||||
|
Self::Prompt => "prompt",
|
||||||
|
Self::Allow => "allow",
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
pub struct PermissionRequest {
|
pub struct PermissionRequest {
|
||||||
pub tool_name: String,
|
pub tool_name: String,
|
||||||
pub input: String,
|
pub input: String,
|
||||||
|
pub current_mode: PermissionMode,
|
||||||
|
pub required_mode: PermissionMode,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
@@ -31,31 +48,41 @@ pub enum PermissionOutcome {
|
|||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
pub struct PermissionPolicy {
|
pub struct PermissionPolicy {
|
||||||
default_mode: PermissionMode,
|
active_mode: PermissionMode,
|
||||||
tool_modes: BTreeMap<String, PermissionMode>,
|
tool_requirements: BTreeMap<String, PermissionMode>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PermissionPolicy {
|
impl PermissionPolicy {
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn new(default_mode: PermissionMode) -> Self {
|
pub fn new(active_mode: PermissionMode) -> Self {
|
||||||
Self {
|
Self {
|
||||||
default_mode,
|
active_mode,
|
||||||
tool_modes: BTreeMap::new(),
|
tool_requirements: BTreeMap::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn with_tool_mode(mut self, tool_name: impl Into<String>, mode: PermissionMode) -> Self {
|
pub fn with_tool_requirement(
|
||||||
self.tool_modes.insert(tool_name.into(), mode);
|
mut self,
|
||||||
|
tool_name: impl Into<String>,
|
||||||
|
required_mode: PermissionMode,
|
||||||
|
) -> Self {
|
||||||
|
self.tool_requirements
|
||||||
|
.insert(tool_name.into(), required_mode);
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn mode_for(&self, tool_name: &str) -> PermissionMode {
|
pub fn active_mode(&self) -> PermissionMode {
|
||||||
self.tool_modes
|
self.active_mode
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn required_mode_for(&self, tool_name: &str) -> PermissionMode {
|
||||||
|
self.tool_requirements
|
||||||
.get(tool_name)
|
.get(tool_name)
|
||||||
.copied()
|
.copied()
|
||||||
.unwrap_or(self.default_mode)
|
.unwrap_or(PermissionMode::DangerFullAccess)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[must_use]
|
#[must_use]
|
||||||
@@ -65,23 +92,44 @@ impl PermissionPolicy {
|
|||||||
input: &str,
|
input: &str,
|
||||||
mut prompter: Option<&mut dyn PermissionPrompter>,
|
mut prompter: Option<&mut dyn PermissionPrompter>,
|
||||||
) -> PermissionOutcome {
|
) -> PermissionOutcome {
|
||||||
match self.mode_for(tool_name) {
|
let current_mode = self.active_mode();
|
||||||
PermissionMode::Allow => PermissionOutcome::Allow,
|
let required_mode = self.required_mode_for(tool_name);
|
||||||
PermissionMode::Deny => PermissionOutcome::Deny {
|
if current_mode == PermissionMode::Allow || current_mode >= required_mode {
|
||||||
reason: format!("tool '{tool_name}' denied by permission policy"),
|
return PermissionOutcome::Allow;
|
||||||
},
|
}
|
||||||
PermissionMode::Prompt => match prompter.as_mut() {
|
|
||||||
Some(prompter) => match prompter.decide(&PermissionRequest {
|
let request = PermissionRequest {
|
||||||
tool_name: tool_name.to_string(),
|
tool_name: tool_name.to_string(),
|
||||||
input: input.to_string(),
|
input: input.to_string(),
|
||||||
}) {
|
current_mode,
|
||||||
|
required_mode,
|
||||||
|
};
|
||||||
|
|
||||||
|
if current_mode == PermissionMode::Prompt
|
||||||
|
|| (current_mode == PermissionMode::WorkspaceWrite
|
||||||
|
&& required_mode == PermissionMode::DangerFullAccess)
|
||||||
|
{
|
||||||
|
return match prompter.as_mut() {
|
||||||
|
Some(prompter) => match prompter.decide(&request) {
|
||||||
PermissionPromptDecision::Allow => PermissionOutcome::Allow,
|
PermissionPromptDecision::Allow => PermissionOutcome::Allow,
|
||||||
PermissionPromptDecision::Deny { reason } => PermissionOutcome::Deny { reason },
|
PermissionPromptDecision::Deny { reason } => PermissionOutcome::Deny { reason },
|
||||||
},
|
},
|
||||||
None => PermissionOutcome::Deny {
|
None => PermissionOutcome::Deny {
|
||||||
reason: format!("tool '{tool_name}' requires interactive approval"),
|
reason: format!(
|
||||||
|
"tool '{tool_name}' requires approval to escalate from {} to {}",
|
||||||
|
current_mode.as_str(),
|
||||||
|
required_mode.as_str()
|
||||||
|
),
|
||||||
},
|
},
|
||||||
},
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
PermissionOutcome::Deny {
|
||||||
|
reason: format!(
|
||||||
|
"tool '{tool_name}' requires {} permission; current mode is {}",
|
||||||
|
required_mode.as_str(),
|
||||||
|
current_mode.as_str()
|
||||||
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -93,25 +141,92 @@ mod tests {
|
|||||||
PermissionPrompter, PermissionRequest,
|
PermissionPrompter, PermissionRequest,
|
||||||
};
|
};
|
||||||
|
|
||||||
struct AllowPrompter;
|
struct RecordingPrompter {
|
||||||
|
seen: Vec<PermissionRequest>,
|
||||||
|
allow: bool,
|
||||||
|
}
|
||||||
|
|
||||||
impl PermissionPrompter for AllowPrompter {
|
impl PermissionPrompter for RecordingPrompter {
|
||||||
fn decide(&mut self, request: &PermissionRequest) -> PermissionPromptDecision {
|
fn decide(&mut self, request: &PermissionRequest) -> PermissionPromptDecision {
|
||||||
assert_eq!(request.tool_name, "bash");
|
self.seen.push(request.clone());
|
||||||
PermissionPromptDecision::Allow
|
if self.allow {
|
||||||
|
PermissionPromptDecision::Allow
|
||||||
|
} else {
|
||||||
|
PermissionPromptDecision::Deny {
|
||||||
|
reason: "not now".to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn uses_tool_specific_overrides() {
|
fn allows_tools_when_active_mode_meets_requirement() {
|
||||||
let policy = PermissionPolicy::new(PermissionMode::Deny)
|
let policy = PermissionPolicy::new(PermissionMode::WorkspaceWrite)
|
||||||
.with_tool_mode("bash", PermissionMode::Prompt);
|
.with_tool_requirement("read_file", PermissionMode::ReadOnly)
|
||||||
|
.with_tool_requirement("write_file", PermissionMode::WorkspaceWrite);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
policy.authorize("read_file", "{}", None),
|
||||||
|
PermissionOutcome::Allow
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
policy.authorize("write_file", "{}", None),
|
||||||
|
PermissionOutcome::Allow
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn denies_read_only_escalations_without_prompt() {
|
||||||
|
let policy = PermissionPolicy::new(PermissionMode::ReadOnly)
|
||||||
|
.with_tool_requirement("write_file", PermissionMode::WorkspaceWrite)
|
||||||
|
.with_tool_requirement("bash", PermissionMode::DangerFullAccess);
|
||||||
|
|
||||||
let outcome = policy.authorize("bash", "echo hi", Some(&mut AllowPrompter));
|
|
||||||
assert_eq!(outcome, PermissionOutcome::Allow);
|
|
||||||
assert!(matches!(
|
assert!(matches!(
|
||||||
policy.authorize("edit", "x", None),
|
policy.authorize("write_file", "{}", None),
|
||||||
PermissionOutcome::Deny { .. }
|
PermissionOutcome::Deny { reason } if reason.contains("requires workspace-write permission")
|
||||||
|
));
|
||||||
|
assert!(matches!(
|
||||||
|
policy.authorize("bash", "{}", None),
|
||||||
|
PermissionOutcome::Deny { reason } if reason.contains("requires danger-full-access permission")
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn prompts_for_workspace_write_to_danger_full_access_escalation() {
|
||||||
|
let policy = PermissionPolicy::new(PermissionMode::WorkspaceWrite)
|
||||||
|
.with_tool_requirement("bash", PermissionMode::DangerFullAccess);
|
||||||
|
let mut prompter = RecordingPrompter {
|
||||||
|
seen: Vec::new(),
|
||||||
|
allow: true,
|
||||||
|
};
|
||||||
|
|
||||||
|
let outcome = policy.authorize("bash", "echo hi", Some(&mut prompter));
|
||||||
|
|
||||||
|
assert_eq!(outcome, PermissionOutcome::Allow);
|
||||||
|
assert_eq!(prompter.seen.len(), 1);
|
||||||
|
assert_eq!(prompter.seen[0].tool_name, "bash");
|
||||||
|
assert_eq!(
|
||||||
|
prompter.seen[0].current_mode,
|
||||||
|
PermissionMode::WorkspaceWrite
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
prompter.seen[0].required_mode,
|
||||||
|
PermissionMode::DangerFullAccess
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn honors_prompt_rejection_reason() {
|
||||||
|
let policy = PermissionPolicy::new(PermissionMode::WorkspaceWrite)
|
||||||
|
.with_tool_requirement("bash", PermissionMode::DangerFullAccess);
|
||||||
|
let mut prompter = RecordingPrompter {
|
||||||
|
seen: Vec::new(),
|
||||||
|
allow: false,
|
||||||
|
};
|
||||||
|
|
||||||
|
assert!(matches!(
|
||||||
|
policy.authorize("bash", "echo hi", Some(&mut prompter)),
|
||||||
|
PermissionOutcome::Deny { reason } if reason == "not now"
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -50,6 +50,7 @@ pub struct ProjectContext {
|
|||||||
pub cwd: PathBuf,
|
pub cwd: PathBuf,
|
||||||
pub current_date: String,
|
pub current_date: String,
|
||||||
pub git_status: Option<String>,
|
pub git_status: Option<String>,
|
||||||
|
pub git_diff: Option<String>,
|
||||||
pub instruction_files: Vec<ContextFile>,
|
pub instruction_files: Vec<ContextFile>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -64,6 +65,7 @@ impl ProjectContext {
|
|||||||
cwd,
|
cwd,
|
||||||
current_date: current_date.into(),
|
current_date: current_date.into(),
|
||||||
git_status: None,
|
git_status: None,
|
||||||
|
git_diff: None,
|
||||||
instruction_files,
|
instruction_files,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -74,6 +76,7 @@ impl ProjectContext {
|
|||||||
) -> std::io::Result<Self> {
|
) -> std::io::Result<Self> {
|
||||||
let mut context = Self::discover(cwd, current_date)?;
|
let mut context = Self::discover(cwd, current_date)?;
|
||||||
context.git_status = read_git_status(&context.cwd);
|
context.git_status = read_git_status(&context.cwd);
|
||||||
|
context.git_diff = read_git_diff(&context.cwd);
|
||||||
Ok(context)
|
Ok(context)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -201,6 +204,7 @@ fn discover_instruction_files(cwd: &Path) -> std::io::Result<Vec<ContextFile>> {
|
|||||||
dir.join("CLAUDE.md"),
|
dir.join("CLAUDE.md"),
|
||||||
dir.join("CLAUDE.local.md"),
|
dir.join("CLAUDE.local.md"),
|
||||||
dir.join(".claude").join("CLAUDE.md"),
|
dir.join(".claude").join("CLAUDE.md"),
|
||||||
|
dir.join(".claude").join("instructions.md"),
|
||||||
] {
|
] {
|
||||||
push_context_file(&mut files, candidate)?;
|
push_context_file(&mut files, candidate)?;
|
||||||
}
|
}
|
||||||
@@ -238,6 +242,38 @@ fn read_git_status(cwd: &Path) -> Option<String> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn read_git_diff(cwd: &Path) -> Option<String> {
|
||||||
|
let mut sections = Vec::new();
|
||||||
|
|
||||||
|
let staged = read_git_output(cwd, &["diff", "--cached"])?;
|
||||||
|
if !staged.trim().is_empty() {
|
||||||
|
sections.push(format!("Staged changes:\n{}", staged.trim_end()));
|
||||||
|
}
|
||||||
|
|
||||||
|
let unstaged = read_git_output(cwd, &["diff"])?;
|
||||||
|
if !unstaged.trim().is_empty() {
|
||||||
|
sections.push(format!("Unstaged changes:\n{}", unstaged.trim_end()));
|
||||||
|
}
|
||||||
|
|
||||||
|
if sections.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(sections.join("\n\n"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn read_git_output(cwd: &Path, args: &[&str]) -> Option<String> {
|
||||||
|
let output = Command::new("git")
|
||||||
|
.args(args)
|
||||||
|
.current_dir(cwd)
|
||||||
|
.output()
|
||||||
|
.ok()?;
|
||||||
|
if !output.status.success() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
String::from_utf8(output.stdout).ok()
|
||||||
|
}
|
||||||
|
|
||||||
fn render_project_context(project_context: &ProjectContext) -> String {
|
fn render_project_context(project_context: &ProjectContext) -> String {
|
||||||
let mut lines = vec!["# Project context".to_string()];
|
let mut lines = vec!["# Project context".to_string()];
|
||||||
let mut bullets = vec![
|
let mut bullets = vec![
|
||||||
@@ -256,6 +292,11 @@ fn render_project_context(project_context: &ProjectContext) -> String {
|
|||||||
lines.push("Git status snapshot:".to_string());
|
lines.push("Git status snapshot:".to_string());
|
||||||
lines.push(status.clone());
|
lines.push(status.clone());
|
||||||
}
|
}
|
||||||
|
if let Some(diff) = &project_context.git_diff {
|
||||||
|
lines.push(String::new());
|
||||||
|
lines.push("Git diff snapshot:".to_string());
|
||||||
|
lines.push(diff.clone());
|
||||||
|
}
|
||||||
lines.join("\n")
|
lines.join("\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -468,6 +509,10 @@ mod tests {
|
|||||||
std::env::temp_dir().join(format!("runtime-prompt-{nanos}"))
|
std::env::temp_dir().join(format!("runtime-prompt-{nanos}"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
|
||||||
|
crate::test_env_lock()
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn discovers_instruction_files_from_ancestor_chain() {
|
fn discovers_instruction_files_from_ancestor_chain() {
|
||||||
let root = temp_dir();
|
let root = temp_dir();
|
||||||
@@ -477,10 +522,21 @@ mod tests {
|
|||||||
fs::write(root.join("CLAUDE.local.md"), "local instructions")
|
fs::write(root.join("CLAUDE.local.md"), "local instructions")
|
||||||
.expect("write local instructions");
|
.expect("write local instructions");
|
||||||
fs::create_dir_all(root.join("apps")).expect("apps dir");
|
fs::create_dir_all(root.join("apps")).expect("apps dir");
|
||||||
|
fs::create_dir_all(root.join("apps").join(".claude")).expect("apps claude dir");
|
||||||
fs::write(root.join("apps").join("CLAUDE.md"), "apps instructions")
|
fs::write(root.join("apps").join("CLAUDE.md"), "apps instructions")
|
||||||
.expect("write apps instructions");
|
.expect("write apps instructions");
|
||||||
|
fs::write(
|
||||||
|
root.join("apps").join(".claude").join("instructions.md"),
|
||||||
|
"apps dot claude instructions",
|
||||||
|
)
|
||||||
|
.expect("write apps dot claude instructions");
|
||||||
fs::write(nested.join(".claude").join("CLAUDE.md"), "nested rules")
|
fs::write(nested.join(".claude").join("CLAUDE.md"), "nested rules")
|
||||||
.expect("write nested rules");
|
.expect("write nested rules");
|
||||||
|
fs::write(
|
||||||
|
nested.join(".claude").join("instructions.md"),
|
||||||
|
"nested instructions",
|
||||||
|
)
|
||||||
|
.expect("write nested instructions");
|
||||||
|
|
||||||
let context = ProjectContext::discover(&nested, "2026-03-31").expect("context should load");
|
let context = ProjectContext::discover(&nested, "2026-03-31").expect("context should load");
|
||||||
let contents = context
|
let contents = context
|
||||||
@@ -495,7 +551,9 @@ mod tests {
|
|||||||
"root instructions",
|
"root instructions",
|
||||||
"local instructions",
|
"local instructions",
|
||||||
"apps instructions",
|
"apps instructions",
|
||||||
"nested rules"
|
"apps dot claude instructions",
|
||||||
|
"nested rules",
|
||||||
|
"nested instructions"
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
fs::remove_dir_all(root).expect("cleanup temp dir");
|
fs::remove_dir_all(root).expect("cleanup temp dir");
|
||||||
@@ -559,6 +617,49 @@ mod tests {
|
|||||||
assert!(status.contains("## No commits yet on") || status.contains("## "));
|
assert!(status.contains("## No commits yet on") || status.contains("## "));
|
||||||
assert!(status.contains("?? CLAUDE.md"));
|
assert!(status.contains("?? CLAUDE.md"));
|
||||||
assert!(status.contains("?? tracked.txt"));
|
assert!(status.contains("?? tracked.txt"));
|
||||||
|
assert!(context.git_diff.is_none());
|
||||||
|
|
||||||
|
fs::remove_dir_all(root).expect("cleanup temp dir");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn discover_with_git_includes_diff_snapshot_for_tracked_changes() {
|
||||||
|
let root = temp_dir();
|
||||||
|
fs::create_dir_all(&root).expect("root dir");
|
||||||
|
std::process::Command::new("git")
|
||||||
|
.args(["init", "--quiet"])
|
||||||
|
.current_dir(&root)
|
||||||
|
.status()
|
||||||
|
.expect("git init should run");
|
||||||
|
std::process::Command::new("git")
|
||||||
|
.args(["config", "user.email", "tests@example.com"])
|
||||||
|
.current_dir(&root)
|
||||||
|
.status()
|
||||||
|
.expect("git config email should run");
|
||||||
|
std::process::Command::new("git")
|
||||||
|
.args(["config", "user.name", "Runtime Prompt Tests"])
|
||||||
|
.current_dir(&root)
|
||||||
|
.status()
|
||||||
|
.expect("git config name should run");
|
||||||
|
fs::write(root.join("tracked.txt"), "hello\n").expect("write tracked file");
|
||||||
|
std::process::Command::new("git")
|
||||||
|
.args(["add", "tracked.txt"])
|
||||||
|
.current_dir(&root)
|
||||||
|
.status()
|
||||||
|
.expect("git add should run");
|
||||||
|
std::process::Command::new("git")
|
||||||
|
.args(["commit", "-m", "init", "--quiet"])
|
||||||
|
.current_dir(&root)
|
||||||
|
.status()
|
||||||
|
.expect("git commit should run");
|
||||||
|
fs::write(root.join("tracked.txt"), "hello\nworld\n").expect("rewrite tracked file");
|
||||||
|
|
||||||
|
let context =
|
||||||
|
ProjectContext::discover_with_git(&root, "2026-03-31").expect("context should load");
|
||||||
|
|
||||||
|
let diff = context.git_diff.expect("git diff should be present");
|
||||||
|
assert!(diff.contains("Unstaged changes:"));
|
||||||
|
assert!(diff.contains("tracked.txt"));
|
||||||
|
|
||||||
fs::remove_dir_all(root).expect("cleanup temp dir");
|
fs::remove_dir_all(root).expect("cleanup temp dir");
|
||||||
}
|
}
|
||||||
@@ -574,7 +675,12 @@ mod tests {
|
|||||||
)
|
)
|
||||||
.expect("write settings");
|
.expect("write settings");
|
||||||
|
|
||||||
|
let _guard = env_lock();
|
||||||
let previous = std::env::current_dir().expect("cwd");
|
let previous = std::env::current_dir().expect("cwd");
|
||||||
|
let original_home = std::env::var("HOME").ok();
|
||||||
|
let original_claude_home = std::env::var("CLAUDE_CONFIG_HOME").ok();
|
||||||
|
std::env::set_var("HOME", &root);
|
||||||
|
std::env::set_var("CLAUDE_CONFIG_HOME", root.join("missing-home"));
|
||||||
std::env::set_current_dir(&root).expect("change cwd");
|
std::env::set_current_dir(&root).expect("change cwd");
|
||||||
let prompt = super::load_system_prompt(&root, "2026-03-31", "linux", "6.8")
|
let prompt = super::load_system_prompt(&root, "2026-03-31", "linux", "6.8")
|
||||||
.expect("system prompt should load")
|
.expect("system prompt should load")
|
||||||
@@ -584,6 +690,16 @@ mod tests {
|
|||||||
",
|
",
|
||||||
);
|
);
|
||||||
std::env::set_current_dir(previous).expect("restore cwd");
|
std::env::set_current_dir(previous).expect("restore cwd");
|
||||||
|
if let Some(value) = original_home {
|
||||||
|
std::env::set_var("HOME", value);
|
||||||
|
} else {
|
||||||
|
std::env::remove_var("HOME");
|
||||||
|
}
|
||||||
|
if let Some(value) = original_claude_home {
|
||||||
|
std::env::set_var("CLAUDE_CONFIG_HOME", value);
|
||||||
|
} else {
|
||||||
|
std::env::remove_var("CLAUDE_CONFIG_HOME");
|
||||||
|
}
|
||||||
|
|
||||||
assert!(prompt.contains("Project rules"));
|
assert!(prompt.contains("Project rules"));
|
||||||
assert!(prompt.contains("permissionMode"));
|
assert!(prompt.contains("permissionMode"));
|
||||||
@@ -631,6 +747,29 @@ mod tests {
|
|||||||
assert!(rendered.chars().count() <= 4_000 + "\n\n[truncated]".chars().count());
|
assert!(rendered.chars().count() <= 4_000 + "\n\n[truncated]".chars().count());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn discovers_dot_claude_instructions_markdown() {
|
||||||
|
let root = temp_dir();
|
||||||
|
let nested = root.join("apps").join("api");
|
||||||
|
fs::create_dir_all(nested.join(".claude")).expect("nested claude dir");
|
||||||
|
fs::write(
|
||||||
|
nested.join(".claude").join("instructions.md"),
|
||||||
|
"instruction markdown",
|
||||||
|
)
|
||||||
|
.expect("write instructions.md");
|
||||||
|
|
||||||
|
let context = ProjectContext::discover(&nested, "2026-03-31").expect("context should load");
|
||||||
|
assert!(context
|
||||||
|
.instruction_files
|
||||||
|
.iter()
|
||||||
|
.any(|file| file.path.ends_with(".claude/instructions.md")));
|
||||||
|
assert!(
|
||||||
|
render_instruction_files(&context.instruction_files).contains("instruction markdown")
|
||||||
|
);
|
||||||
|
|
||||||
|
fs::remove_dir_all(root).expect("cleanup temp dir");
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn renders_instruction_file_metadata() {
|
fn renders_instruction_file_metadata() {
|
||||||
let rendered = render_instruction_files(&[ContextFile {
|
let rendered = render_instruction_files(&[ContextFile {
|
||||||
|
|||||||
364
rust/crates/runtime/src/sandbox.rs
Normal file
364
rust/crates/runtime/src/sandbox.rs
Normal file
@@ -0,0 +1,364 @@
|
|||||||
|
use std::env;
|
||||||
|
use std::fs;
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
|
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
|
||||||
|
#[serde(rename_all = "kebab-case")]
|
||||||
|
pub enum FilesystemIsolationMode {
|
||||||
|
Off,
|
||||||
|
#[default]
|
||||||
|
WorkspaceOnly,
|
||||||
|
AllowList,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FilesystemIsolationMode {
|
||||||
|
#[must_use]
|
||||||
|
pub fn as_str(self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
Self::Off => "off",
|
||||||
|
Self::WorkspaceOnly => "workspace-only",
|
||||||
|
Self::AllowList => "allow-list",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
|
||||||
|
pub struct SandboxConfig {
|
||||||
|
pub enabled: Option<bool>,
|
||||||
|
pub namespace_restrictions: Option<bool>,
|
||||||
|
pub network_isolation: Option<bool>,
|
||||||
|
pub filesystem_mode: Option<FilesystemIsolationMode>,
|
||||||
|
pub allowed_mounts: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
|
||||||
|
pub struct SandboxRequest {
|
||||||
|
pub enabled: bool,
|
||||||
|
pub namespace_restrictions: bool,
|
||||||
|
pub network_isolation: bool,
|
||||||
|
pub filesystem_mode: FilesystemIsolationMode,
|
||||||
|
pub allowed_mounts: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
|
||||||
|
pub struct ContainerEnvironment {
|
||||||
|
pub in_container: bool,
|
||||||
|
pub markers: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::struct_excessive_bools)]
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
|
||||||
|
pub struct SandboxStatus {
|
||||||
|
pub enabled: bool,
|
||||||
|
pub requested: SandboxRequest,
|
||||||
|
pub supported: bool,
|
||||||
|
pub active: bool,
|
||||||
|
pub namespace_supported: bool,
|
||||||
|
pub namespace_active: bool,
|
||||||
|
pub network_supported: bool,
|
||||||
|
pub network_active: bool,
|
||||||
|
pub filesystem_mode: FilesystemIsolationMode,
|
||||||
|
pub filesystem_active: bool,
|
||||||
|
pub allowed_mounts: Vec<String>,
|
||||||
|
pub in_container: bool,
|
||||||
|
pub container_markers: Vec<String>,
|
||||||
|
pub fallback_reason: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct SandboxDetectionInputs<'a> {
|
||||||
|
pub env_pairs: Vec<(String, String)>,
|
||||||
|
pub dockerenv_exists: bool,
|
||||||
|
pub containerenv_exists: bool,
|
||||||
|
pub proc_1_cgroup: Option<&'a str>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct LinuxSandboxCommand {
|
||||||
|
pub program: String,
|
||||||
|
pub args: Vec<String>,
|
||||||
|
pub env: Vec<(String, String)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SandboxConfig {
|
||||||
|
#[must_use]
|
||||||
|
pub fn resolve_request(
|
||||||
|
&self,
|
||||||
|
enabled_override: Option<bool>,
|
||||||
|
namespace_override: Option<bool>,
|
||||||
|
network_override: Option<bool>,
|
||||||
|
filesystem_mode_override: Option<FilesystemIsolationMode>,
|
||||||
|
allowed_mounts_override: Option<Vec<String>>,
|
||||||
|
) -> SandboxRequest {
|
||||||
|
SandboxRequest {
|
||||||
|
enabled: enabled_override.unwrap_or(self.enabled.unwrap_or(true)),
|
||||||
|
namespace_restrictions: namespace_override
|
||||||
|
.unwrap_or(self.namespace_restrictions.unwrap_or(true)),
|
||||||
|
network_isolation: network_override.unwrap_or(self.network_isolation.unwrap_or(false)),
|
||||||
|
filesystem_mode: filesystem_mode_override
|
||||||
|
.or(self.filesystem_mode)
|
||||||
|
.unwrap_or_default(),
|
||||||
|
allowed_mounts: allowed_mounts_override.unwrap_or_else(|| self.allowed_mounts.clone()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn detect_container_environment() -> ContainerEnvironment {
|
||||||
|
let proc_1_cgroup = fs::read_to_string("/proc/1/cgroup").ok();
|
||||||
|
detect_container_environment_from(SandboxDetectionInputs {
|
||||||
|
env_pairs: env::vars().collect(),
|
||||||
|
dockerenv_exists: Path::new("/.dockerenv").exists(),
|
||||||
|
containerenv_exists: Path::new("/run/.containerenv").exists(),
|
||||||
|
proc_1_cgroup: proc_1_cgroup.as_deref(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn detect_container_environment_from(
|
||||||
|
inputs: SandboxDetectionInputs<'_>,
|
||||||
|
) -> ContainerEnvironment {
|
||||||
|
let mut markers = Vec::new();
|
||||||
|
if inputs.dockerenv_exists {
|
||||||
|
markers.push("/.dockerenv".to_string());
|
||||||
|
}
|
||||||
|
if inputs.containerenv_exists {
|
||||||
|
markers.push("/run/.containerenv".to_string());
|
||||||
|
}
|
||||||
|
for (key, value) in inputs.env_pairs {
|
||||||
|
let normalized = key.to_ascii_lowercase();
|
||||||
|
if matches!(
|
||||||
|
normalized.as_str(),
|
||||||
|
"container" | "docker" | "podman" | "kubernetes_service_host"
|
||||||
|
) && !value.is_empty()
|
||||||
|
{
|
||||||
|
markers.push(format!("env:{key}={value}"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if let Some(cgroup) = inputs.proc_1_cgroup {
|
||||||
|
for needle in ["docker", "containerd", "kubepods", "podman", "libpod"] {
|
||||||
|
if cgroup.contains(needle) {
|
||||||
|
markers.push(format!("/proc/1/cgroup:{needle}"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
markers.sort();
|
||||||
|
markers.dedup();
|
||||||
|
ContainerEnvironment {
|
||||||
|
in_container: !markers.is_empty(),
|
||||||
|
markers,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn resolve_sandbox_status(config: &SandboxConfig, cwd: &Path) -> SandboxStatus {
|
||||||
|
let request = config.resolve_request(None, None, None, None, None);
|
||||||
|
resolve_sandbox_status_for_request(&request, cwd)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn resolve_sandbox_status_for_request(request: &SandboxRequest, cwd: &Path) -> SandboxStatus {
|
||||||
|
let container = detect_container_environment();
|
||||||
|
let namespace_supported = cfg!(target_os = "linux") && command_exists("unshare");
|
||||||
|
let network_supported = namespace_supported;
|
||||||
|
let filesystem_active =
|
||||||
|
request.enabled && request.filesystem_mode != FilesystemIsolationMode::Off;
|
||||||
|
let mut fallback_reasons = Vec::new();
|
||||||
|
|
||||||
|
if request.enabled && request.namespace_restrictions && !namespace_supported {
|
||||||
|
fallback_reasons
|
||||||
|
.push("namespace isolation unavailable (requires Linux with `unshare`)".to_string());
|
||||||
|
}
|
||||||
|
if request.enabled && request.network_isolation && !network_supported {
|
||||||
|
fallback_reasons
|
||||||
|
.push("network isolation unavailable (requires Linux with `unshare`)".to_string());
|
||||||
|
}
|
||||||
|
if request.enabled
|
||||||
|
&& request.filesystem_mode == FilesystemIsolationMode::AllowList
|
||||||
|
&& request.allowed_mounts.is_empty()
|
||||||
|
{
|
||||||
|
fallback_reasons
|
||||||
|
.push("filesystem allow-list requested without configured mounts".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
let active = request.enabled
|
||||||
|
&& (!request.namespace_restrictions || namespace_supported)
|
||||||
|
&& (!request.network_isolation || network_supported);
|
||||||
|
|
||||||
|
let allowed_mounts = normalize_mounts(&request.allowed_mounts, cwd);
|
||||||
|
|
||||||
|
SandboxStatus {
|
||||||
|
enabled: request.enabled,
|
||||||
|
requested: request.clone(),
|
||||||
|
supported: namespace_supported,
|
||||||
|
active,
|
||||||
|
namespace_supported,
|
||||||
|
namespace_active: request.enabled && request.namespace_restrictions && namespace_supported,
|
||||||
|
network_supported,
|
||||||
|
network_active: request.enabled && request.network_isolation && network_supported,
|
||||||
|
filesystem_mode: request.filesystem_mode,
|
||||||
|
filesystem_active,
|
||||||
|
allowed_mounts,
|
||||||
|
in_container: container.in_container,
|
||||||
|
container_markers: container.markers,
|
||||||
|
fallback_reason: (!fallback_reasons.is_empty()).then(|| fallback_reasons.join("; ")),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn build_linux_sandbox_command(
|
||||||
|
command: &str,
|
||||||
|
cwd: &Path,
|
||||||
|
status: &SandboxStatus,
|
||||||
|
) -> Option<LinuxSandboxCommand> {
|
||||||
|
if !cfg!(target_os = "linux")
|
||||||
|
|| !status.enabled
|
||||||
|
|| (!status.namespace_active && !status.network_active)
|
||||||
|
{
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut args = vec![
|
||||||
|
"--user".to_string(),
|
||||||
|
"--map-root-user".to_string(),
|
||||||
|
"--mount".to_string(),
|
||||||
|
"--ipc".to_string(),
|
||||||
|
"--pid".to_string(),
|
||||||
|
"--uts".to_string(),
|
||||||
|
"--fork".to_string(),
|
||||||
|
];
|
||||||
|
if status.network_active {
|
||||||
|
args.push("--net".to_string());
|
||||||
|
}
|
||||||
|
args.push("sh".to_string());
|
||||||
|
args.push("-lc".to_string());
|
||||||
|
args.push(command.to_string());
|
||||||
|
|
||||||
|
let sandbox_home = cwd.join(".sandbox-home");
|
||||||
|
let sandbox_tmp = cwd.join(".sandbox-tmp");
|
||||||
|
let mut env = vec![
|
||||||
|
("HOME".to_string(), sandbox_home.display().to_string()),
|
||||||
|
("TMPDIR".to_string(), sandbox_tmp.display().to_string()),
|
||||||
|
(
|
||||||
|
"CLAWD_SANDBOX_FILESYSTEM_MODE".to_string(),
|
||||||
|
status.filesystem_mode.as_str().to_string(),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"CLAWD_SANDBOX_ALLOWED_MOUNTS".to_string(),
|
||||||
|
status.allowed_mounts.join(":"),
|
||||||
|
),
|
||||||
|
];
|
||||||
|
if let Ok(path) = env::var("PATH") {
|
||||||
|
env.push(("PATH".to_string(), path));
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(LinuxSandboxCommand {
|
||||||
|
program: "unshare".to_string(),
|
||||||
|
args,
|
||||||
|
env,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn normalize_mounts(mounts: &[String], cwd: &Path) -> Vec<String> {
|
||||||
|
let cwd = cwd.to_path_buf();
|
||||||
|
mounts
|
||||||
|
.iter()
|
||||||
|
.map(|mount| {
|
||||||
|
let path = PathBuf::from(mount);
|
||||||
|
if path.is_absolute() {
|
||||||
|
path
|
||||||
|
} else {
|
||||||
|
cwd.join(path)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.map(|path| path.display().to_string())
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn command_exists(command: &str) -> bool {
|
||||||
|
env::var_os("PATH")
|
||||||
|
.is_some_and(|paths| env::split_paths(&paths).any(|path| path.join(command).exists()))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::{
|
||||||
|
build_linux_sandbox_command, detect_container_environment_from, FilesystemIsolationMode,
|
||||||
|
SandboxConfig, SandboxDetectionInputs,
|
||||||
|
};
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn detects_container_markers_from_multiple_sources() {
|
||||||
|
let detected = detect_container_environment_from(SandboxDetectionInputs {
|
||||||
|
env_pairs: vec![("container".to_string(), "docker".to_string())],
|
||||||
|
dockerenv_exists: true,
|
||||||
|
containerenv_exists: false,
|
||||||
|
proc_1_cgroup: Some("12:memory:/docker/abc"),
|
||||||
|
});
|
||||||
|
|
||||||
|
assert!(detected.in_container);
|
||||||
|
assert!(detected
|
||||||
|
.markers
|
||||||
|
.iter()
|
||||||
|
.any(|marker| marker == "/.dockerenv"));
|
||||||
|
assert!(detected
|
||||||
|
.markers
|
||||||
|
.iter()
|
||||||
|
.any(|marker| marker == "env:container=docker"));
|
||||||
|
assert!(detected
|
||||||
|
.markers
|
||||||
|
.iter()
|
||||||
|
.any(|marker| marker == "/proc/1/cgroup:docker"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn resolves_request_with_overrides() {
|
||||||
|
let config = SandboxConfig {
|
||||||
|
enabled: Some(true),
|
||||||
|
namespace_restrictions: Some(true),
|
||||||
|
network_isolation: Some(false),
|
||||||
|
filesystem_mode: Some(FilesystemIsolationMode::WorkspaceOnly),
|
||||||
|
allowed_mounts: vec!["logs".to_string()],
|
||||||
|
};
|
||||||
|
|
||||||
|
let request = config.resolve_request(
|
||||||
|
Some(true),
|
||||||
|
Some(false),
|
||||||
|
Some(true),
|
||||||
|
Some(FilesystemIsolationMode::AllowList),
|
||||||
|
Some(vec!["tmp".to_string()]),
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(request.enabled);
|
||||||
|
assert!(!request.namespace_restrictions);
|
||||||
|
assert!(request.network_isolation);
|
||||||
|
assert_eq!(request.filesystem_mode, FilesystemIsolationMode::AllowList);
|
||||||
|
assert_eq!(request.allowed_mounts, vec!["tmp"]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn builds_linux_launcher_with_network_flag_when_requested() {
|
||||||
|
let config = SandboxConfig::default();
|
||||||
|
let status = super::resolve_sandbox_status_for_request(
|
||||||
|
&config.resolve_request(
|
||||||
|
Some(true),
|
||||||
|
Some(true),
|
||||||
|
Some(true),
|
||||||
|
Some(FilesystemIsolationMode::WorkspaceOnly),
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
Path::new("/workspace"),
|
||||||
|
);
|
||||||
|
|
||||||
|
if let Some(launcher) =
|
||||||
|
build_linux_sandbox_command("printf hi", Path::new("/workspace"), &status)
|
||||||
|
{
|
||||||
|
assert_eq!(launcher.program, "unshare");
|
||||||
|
assert!(launcher.args.iter().any(|arg| arg == "--mount"));
|
||||||
|
assert!(launcher.args.iter().any(|arg| arg == "--net") == status.network_active);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -5,12 +5,17 @@ edition.workspace = true
|
|||||||
license.workspace = true
|
license.workspace = true
|
||||||
publish.workspace = true
|
publish.workspace = true
|
||||||
|
|
||||||
|
[[bin]]
|
||||||
|
name = "claw"
|
||||||
|
path = "src/main.rs"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
api = { path = "../api" }
|
api = { path = "../api" }
|
||||||
commands = { path = "../commands" }
|
commands = { path = "../commands" }
|
||||||
compat-harness = { path = "../compat-harness" }
|
compat-harness = { path = "../compat-harness" }
|
||||||
crossterm = "0.28"
|
crossterm = "0.28"
|
||||||
pulldown-cmark = "0.13"
|
pulldown-cmark = "0.13"
|
||||||
|
rustyline = "15"
|
||||||
runtime = { path = "../runtime" }
|
runtime = { path = "../runtime" }
|
||||||
serde_json = "1"
|
serde_json = "1"
|
||||||
syntect = "5"
|
syntect = "5"
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ use clap::{Parser, Subcommand, ValueEnum};
|
|||||||
about = "Rust Claude CLI prototype"
|
about = "Rust Claude CLI prototype"
|
||||||
)]
|
)]
|
||||||
pub struct Cli {
|
pub struct Cli {
|
||||||
#[arg(long, default_value = "claude-3-7-sonnet")]
|
#[arg(long, default_value = "claude-opus-4-6")]
|
||||||
pub model: String,
|
pub model: String,
|
||||||
|
|
||||||
#[arg(long, value_enum, default_value_t = PermissionMode::WorkspaceWrite)]
|
#[arg(long, value_enum, default_value_t = PermissionMode::WorkspaceWrite)]
|
||||||
@@ -31,6 +31,10 @@ pub enum Command {
|
|||||||
DumpManifests,
|
DumpManifests,
|
||||||
/// Print the current bootstrap phase skeleton
|
/// Print the current bootstrap phase skeleton
|
||||||
BootstrapPlan,
|
BootstrapPlan,
|
||||||
|
/// Start the OAuth login flow
|
||||||
|
Login,
|
||||||
|
/// Clear saved OAuth credentials
|
||||||
|
Logout,
|
||||||
/// Run a non-interactive prompt and exit
|
/// Run a non-interactive prompt and exit
|
||||||
Prompt { prompt: Vec<String> },
|
Prompt { prompt: Vec<String> },
|
||||||
}
|
}
|
||||||
@@ -86,4 +90,13 @@ mod tests {
|
|||||||
})
|
})
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parses_login_and_logout_commands() {
|
||||||
|
let login = Cli::parse_from(["rusty-claude-cli", "login"]);
|
||||||
|
assert_eq!(login.command, Some(Command::Login));
|
||||||
|
|
||||||
|
let logout = Cli::parse_from(["rusty-claude-cli", "logout"]);
|
||||||
|
assert_eq!(logout.command, Some(Command::Logout));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
433
rust/crates/rusty-claude-cli/src/init.rs
Normal file
433
rust/crates/rusty-claude-cli/src/init.rs
Normal file
@@ -0,0 +1,433 @@
|
|||||||
|
use std::fs;
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
|
|
||||||
|
const STARTER_CLAUDE_JSON: &str = concat!(
|
||||||
|
"{\n",
|
||||||
|
" \"permissions\": {\n",
|
||||||
|
" \"defaultMode\": \"acceptEdits\"\n",
|
||||||
|
" }\n",
|
||||||
|
"}\n",
|
||||||
|
);
|
||||||
|
const GITIGNORE_COMMENT: &str = "# Claude Code local artifacts";
|
||||||
|
const GITIGNORE_ENTRIES: [&str; 2] = [".claude/settings.local.json", ".claude/sessions/"];
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub(crate) enum InitStatus {
|
||||||
|
Created,
|
||||||
|
Updated,
|
||||||
|
Skipped,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl InitStatus {
|
||||||
|
#[must_use]
|
||||||
|
pub(crate) fn label(self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
Self::Created => "created",
|
||||||
|
Self::Updated => "updated",
|
||||||
|
Self::Skipped => "skipped (already exists)",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub(crate) struct InitArtifact {
|
||||||
|
pub(crate) name: &'static str,
|
||||||
|
pub(crate) status: InitStatus,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub(crate) struct InitReport {
|
||||||
|
pub(crate) project_root: PathBuf,
|
||||||
|
pub(crate) artifacts: Vec<InitArtifact>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl InitReport {
|
||||||
|
#[must_use]
|
||||||
|
pub(crate) fn render(&self) -> String {
|
||||||
|
let mut lines = vec![
|
||||||
|
"Init".to_string(),
|
||||||
|
format!(" Project {}", self.project_root.display()),
|
||||||
|
];
|
||||||
|
for artifact in &self.artifacts {
|
||||||
|
lines.push(format!(
|
||||||
|
" {:<16} {}",
|
||||||
|
artifact.name,
|
||||||
|
artifact.status.label()
|
||||||
|
));
|
||||||
|
}
|
||||||
|
lines.push(" Next step Review and tailor the generated guidance".to_string());
|
||||||
|
lines.join("\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Default, PartialEq, Eq)]
|
||||||
|
#[allow(clippy::struct_excessive_bools)]
|
||||||
|
struct RepoDetection {
|
||||||
|
rust_workspace: bool,
|
||||||
|
rust_root: bool,
|
||||||
|
python: bool,
|
||||||
|
package_json: bool,
|
||||||
|
typescript: bool,
|
||||||
|
nextjs: bool,
|
||||||
|
react: bool,
|
||||||
|
vite: bool,
|
||||||
|
nest: bool,
|
||||||
|
src_dir: bool,
|
||||||
|
tests_dir: bool,
|
||||||
|
rust_dir: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn initialize_repo(cwd: &Path) -> Result<InitReport, Box<dyn std::error::Error>> {
|
||||||
|
let mut artifacts = Vec::new();
|
||||||
|
|
||||||
|
let claude_dir = cwd.join(".claude");
|
||||||
|
artifacts.push(InitArtifact {
|
||||||
|
name: ".claude/",
|
||||||
|
status: ensure_dir(&claude_dir)?,
|
||||||
|
});
|
||||||
|
|
||||||
|
let claude_json = cwd.join(".claude.json");
|
||||||
|
artifacts.push(InitArtifact {
|
||||||
|
name: ".claude.json",
|
||||||
|
status: write_file_if_missing(&claude_json, STARTER_CLAUDE_JSON)?,
|
||||||
|
});
|
||||||
|
|
||||||
|
let gitignore = cwd.join(".gitignore");
|
||||||
|
artifacts.push(InitArtifact {
|
||||||
|
name: ".gitignore",
|
||||||
|
status: ensure_gitignore_entries(&gitignore)?,
|
||||||
|
});
|
||||||
|
|
||||||
|
let claude_md = cwd.join("CLAUDE.md");
|
||||||
|
let content = render_init_claude_md(cwd);
|
||||||
|
artifacts.push(InitArtifact {
|
||||||
|
name: "CLAUDE.md",
|
||||||
|
status: write_file_if_missing(&claude_md, &content)?,
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(InitReport {
|
||||||
|
project_root: cwd.to_path_buf(),
|
||||||
|
artifacts,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ensure_dir(path: &Path) -> Result<InitStatus, std::io::Error> {
|
||||||
|
if path.is_dir() {
|
||||||
|
return Ok(InitStatus::Skipped);
|
||||||
|
}
|
||||||
|
fs::create_dir_all(path)?;
|
||||||
|
Ok(InitStatus::Created)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn write_file_if_missing(path: &Path, content: &str) -> Result<InitStatus, std::io::Error> {
|
||||||
|
if path.exists() {
|
||||||
|
return Ok(InitStatus::Skipped);
|
||||||
|
}
|
||||||
|
fs::write(path, content)?;
|
||||||
|
Ok(InitStatus::Created)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ensure_gitignore_entries(path: &Path) -> Result<InitStatus, std::io::Error> {
|
||||||
|
if !path.exists() {
|
||||||
|
let mut lines = vec![GITIGNORE_COMMENT.to_string()];
|
||||||
|
lines.extend(GITIGNORE_ENTRIES.iter().map(|entry| (*entry).to_string()));
|
||||||
|
fs::write(path, format!("{}\n", lines.join("\n")))?;
|
||||||
|
return Ok(InitStatus::Created);
|
||||||
|
}
|
||||||
|
|
||||||
|
let existing = fs::read_to_string(path)?;
|
||||||
|
let mut lines = existing.lines().map(ToOwned::to_owned).collect::<Vec<_>>();
|
||||||
|
let mut changed = false;
|
||||||
|
|
||||||
|
if !lines.iter().any(|line| line == GITIGNORE_COMMENT) {
|
||||||
|
lines.push(GITIGNORE_COMMENT.to_string());
|
||||||
|
changed = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
for entry in GITIGNORE_ENTRIES {
|
||||||
|
if !lines.iter().any(|line| line == entry) {
|
||||||
|
lines.push(entry.to_string());
|
||||||
|
changed = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !changed {
|
||||||
|
return Ok(InitStatus::Skipped);
|
||||||
|
}
|
||||||
|
|
||||||
|
fs::write(path, format!("{}\n", lines.join("\n")))?;
|
||||||
|
Ok(InitStatus::Updated)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn render_init_claude_md(cwd: &Path) -> String {
|
||||||
|
let detection = detect_repo(cwd);
|
||||||
|
let mut lines = vec![
|
||||||
|
"# CLAUDE.md".to_string(),
|
||||||
|
String::new(),
|
||||||
|
"This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.".to_string(),
|
||||||
|
String::new(),
|
||||||
|
];
|
||||||
|
|
||||||
|
let detected_languages = detected_languages(&detection);
|
||||||
|
let detected_frameworks = detected_frameworks(&detection);
|
||||||
|
lines.push("## Detected stack".to_string());
|
||||||
|
if detected_languages.is_empty() {
|
||||||
|
lines.push("- No specific language markers were detected yet; document the primary language and verification commands once the project structure settles.".to_string());
|
||||||
|
} else {
|
||||||
|
lines.push(format!("- Languages: {}.", detected_languages.join(", ")));
|
||||||
|
}
|
||||||
|
if detected_frameworks.is_empty() {
|
||||||
|
lines.push("- Frameworks: none detected from the supported starter markers.".to_string());
|
||||||
|
} else {
|
||||||
|
lines.push(format!(
|
||||||
|
"- Frameworks/tooling markers: {}.",
|
||||||
|
detected_frameworks.join(", ")
|
||||||
|
));
|
||||||
|
}
|
||||||
|
lines.push(String::new());
|
||||||
|
|
||||||
|
let verification_lines = verification_lines(cwd, &detection);
|
||||||
|
if !verification_lines.is_empty() {
|
||||||
|
lines.push("## Verification".to_string());
|
||||||
|
lines.extend(verification_lines);
|
||||||
|
lines.push(String::new());
|
||||||
|
}
|
||||||
|
|
||||||
|
let structure_lines = repository_shape_lines(&detection);
|
||||||
|
if !structure_lines.is_empty() {
|
||||||
|
lines.push("## Repository shape".to_string());
|
||||||
|
lines.extend(structure_lines);
|
||||||
|
lines.push(String::new());
|
||||||
|
}
|
||||||
|
|
||||||
|
let framework_lines = framework_notes(&detection);
|
||||||
|
if !framework_lines.is_empty() {
|
||||||
|
lines.push("## Framework notes".to_string());
|
||||||
|
lines.extend(framework_lines);
|
||||||
|
lines.push(String::new());
|
||||||
|
}
|
||||||
|
|
||||||
|
lines.push("## Working agreement".to_string());
|
||||||
|
lines.push("- Prefer small, reviewable changes and keep generated bootstrap files aligned with actual repo workflows.".to_string());
|
||||||
|
lines.push("- Keep shared defaults in `.claude.json`; reserve `.claude/settings.local.json` for machine-local overrides.".to_string());
|
||||||
|
lines.push("- Do not overwrite existing `CLAUDE.md` content automatically; update it intentionally when repo workflows change.".to_string());
|
||||||
|
lines.push(String::new());
|
||||||
|
|
||||||
|
lines.join("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn detect_repo(cwd: &Path) -> RepoDetection {
|
||||||
|
let package_json_contents = fs::read_to_string(cwd.join("package.json"))
|
||||||
|
.unwrap_or_default()
|
||||||
|
.to_ascii_lowercase();
|
||||||
|
RepoDetection {
|
||||||
|
rust_workspace: cwd.join("rust").join("Cargo.toml").is_file(),
|
||||||
|
rust_root: cwd.join("Cargo.toml").is_file(),
|
||||||
|
python: cwd.join("pyproject.toml").is_file()
|
||||||
|
|| cwd.join("requirements.txt").is_file()
|
||||||
|
|| cwd.join("setup.py").is_file(),
|
||||||
|
package_json: cwd.join("package.json").is_file(),
|
||||||
|
typescript: cwd.join("tsconfig.json").is_file()
|
||||||
|
|| package_json_contents.contains("typescript"),
|
||||||
|
nextjs: package_json_contents.contains("\"next\""),
|
||||||
|
react: package_json_contents.contains("\"react\""),
|
||||||
|
vite: package_json_contents.contains("\"vite\""),
|
||||||
|
nest: package_json_contents.contains("@nestjs"),
|
||||||
|
src_dir: cwd.join("src").is_dir(),
|
||||||
|
tests_dir: cwd.join("tests").is_dir(),
|
||||||
|
rust_dir: cwd.join("rust").is_dir(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn detected_languages(detection: &RepoDetection) -> Vec<&'static str> {
|
||||||
|
let mut languages = Vec::new();
|
||||||
|
if detection.rust_workspace || detection.rust_root {
|
||||||
|
languages.push("Rust");
|
||||||
|
}
|
||||||
|
if detection.python {
|
||||||
|
languages.push("Python");
|
||||||
|
}
|
||||||
|
if detection.typescript {
|
||||||
|
languages.push("TypeScript");
|
||||||
|
} else if detection.package_json {
|
||||||
|
languages.push("JavaScript/Node.js");
|
||||||
|
}
|
||||||
|
languages
|
||||||
|
}
|
||||||
|
|
||||||
|
fn detected_frameworks(detection: &RepoDetection) -> Vec<&'static str> {
|
||||||
|
let mut frameworks = Vec::new();
|
||||||
|
if detection.nextjs {
|
||||||
|
frameworks.push("Next.js");
|
||||||
|
}
|
||||||
|
if detection.react {
|
||||||
|
frameworks.push("React");
|
||||||
|
}
|
||||||
|
if detection.vite {
|
||||||
|
frameworks.push("Vite");
|
||||||
|
}
|
||||||
|
if detection.nest {
|
||||||
|
frameworks.push("NestJS");
|
||||||
|
}
|
||||||
|
frameworks
|
||||||
|
}
|
||||||
|
|
||||||
|
fn verification_lines(cwd: &Path, detection: &RepoDetection) -> Vec<String> {
|
||||||
|
let mut lines = Vec::new();
|
||||||
|
if detection.rust_workspace {
|
||||||
|
lines.push("- Run Rust verification from `rust/`: `cargo fmt`, `cargo clippy --workspace --all-targets -- -D warnings`, `cargo test --workspace`".to_string());
|
||||||
|
} else if detection.rust_root {
|
||||||
|
lines.push("- Run Rust verification from the repo root: `cargo fmt`, `cargo clippy --workspace --all-targets -- -D warnings`, `cargo test --workspace`".to_string());
|
||||||
|
}
|
||||||
|
if detection.python {
|
||||||
|
if cwd.join("pyproject.toml").is_file() {
|
||||||
|
lines.push("- Run the Python project checks declared in `pyproject.toml` (for example: `pytest`, `ruff check`, and `mypy` when configured).".to_string());
|
||||||
|
} else {
|
||||||
|
lines.push(
|
||||||
|
"- Run the repo's Python test/lint commands before shipping changes.".to_string(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if detection.package_json {
|
||||||
|
lines.push("- Run the JavaScript/TypeScript checks from `package.json` before shipping changes (`npm test`, `npm run lint`, `npm run build`, or the repo equivalent).".to_string());
|
||||||
|
}
|
||||||
|
if detection.tests_dir && detection.src_dir {
|
||||||
|
lines.push("- `src/` and `tests/` are both present; update both surfaces together when behavior changes.".to_string());
|
||||||
|
}
|
||||||
|
lines
|
||||||
|
}
|
||||||
|
|
||||||
|
fn repository_shape_lines(detection: &RepoDetection) -> Vec<String> {
|
||||||
|
let mut lines = Vec::new();
|
||||||
|
if detection.rust_dir {
|
||||||
|
lines.push(
|
||||||
|
"- `rust/` contains the Rust workspace and active CLI/runtime implementation."
|
||||||
|
.to_string(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if detection.src_dir {
|
||||||
|
lines.push("- `src/` contains source files that should stay consistent with generated guidance and tests.".to_string());
|
||||||
|
}
|
||||||
|
if detection.tests_dir {
|
||||||
|
lines.push("- `tests/` contains validation surfaces that should be reviewed alongside code changes.".to_string());
|
||||||
|
}
|
||||||
|
lines
|
||||||
|
}
|
||||||
|
|
||||||
|
fn framework_notes(detection: &RepoDetection) -> Vec<String> {
|
||||||
|
let mut lines = Vec::new();
|
||||||
|
if detection.nextjs {
|
||||||
|
lines.push("- Next.js detected: preserve routing/data-fetching conventions and verify production builds after changing app structure.".to_string());
|
||||||
|
}
|
||||||
|
if detection.react && !detection.nextjs {
|
||||||
|
lines.push("- React detected: keep component behavior covered with focused tests and avoid unnecessary prop/API churn.".to_string());
|
||||||
|
}
|
||||||
|
if detection.vite {
|
||||||
|
lines.push("- Vite detected: validate the production bundle after changing build-sensitive configuration or imports.".to_string());
|
||||||
|
}
|
||||||
|
if detection.nest {
|
||||||
|
lines.push("- NestJS detected: keep module/provider boundaries explicit and verify controller/service wiring after refactors.".to_string());
|
||||||
|
}
|
||||||
|
lines
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::{initialize_repo, render_init_claude_md};
|
||||||
|
use std::fs;
|
||||||
|
use std::path::Path;
|
||||||
|
use std::time::{SystemTime, UNIX_EPOCH};
|
||||||
|
|
||||||
|
fn temp_dir() -> std::path::PathBuf {
|
||||||
|
let nanos = SystemTime::now()
|
||||||
|
.duration_since(UNIX_EPOCH)
|
||||||
|
.expect("time should be after epoch")
|
||||||
|
.as_nanos();
|
||||||
|
std::env::temp_dir().join(format!("rusty-claude-init-{nanos}"))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn initialize_repo_creates_expected_files_and_gitignore_entries() {
|
||||||
|
let root = temp_dir();
|
||||||
|
fs::create_dir_all(root.join("rust")).expect("create rust dir");
|
||||||
|
fs::write(root.join("rust").join("Cargo.toml"), "[workspace]\n").expect("write cargo");
|
||||||
|
|
||||||
|
let report = initialize_repo(&root).expect("init should succeed");
|
||||||
|
let rendered = report.render();
|
||||||
|
assert!(rendered.contains(".claude/ created"));
|
||||||
|
assert!(rendered.contains(".claude.json created"));
|
||||||
|
assert!(rendered.contains(".gitignore created"));
|
||||||
|
assert!(rendered.contains("CLAUDE.md created"));
|
||||||
|
assert!(root.join(".claude").is_dir());
|
||||||
|
assert!(root.join(".claude.json").is_file());
|
||||||
|
assert!(root.join("CLAUDE.md").is_file());
|
||||||
|
assert_eq!(
|
||||||
|
fs::read_to_string(root.join(".claude.json")).expect("read claude json"),
|
||||||
|
concat!(
|
||||||
|
"{\n",
|
||||||
|
" \"permissions\": {\n",
|
||||||
|
" \"defaultMode\": \"acceptEdits\"\n",
|
||||||
|
" }\n",
|
||||||
|
"}\n",
|
||||||
|
)
|
||||||
|
);
|
||||||
|
let gitignore = fs::read_to_string(root.join(".gitignore")).expect("read gitignore");
|
||||||
|
assert!(gitignore.contains(".claude/settings.local.json"));
|
||||||
|
assert!(gitignore.contains(".claude/sessions/"));
|
||||||
|
let claude_md = fs::read_to_string(root.join("CLAUDE.md")).expect("read claude md");
|
||||||
|
assert!(claude_md.contains("Languages: Rust."));
|
||||||
|
assert!(claude_md.contains("cargo clippy --workspace --all-targets -- -D warnings"));
|
||||||
|
|
||||||
|
fs::remove_dir_all(root).expect("cleanup temp dir");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn initialize_repo_is_idempotent_and_preserves_existing_files() {
|
||||||
|
let root = temp_dir();
|
||||||
|
fs::create_dir_all(&root).expect("create root");
|
||||||
|
fs::write(root.join("CLAUDE.md"), "custom guidance\n").expect("write existing claude md");
|
||||||
|
fs::write(root.join(".gitignore"), ".claude/settings.local.json\n")
|
||||||
|
.expect("write gitignore");
|
||||||
|
|
||||||
|
let first = initialize_repo(&root).expect("first init should succeed");
|
||||||
|
assert!(first
|
||||||
|
.render()
|
||||||
|
.contains("CLAUDE.md skipped (already exists)"));
|
||||||
|
let second = initialize_repo(&root).expect("second init should succeed");
|
||||||
|
let second_rendered = second.render();
|
||||||
|
assert!(second_rendered.contains(".claude/ skipped (already exists)"));
|
||||||
|
assert!(second_rendered.contains(".claude.json skipped (already exists)"));
|
||||||
|
assert!(second_rendered.contains(".gitignore skipped (already exists)"));
|
||||||
|
assert!(second_rendered.contains("CLAUDE.md skipped (already exists)"));
|
||||||
|
assert_eq!(
|
||||||
|
fs::read_to_string(root.join("CLAUDE.md")).expect("read existing claude md"),
|
||||||
|
"custom guidance\n"
|
||||||
|
);
|
||||||
|
let gitignore = fs::read_to_string(root.join(".gitignore")).expect("read gitignore");
|
||||||
|
assert_eq!(gitignore.matches(".claude/settings.local.json").count(), 1);
|
||||||
|
assert_eq!(gitignore.matches(".claude/sessions/").count(), 1);
|
||||||
|
|
||||||
|
fs::remove_dir_all(root).expect("cleanup temp dir");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn render_init_template_mentions_detected_python_and_nextjs_markers() {
|
||||||
|
let root = temp_dir();
|
||||||
|
fs::create_dir_all(&root).expect("create root");
|
||||||
|
fs::write(root.join("pyproject.toml"), "[project]\nname = \"demo\"\n")
|
||||||
|
.expect("write pyproject");
|
||||||
|
fs::write(
|
||||||
|
root.join("package.json"),
|
||||||
|
r#"{"dependencies":{"next":"14.0.0","react":"18.0.0"},"devDependencies":{"typescript":"5.0.0"}}"#,
|
||||||
|
)
|
||||||
|
.expect("write package json");
|
||||||
|
|
||||||
|
let rendered = render_init_claude_md(Path::new(&root));
|
||||||
|
assert!(rendered.contains("Languages: Python, TypeScript."));
|
||||||
|
assert!(rendered.contains("Frameworks/tooling markers: Next.js, React."));
|
||||||
|
assert!(rendered.contains("pyproject.toml"));
|
||||||
|
assert!(rendered.contains("Next.js detected"));
|
||||||
|
|
||||||
|
fs::remove_dir_all(root).expect("cleanup temp dir");
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,166 +1,16 @@
|
|||||||
|
use std::borrow::Cow;
|
||||||
|
use std::cell::RefCell;
|
||||||
use std::io::{self, IsTerminal, Write};
|
use std::io::{self, IsTerminal, Write};
|
||||||
|
|
||||||
use crossterm::cursor::{MoveDown, MoveToColumn, MoveUp};
|
use rustyline::completion::{Completer, Pair};
|
||||||
use crossterm::event::{self, Event, KeyCode, KeyEvent, KeyModifiers};
|
use rustyline::error::ReadlineError;
|
||||||
use crossterm::queue;
|
use rustyline::highlight::{CmdKind, Highlighter};
|
||||||
use crossterm::terminal::{disable_raw_mode, enable_raw_mode, Clear, ClearType};
|
use rustyline::hint::Hinter;
|
||||||
|
use rustyline::history::DefaultHistory;
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
use rustyline::validate::Validator;
|
||||||
pub struct InputBuffer {
|
use rustyline::{
|
||||||
buffer: String,
|
Cmd, CompletionType, Config, Context, EditMode, Editor, Helper, KeyCode, KeyEvent, Modifiers,
|
||||||
cursor: usize,
|
};
|
||||||
}
|
|
||||||
|
|
||||||
impl InputBuffer {
|
|
||||||
#[must_use]
|
|
||||||
pub fn new() -> Self {
|
|
||||||
Self {
|
|
||||||
buffer: String::new(),
|
|
||||||
cursor: 0,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn insert(&mut self, ch: char) {
|
|
||||||
self.buffer.insert(self.cursor, ch);
|
|
||||||
self.cursor += ch.len_utf8();
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn insert_newline(&mut self) {
|
|
||||||
self.insert('\n');
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn backspace(&mut self) {
|
|
||||||
if self.cursor == 0 {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
let previous = self.buffer[..self.cursor]
|
|
||||||
.char_indices()
|
|
||||||
.last()
|
|
||||||
.map_or(0, |(idx, _)| idx);
|
|
||||||
self.buffer.drain(previous..self.cursor);
|
|
||||||
self.cursor = previous;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn move_left(&mut self) {
|
|
||||||
if self.cursor == 0 {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
self.cursor = self.buffer[..self.cursor]
|
|
||||||
.char_indices()
|
|
||||||
.last()
|
|
||||||
.map_or(0, |(idx, _)| idx);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn move_right(&mut self) {
|
|
||||||
if self.cursor >= self.buffer.len() {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if let Some(next) = self.buffer[self.cursor..].chars().next() {
|
|
||||||
self.cursor += next.len_utf8();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn move_home(&mut self) {
|
|
||||||
self.cursor = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn move_end(&mut self) {
|
|
||||||
self.cursor = self.buffer.len();
|
|
||||||
}
|
|
||||||
|
|
||||||
#[must_use]
|
|
||||||
pub fn as_str(&self) -> &str {
|
|
||||||
&self.buffer
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
#[must_use]
|
|
||||||
pub fn cursor(&self) -> usize {
|
|
||||||
self.cursor
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn clear(&mut self) {
|
|
||||||
self.buffer.clear();
|
|
||||||
self.cursor = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn replace(&mut self, value: impl Into<String>) {
|
|
||||||
self.buffer = value.into();
|
|
||||||
self.cursor = self.buffer.len();
|
|
||||||
}
|
|
||||||
|
|
||||||
#[must_use]
|
|
||||||
fn current_command_prefix(&self) -> Option<&str> {
|
|
||||||
if self.cursor != self.buffer.len() {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
let prefix = &self.buffer[..self.cursor];
|
|
||||||
if prefix.contains(char::is_whitespace) || !prefix.starts_with('/') {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
Some(prefix)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn complete_slash_command(&mut self, candidates: &[String]) -> bool {
|
|
||||||
let Some(prefix) = self.current_command_prefix() else {
|
|
||||||
return false;
|
|
||||||
};
|
|
||||||
|
|
||||||
let matches = candidates
|
|
||||||
.iter()
|
|
||||||
.filter(|candidate| candidate.starts_with(prefix))
|
|
||||||
.map(String::as_str)
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
if matches.is_empty() {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
let replacement = longest_common_prefix(&matches);
|
|
||||||
if replacement == prefix {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
self.replace(replacement);
|
|
||||||
true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
||||||
pub struct RenderedBuffer {
|
|
||||||
lines: Vec<String>,
|
|
||||||
cursor_row: u16,
|
|
||||||
cursor_col: u16,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl RenderedBuffer {
|
|
||||||
#[must_use]
|
|
||||||
pub fn line_count(&self) -> usize {
|
|
||||||
self.lines.len()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn write(&self, out: &mut impl Write) -> io::Result<()> {
|
|
||||||
for (index, line) in self.lines.iter().enumerate() {
|
|
||||||
if index > 0 {
|
|
||||||
writeln!(out)?;
|
|
||||||
}
|
|
||||||
write!(out, "{line}")?;
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
#[must_use]
|
|
||||||
pub fn lines(&self) -> &[String] {
|
|
||||||
&self.lines
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
#[must_use]
|
|
||||||
pub fn cursor_position(&self) -> (u16, u16) {
|
|
||||||
(self.cursor_row, self.cursor_col)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
pub enum ReadOutcome {
|
pub enum ReadOutcome {
|
||||||
@@ -169,25 +19,101 @@ pub enum ReadOutcome {
|
|||||||
Exit,
|
Exit,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct SlashCommandHelper {
|
||||||
|
completions: Vec<String>,
|
||||||
|
current_line: RefCell<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SlashCommandHelper {
|
||||||
|
fn new(completions: Vec<String>) -> Self {
|
||||||
|
Self {
|
||||||
|
completions,
|
||||||
|
current_line: RefCell::new(String::new()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn reset_current_line(&self) {
|
||||||
|
self.current_line.borrow_mut().clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn current_line(&self) -> String {
|
||||||
|
self.current_line.borrow().clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_current_line(&self, line: &str) {
|
||||||
|
let mut current = self.current_line.borrow_mut();
|
||||||
|
current.clear();
|
||||||
|
current.push_str(line);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Completer for SlashCommandHelper {
|
||||||
|
type Candidate = Pair;
|
||||||
|
|
||||||
|
fn complete(
|
||||||
|
&self,
|
||||||
|
line: &str,
|
||||||
|
pos: usize,
|
||||||
|
_ctx: &Context<'_>,
|
||||||
|
) -> rustyline::Result<(usize, Vec<Self::Candidate>)> {
|
||||||
|
let Some(prefix) = slash_command_prefix(line, pos) else {
|
||||||
|
return Ok((0, Vec::new()));
|
||||||
|
};
|
||||||
|
|
||||||
|
let matches = self
|
||||||
|
.completions
|
||||||
|
.iter()
|
||||||
|
.filter(|candidate| candidate.starts_with(prefix))
|
||||||
|
.map(|candidate| Pair {
|
||||||
|
display: candidate.clone(),
|
||||||
|
replacement: candidate.clone(),
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok((0, matches))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Hinter for SlashCommandHelper {
|
||||||
|
type Hint = String;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Highlighter for SlashCommandHelper {
|
||||||
|
fn highlight<'l>(&self, line: &'l str, _pos: usize) -> Cow<'l, str> {
|
||||||
|
self.set_current_line(line);
|
||||||
|
Cow::Borrowed(line)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn highlight_char(&self, line: &str, _pos: usize, _kind: CmdKind) -> bool {
|
||||||
|
self.set_current_line(line);
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Validator for SlashCommandHelper {}
|
||||||
|
impl Helper for SlashCommandHelper {}
|
||||||
|
|
||||||
pub struct LineEditor {
|
pub struct LineEditor {
|
||||||
prompt: String,
|
prompt: String,
|
||||||
continuation_prompt: String,
|
editor: Editor<SlashCommandHelper, DefaultHistory>,
|
||||||
history: Vec<String>,
|
|
||||||
history_index: Option<usize>,
|
|
||||||
draft: Option<String>,
|
|
||||||
completions: Vec<String>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LineEditor {
|
impl LineEditor {
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn new(prompt: impl Into<String>, completions: Vec<String>) -> Self {
|
pub fn new(prompt: impl Into<String>, completions: Vec<String>) -> Self {
|
||||||
|
let config = Config::builder()
|
||||||
|
.completion_type(CompletionType::List)
|
||||||
|
.edit_mode(EditMode::Emacs)
|
||||||
|
.build();
|
||||||
|
let mut editor = Editor::<SlashCommandHelper, DefaultHistory>::with_config(config)
|
||||||
|
.expect("rustyline editor should initialize");
|
||||||
|
editor.set_helper(Some(SlashCommandHelper::new(completions)));
|
||||||
|
editor.bind_sequence(KeyEvent(KeyCode::Char('J'), Modifiers::CTRL), Cmd::Newline);
|
||||||
|
editor.bind_sequence(KeyEvent(KeyCode::Enter, Modifiers::SHIFT), Cmd::Newline);
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
prompt: prompt.into(),
|
prompt: prompt.into(),
|
||||||
continuation_prompt: String::from("> "),
|
editor,
|
||||||
history: Vec::new(),
|
|
||||||
history_index: None,
|
|
||||||
draft: None,
|
|
||||||
completions,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -196,9 +122,8 @@ impl LineEditor {
|
|||||||
if entry.trim().is_empty() {
|
if entry.trim().is_empty() {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
self.history.push(entry);
|
|
||||||
self.history_index = None;
|
let _ = self.editor.add_history_entry(entry);
|
||||||
self.draft = None;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn read_line(&mut self) -> io::Result<ReadOutcome> {
|
pub fn read_line(&mut self) -> io::Result<ReadOutcome> {
|
||||||
@@ -206,45 +131,43 @@ impl LineEditor {
|
|||||||
return self.read_line_fallback();
|
return self.read_line_fallback();
|
||||||
}
|
}
|
||||||
|
|
||||||
enable_raw_mode()?;
|
if let Some(helper) = self.editor.helper_mut() {
|
||||||
let mut stdout = io::stdout();
|
helper.reset_current_line();
|
||||||
let mut input = InputBuffer::new();
|
}
|
||||||
let mut rendered_lines = 1usize;
|
|
||||||
self.redraw(&mut stdout, &input, rendered_lines)?;
|
|
||||||
|
|
||||||
loop {
|
match self.editor.readline(&self.prompt) {
|
||||||
let event = event::read()?;
|
Ok(line) => Ok(ReadOutcome::Submit(line)),
|
||||||
if let Event::Key(key) = event {
|
Err(ReadlineError::Interrupted) => {
|
||||||
match self.handle_key(key, &mut input) {
|
let has_input = !self.current_line().is_empty();
|
||||||
EditorAction::Continue => {
|
self.finish_interrupted_read()?;
|
||||||
rendered_lines = self.redraw(&mut stdout, &input, rendered_lines)?;
|
if has_input {
|
||||||
}
|
Ok(ReadOutcome::Cancel)
|
||||||
EditorAction::Submit => {
|
} else {
|
||||||
disable_raw_mode()?;
|
Ok(ReadOutcome::Exit)
|
||||||
writeln!(stdout)?;
|
|
||||||
self.history_index = None;
|
|
||||||
self.draft = None;
|
|
||||||
return Ok(ReadOutcome::Submit(input.as_str().to_owned()));
|
|
||||||
}
|
|
||||||
EditorAction::Cancel => {
|
|
||||||
disable_raw_mode()?;
|
|
||||||
writeln!(stdout)?;
|
|
||||||
self.history_index = None;
|
|
||||||
self.draft = None;
|
|
||||||
return Ok(ReadOutcome::Cancel);
|
|
||||||
}
|
|
||||||
EditorAction::Exit => {
|
|
||||||
disable_raw_mode()?;
|
|
||||||
writeln!(stdout)?;
|
|
||||||
self.history_index = None;
|
|
||||||
self.draft = None;
|
|
||||||
return Ok(ReadOutcome::Exit);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Err(ReadlineError::Eof) => {
|
||||||
|
self.finish_interrupted_read()?;
|
||||||
|
Ok(ReadOutcome::Exit)
|
||||||
|
}
|
||||||
|
Err(error) => Err(io::Error::other(error)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn current_line(&self) -> String {
|
||||||
|
self.editor
|
||||||
|
.helper()
|
||||||
|
.map_or_else(String::new, SlashCommandHelper::current_line)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn finish_interrupted_read(&mut self) -> io::Result<()> {
|
||||||
|
if let Some(helper) = self.editor.helper_mut() {
|
||||||
|
helper.reset_current_line();
|
||||||
|
}
|
||||||
|
let mut stdout = io::stdout();
|
||||||
|
writeln!(stdout)
|
||||||
|
}
|
||||||
|
|
||||||
fn read_line_fallback(&self) -> io::Result<ReadOutcome> {
|
fn read_line_fallback(&self) -> io::Result<ReadOutcome> {
|
||||||
let mut stdout = io::stdout();
|
let mut stdout = io::stdout();
|
||||||
write!(stdout, "{}", self.prompt)?;
|
write!(stdout, "{}", self.prompt)?;
|
||||||
@@ -261,388 +184,86 @@ impl LineEditor {
|
|||||||
}
|
}
|
||||||
Ok(ReadOutcome::Submit(buffer))
|
Ok(ReadOutcome::Submit(buffer))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_lines)]
|
|
||||||
fn handle_key(&mut self, key: KeyEvent, input: &mut InputBuffer) -> EditorAction {
|
|
||||||
match key {
|
|
||||||
KeyEvent {
|
|
||||||
code: KeyCode::Char('c'),
|
|
||||||
modifiers,
|
|
||||||
..
|
|
||||||
} if modifiers.contains(KeyModifiers::CONTROL) => {
|
|
||||||
if input.as_str().is_empty() {
|
|
||||||
EditorAction::Exit
|
|
||||||
} else {
|
|
||||||
input.clear();
|
|
||||||
self.history_index = None;
|
|
||||||
self.draft = None;
|
|
||||||
EditorAction::Cancel
|
|
||||||
}
|
|
||||||
}
|
|
||||||
KeyEvent {
|
|
||||||
code: KeyCode::Char('j'),
|
|
||||||
modifiers,
|
|
||||||
..
|
|
||||||
} if modifiers.contains(KeyModifiers::CONTROL) => {
|
|
||||||
input.insert_newline();
|
|
||||||
EditorAction::Continue
|
|
||||||
}
|
|
||||||
KeyEvent {
|
|
||||||
code: KeyCode::Enter,
|
|
||||||
modifiers,
|
|
||||||
..
|
|
||||||
} if modifiers.contains(KeyModifiers::SHIFT) => {
|
|
||||||
input.insert_newline();
|
|
||||||
EditorAction::Continue
|
|
||||||
}
|
|
||||||
KeyEvent {
|
|
||||||
code: KeyCode::Enter,
|
|
||||||
..
|
|
||||||
} => EditorAction::Submit,
|
|
||||||
KeyEvent {
|
|
||||||
code: KeyCode::Backspace,
|
|
||||||
..
|
|
||||||
} => {
|
|
||||||
input.backspace();
|
|
||||||
EditorAction::Continue
|
|
||||||
}
|
|
||||||
KeyEvent {
|
|
||||||
code: KeyCode::Left,
|
|
||||||
..
|
|
||||||
} => {
|
|
||||||
input.move_left();
|
|
||||||
EditorAction::Continue
|
|
||||||
}
|
|
||||||
KeyEvent {
|
|
||||||
code: KeyCode::Right,
|
|
||||||
..
|
|
||||||
} => {
|
|
||||||
input.move_right();
|
|
||||||
EditorAction::Continue
|
|
||||||
}
|
|
||||||
KeyEvent {
|
|
||||||
code: KeyCode::Up, ..
|
|
||||||
} => {
|
|
||||||
self.navigate_history_up(input);
|
|
||||||
EditorAction::Continue
|
|
||||||
}
|
|
||||||
KeyEvent {
|
|
||||||
code: KeyCode::Down,
|
|
||||||
..
|
|
||||||
} => {
|
|
||||||
self.navigate_history_down(input);
|
|
||||||
EditorAction::Continue
|
|
||||||
}
|
|
||||||
KeyEvent {
|
|
||||||
code: KeyCode::Tab, ..
|
|
||||||
} => {
|
|
||||||
input.complete_slash_command(&self.completions);
|
|
||||||
EditorAction::Continue
|
|
||||||
}
|
|
||||||
KeyEvent {
|
|
||||||
code: KeyCode::Home,
|
|
||||||
..
|
|
||||||
} => {
|
|
||||||
input.move_home();
|
|
||||||
EditorAction::Continue
|
|
||||||
}
|
|
||||||
KeyEvent {
|
|
||||||
code: KeyCode::End, ..
|
|
||||||
} => {
|
|
||||||
input.move_end();
|
|
||||||
EditorAction::Continue
|
|
||||||
}
|
|
||||||
KeyEvent {
|
|
||||||
code: KeyCode::Esc, ..
|
|
||||||
} => {
|
|
||||||
input.clear();
|
|
||||||
self.history_index = None;
|
|
||||||
self.draft = None;
|
|
||||||
EditorAction::Cancel
|
|
||||||
}
|
|
||||||
KeyEvent {
|
|
||||||
code: KeyCode::Char(ch),
|
|
||||||
modifiers,
|
|
||||||
..
|
|
||||||
} if modifiers.is_empty() || modifiers == KeyModifiers::SHIFT => {
|
|
||||||
input.insert(ch);
|
|
||||||
self.history_index = None;
|
|
||||||
self.draft = None;
|
|
||||||
EditorAction::Continue
|
|
||||||
}
|
|
||||||
_ => EditorAction::Continue,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn navigate_history_up(&mut self, input: &mut InputBuffer) {
|
|
||||||
if self.history.is_empty() {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
match self.history_index {
|
|
||||||
Some(0) => {}
|
|
||||||
Some(index) => {
|
|
||||||
let next_index = index - 1;
|
|
||||||
input.replace(self.history[next_index].clone());
|
|
||||||
self.history_index = Some(next_index);
|
|
||||||
}
|
|
||||||
None => {
|
|
||||||
self.draft = Some(input.as_str().to_owned());
|
|
||||||
let next_index = self.history.len() - 1;
|
|
||||||
input.replace(self.history[next_index].clone());
|
|
||||||
self.history_index = Some(next_index);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn navigate_history_down(&mut self, input: &mut InputBuffer) {
|
|
||||||
let Some(index) = self.history_index else {
|
|
||||||
return;
|
|
||||||
};
|
|
||||||
|
|
||||||
if index + 1 < self.history.len() {
|
|
||||||
let next_index = index + 1;
|
|
||||||
input.replace(self.history[next_index].clone());
|
|
||||||
self.history_index = Some(next_index);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
input.replace(self.draft.take().unwrap_or_default());
|
|
||||||
self.history_index = None;
|
|
||||||
}
|
|
||||||
|
|
||||||
fn redraw(
|
|
||||||
&self,
|
|
||||||
out: &mut impl Write,
|
|
||||||
input: &InputBuffer,
|
|
||||||
previous_line_count: usize,
|
|
||||||
) -> io::Result<usize> {
|
|
||||||
let rendered = render_buffer(&self.prompt, &self.continuation_prompt, input);
|
|
||||||
if previous_line_count > 1 {
|
|
||||||
queue!(out, MoveUp(saturating_u16(previous_line_count - 1)))?;
|
|
||||||
}
|
|
||||||
queue!(out, MoveToColumn(0), Clear(ClearType::FromCursorDown),)?;
|
|
||||||
rendered.write(out)?;
|
|
||||||
queue!(
|
|
||||||
out,
|
|
||||||
MoveUp(saturating_u16(rendered.line_count().saturating_sub(1))),
|
|
||||||
MoveToColumn(0),
|
|
||||||
)?;
|
|
||||||
if rendered.cursor_row > 0 {
|
|
||||||
queue!(out, MoveDown(rendered.cursor_row))?;
|
|
||||||
}
|
|
||||||
queue!(out, MoveToColumn(rendered.cursor_col))?;
|
|
||||||
out.flush()?;
|
|
||||||
Ok(rendered.line_count())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
fn slash_command_prefix(line: &str, pos: usize) -> Option<&str> {
|
||||||
enum EditorAction {
|
if pos != line.len() {
|
||||||
Continue,
|
return None;
|
||||||
Submit,
|
|
||||||
Cancel,
|
|
||||||
Exit,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[must_use]
|
|
||||||
pub fn render_buffer(
|
|
||||||
prompt: &str,
|
|
||||||
continuation_prompt: &str,
|
|
||||||
input: &InputBuffer,
|
|
||||||
) -> RenderedBuffer {
|
|
||||||
let before_cursor = &input.as_str()[..input.cursor];
|
|
||||||
let cursor_row = saturating_u16(before_cursor.chars().filter(|ch| *ch == '\n').count());
|
|
||||||
let cursor_line = before_cursor.rsplit('\n').next().unwrap_or_default();
|
|
||||||
let cursor_prompt = if cursor_row == 0 {
|
|
||||||
prompt
|
|
||||||
} else {
|
|
||||||
continuation_prompt
|
|
||||||
};
|
|
||||||
let cursor_col = saturating_u16(cursor_prompt.chars().count() + cursor_line.chars().count());
|
|
||||||
|
|
||||||
let mut lines = Vec::new();
|
|
||||||
for (index, line) in input.as_str().split('\n').enumerate() {
|
|
||||||
let prefix = if index == 0 {
|
|
||||||
prompt
|
|
||||||
} else {
|
|
||||||
continuation_prompt
|
|
||||||
};
|
|
||||||
lines.push(format!("{prefix}{line}"));
|
|
||||||
}
|
|
||||||
if lines.is_empty() {
|
|
||||||
lines.push(prompt.to_string());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
RenderedBuffer {
|
let prefix = &line[..pos];
|
||||||
lines,
|
if prefix.contains(char::is_whitespace) || !prefix.starts_with('/') {
|
||||||
cursor_row,
|
return None;
|
||||||
cursor_col,
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
#[must_use]
|
Some(prefix)
|
||||||
fn longest_common_prefix(values: &[&str]) -> String {
|
|
||||||
let Some(first) = values.first() else {
|
|
||||||
return String::new();
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut prefix = (*first).to_string();
|
|
||||||
for value in values.iter().skip(1) {
|
|
||||||
while !value.starts_with(&prefix) {
|
|
||||||
prefix.pop();
|
|
||||||
if prefix.is_empty() {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
prefix
|
|
||||||
}
|
|
||||||
|
|
||||||
#[must_use]
|
|
||||||
fn saturating_u16(value: usize) -> u16 {
|
|
||||||
u16::try_from(value).unwrap_or(u16::MAX)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::{render_buffer, InputBuffer, LineEditor};
|
use super::{slash_command_prefix, LineEditor, SlashCommandHelper};
|
||||||
use crossterm::event::{KeyCode, KeyEvent, KeyModifiers};
|
use rustyline::completion::Completer;
|
||||||
|
use rustyline::highlight::Highlighter;
|
||||||
|
use rustyline::history::{DefaultHistory, History};
|
||||||
|
use rustyline::Context;
|
||||||
|
|
||||||
fn key(code: KeyCode) -> KeyEvent {
|
#[test]
|
||||||
KeyEvent::new(code, KeyModifiers::NONE)
|
fn extracts_only_terminal_slash_command_prefixes() {
|
||||||
|
assert_eq!(slash_command_prefix("/he", 3), Some("/he"));
|
||||||
|
assert_eq!(slash_command_prefix("/help me", 5), None);
|
||||||
|
assert_eq!(slash_command_prefix("hello", 5), None);
|
||||||
|
assert_eq!(slash_command_prefix("/help", 2), None);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn supports_basic_line_editing() {
|
fn completes_matching_slash_commands() {
|
||||||
let mut input = InputBuffer::new();
|
let helper = SlashCommandHelper::new(vec![
|
||||||
input.insert('h');
|
|
||||||
input.insert('i');
|
|
||||||
input.move_end();
|
|
||||||
input.insert_newline();
|
|
||||||
input.insert('x');
|
|
||||||
|
|
||||||
assert_eq!(input.as_str(), "hi\nx");
|
|
||||||
assert_eq!(input.cursor(), 4);
|
|
||||||
|
|
||||||
input.move_left();
|
|
||||||
input.backspace();
|
|
||||||
assert_eq!(input.as_str(), "hix");
|
|
||||||
assert_eq!(input.cursor(), 2);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn completes_unique_slash_command() {
|
|
||||||
let mut input = InputBuffer::new();
|
|
||||||
for ch in "/he".chars() {
|
|
||||||
input.insert(ch);
|
|
||||||
}
|
|
||||||
|
|
||||||
assert!(input.complete_slash_command(&[
|
|
||||||
"/help".to_string(),
|
"/help".to_string(),
|
||||||
"/hello".to_string(),
|
"/hello".to_string(),
|
||||||
"/status".to_string(),
|
"/status".to_string(),
|
||||||
]));
|
]);
|
||||||
assert_eq!(input.as_str(), "/hel");
|
let history = DefaultHistory::new();
|
||||||
|
let ctx = Context::new(&history);
|
||||||
|
let (start, matches) = helper
|
||||||
|
.complete("/he", 3, &ctx)
|
||||||
|
.expect("completion should work");
|
||||||
|
|
||||||
assert!(input.complete_slash_command(&["/help".to_string(), "/status".to_string()]));
|
assert_eq!(start, 0);
|
||||||
assert_eq!(input.as_str(), "/help");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn ignores_completion_when_prefix_is_not_a_slash_command() {
|
|
||||||
let mut input = InputBuffer::new();
|
|
||||||
for ch in "hello".chars() {
|
|
||||||
input.insert(ch);
|
|
||||||
}
|
|
||||||
|
|
||||||
assert!(!input.complete_slash_command(&["/help".to_string()]));
|
|
||||||
assert_eq!(input.as_str(), "hello");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn history_navigation_restores_current_draft() {
|
|
||||||
let mut editor = LineEditor::new("› ", vec![]);
|
|
||||||
editor.push_history("/help");
|
|
||||||
editor.push_history("status report");
|
|
||||||
|
|
||||||
let mut input = InputBuffer::new();
|
|
||||||
for ch in "draft".chars() {
|
|
||||||
input.insert(ch);
|
|
||||||
}
|
|
||||||
|
|
||||||
let _ = editor.handle_key(key(KeyCode::Up), &mut input);
|
|
||||||
assert_eq!(input.as_str(), "status report");
|
|
||||||
|
|
||||||
let _ = editor.handle_key(key(KeyCode::Up), &mut input);
|
|
||||||
assert_eq!(input.as_str(), "/help");
|
|
||||||
|
|
||||||
let _ = editor.handle_key(key(KeyCode::Down), &mut input);
|
|
||||||
assert_eq!(input.as_str(), "status report");
|
|
||||||
|
|
||||||
let _ = editor.handle_key(key(KeyCode::Down), &mut input);
|
|
||||||
assert_eq!(input.as_str(), "draft");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn tab_key_completes_from_editor_candidates() {
|
|
||||||
let mut editor = LineEditor::new(
|
|
||||||
"› ",
|
|
||||||
vec![
|
|
||||||
"/help".to_string(),
|
|
||||||
"/status".to_string(),
|
|
||||||
"/session".to_string(),
|
|
||||||
],
|
|
||||||
);
|
|
||||||
let mut input = InputBuffer::new();
|
|
||||||
for ch in "/st".chars() {
|
|
||||||
input.insert(ch);
|
|
||||||
}
|
|
||||||
|
|
||||||
let _ = editor.handle_key(key(KeyCode::Tab), &mut input);
|
|
||||||
assert_eq!(input.as_str(), "/status");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn renders_multiline_buffers_with_continuation_prompt() {
|
|
||||||
let mut input = InputBuffer::new();
|
|
||||||
for ch in "hello\nworld".chars() {
|
|
||||||
if ch == '\n' {
|
|
||||||
input.insert_newline();
|
|
||||||
} else {
|
|
||||||
input.insert(ch);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let rendered = render_buffer("› ", "> ", &input);
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
rendered.lines(),
|
matches
|
||||||
&["› hello".to_string(), "> world".to_string()]
|
.into_iter()
|
||||||
|
.map(|candidate| candidate.replacement)
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
vec!["/help".to_string(), "/hello".to_string()]
|
||||||
);
|
);
|
||||||
assert_eq!(rendered.cursor_position(), (1, 7));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn ctrl_c_exits_only_when_buffer_is_empty() {
|
fn ignores_non_slash_command_completion_requests() {
|
||||||
let mut editor = LineEditor::new("› ", vec![]);
|
let helper = SlashCommandHelper::new(vec!["/help".to_string()]);
|
||||||
let mut empty = InputBuffer::new();
|
let history = DefaultHistory::new();
|
||||||
assert!(matches!(
|
let ctx = Context::new(&history);
|
||||||
editor.handle_key(
|
let (_, matches) = helper
|
||||||
KeyEvent::new(KeyCode::Char('c'), KeyModifiers::CONTROL),
|
.complete("hello", 5, &ctx)
|
||||||
&mut empty,
|
.expect("completion should work");
|
||||||
),
|
|
||||||
super::EditorAction::Exit
|
|
||||||
));
|
|
||||||
|
|
||||||
let mut filled = InputBuffer::new();
|
assert!(matches.is_empty());
|
||||||
filled.insert('x');
|
}
|
||||||
assert!(matches!(
|
|
||||||
editor.handle_key(
|
#[test]
|
||||||
KeyEvent::new(KeyCode::Char('c'), KeyModifiers::CONTROL),
|
fn tracks_current_buffer_through_highlighter() {
|
||||||
&mut filled,
|
let helper = SlashCommandHelper::new(Vec::new());
|
||||||
),
|
let _ = helper.highlight("draft", 5);
|
||||||
super::EditorAction::Cancel
|
|
||||||
));
|
assert_eq!(helper.current_line(), "draft");
|
||||||
assert!(filled.as_str().is_empty());
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn push_history_ignores_blank_entries() {
|
||||||
|
let mut editor = LineEditor::new("> ", vec!["/help".to_string()]);
|
||||||
|
editor.push_history(" ");
|
||||||
|
editor.push_history("/help");
|
||||||
|
|
||||||
|
assert_eq!(editor.editor.history().len(), 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -21,6 +21,7 @@ pub struct ColorTheme {
|
|||||||
inline_code: Color,
|
inline_code: Color,
|
||||||
link: Color,
|
link: Color,
|
||||||
quote: Color,
|
quote: Color,
|
||||||
|
table_border: Color,
|
||||||
spinner_active: Color,
|
spinner_active: Color,
|
||||||
spinner_done: Color,
|
spinner_done: Color,
|
||||||
spinner_failed: Color,
|
spinner_failed: Color,
|
||||||
@@ -35,6 +36,7 @@ impl Default for ColorTheme {
|
|||||||
inline_code: Color::Green,
|
inline_code: Color::Green,
|
||||||
link: Color::Blue,
|
link: Color::Blue,
|
||||||
quote: Color::DarkGrey,
|
quote: Color::DarkGrey,
|
||||||
|
table_border: Color::DarkCyan,
|
||||||
spinner_active: Color::Blue,
|
spinner_active: Color::Blue,
|
||||||
spinner_done: Color::Green,
|
spinner_done: Color::Green,
|
||||||
spinner_failed: Color::Red,
|
spinner_failed: Color::Red,
|
||||||
@@ -113,24 +115,70 @@ impl Spinner {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
enum ListKind {
|
||||||
|
Unordered,
|
||||||
|
Ordered { next_index: u64 },
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Default, Clone, PartialEq, Eq)]
|
||||||
|
struct TableState {
|
||||||
|
headers: Vec<String>,
|
||||||
|
rows: Vec<Vec<String>>,
|
||||||
|
current_row: Vec<String>,
|
||||||
|
current_cell: String,
|
||||||
|
in_head: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TableState {
|
||||||
|
fn push_cell(&mut self) {
|
||||||
|
let cell = self.current_cell.trim().to_string();
|
||||||
|
self.current_row.push(cell);
|
||||||
|
self.current_cell.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn finish_row(&mut self) {
|
||||||
|
if self.current_row.is_empty() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
let row = std::mem::take(&mut self.current_row);
|
||||||
|
if self.in_head {
|
||||||
|
self.headers = row;
|
||||||
|
} else {
|
||||||
|
self.rows.push(row);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Default, Clone, PartialEq, Eq)]
|
#[derive(Debug, Default, Clone, PartialEq, Eq)]
|
||||||
struct RenderState {
|
struct RenderState {
|
||||||
emphasis: usize,
|
emphasis: usize,
|
||||||
strong: usize,
|
strong: usize,
|
||||||
quote: usize,
|
quote: usize,
|
||||||
list: usize,
|
list_stack: Vec<ListKind>,
|
||||||
|
table: Option<TableState>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RenderState {
|
impl RenderState {
|
||||||
fn style_text(&self, text: &str, theme: &ColorTheme) -> String {
|
fn style_text(&self, text: &str, theme: &ColorTheme) -> String {
|
||||||
|
let mut styled = text.to_string();
|
||||||
if self.strong > 0 {
|
if self.strong > 0 {
|
||||||
format!("{}", text.bold().with(theme.strong))
|
styled = format!("{}", styled.bold().with(theme.strong));
|
||||||
} else if self.emphasis > 0 {
|
}
|
||||||
format!("{}", text.italic().with(theme.emphasis))
|
if self.emphasis > 0 {
|
||||||
} else if self.quote > 0 {
|
styled = format!("{}", styled.italic().with(theme.emphasis));
|
||||||
format!("{}", text.with(theme.quote))
|
}
|
||||||
|
if self.quote > 0 {
|
||||||
|
styled = format!("{}", styled.with(theme.quote));
|
||||||
|
}
|
||||||
|
styled
|
||||||
|
}
|
||||||
|
|
||||||
|
fn capture_target_mut<'a>(&'a mut self, output: &'a mut String) -> &'a mut String {
|
||||||
|
if let Some(table) = self.table.as_mut() {
|
||||||
|
&mut table.current_cell
|
||||||
} else {
|
} else {
|
||||||
text.to_string()
|
output
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -190,6 +238,7 @@ impl TerminalRenderer {
|
|||||||
output.trim_end().to_string()
|
output.trim_end().to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_lines)]
|
||||||
fn render_event(
|
fn render_event(
|
||||||
&self,
|
&self,
|
||||||
event: Event<'_>,
|
event: Event<'_>,
|
||||||
@@ -203,12 +252,22 @@ impl TerminalRenderer {
|
|||||||
Event::Start(Tag::Heading { level, .. }) => self.start_heading(level as u8, output),
|
Event::Start(Tag::Heading { level, .. }) => self.start_heading(level as u8, output),
|
||||||
Event::End(TagEnd::Heading(..) | TagEnd::Paragraph) => output.push_str("\n\n"),
|
Event::End(TagEnd::Heading(..) | TagEnd::Paragraph) => output.push_str("\n\n"),
|
||||||
Event::Start(Tag::BlockQuote(..)) => self.start_quote(state, output),
|
Event::Start(Tag::BlockQuote(..)) => self.start_quote(state, output),
|
||||||
Event::End(TagEnd::BlockQuote(..) | TagEnd::Item)
|
Event::End(TagEnd::BlockQuote(..)) => {
|
||||||
| Event::SoftBreak
|
state.quote = state.quote.saturating_sub(1);
|
||||||
| Event::HardBreak => output.push('\n'),
|
output.push('\n');
|
||||||
Event::Start(Tag::List(_)) => state.list += 1,
|
}
|
||||||
|
Event::End(TagEnd::Item) | Event::SoftBreak | Event::HardBreak => {
|
||||||
|
state.capture_target_mut(output).push('\n');
|
||||||
|
}
|
||||||
|
Event::Start(Tag::List(first_item)) => {
|
||||||
|
let kind = match first_item {
|
||||||
|
Some(index) => ListKind::Ordered { next_index: index },
|
||||||
|
None => ListKind::Unordered,
|
||||||
|
};
|
||||||
|
state.list_stack.push(kind);
|
||||||
|
}
|
||||||
Event::End(TagEnd::List(..)) => {
|
Event::End(TagEnd::List(..)) => {
|
||||||
state.list = state.list.saturating_sub(1);
|
state.list_stack.pop();
|
||||||
output.push('\n');
|
output.push('\n');
|
||||||
}
|
}
|
||||||
Event::Start(Tag::Item) => Self::start_item(state, output),
|
Event::Start(Tag::Item) => Self::start_item(state, output),
|
||||||
@@ -232,57 +291,85 @@ impl TerminalRenderer {
|
|||||||
Event::Start(Tag::Strong) => state.strong += 1,
|
Event::Start(Tag::Strong) => state.strong += 1,
|
||||||
Event::End(TagEnd::Strong) => state.strong = state.strong.saturating_sub(1),
|
Event::End(TagEnd::Strong) => state.strong = state.strong.saturating_sub(1),
|
||||||
Event::Code(code) => {
|
Event::Code(code) => {
|
||||||
let _ = write!(
|
let rendered =
|
||||||
output,
|
format!("{}", format!("`{code}`").with(self.color_theme.inline_code));
|
||||||
"{}",
|
state.capture_target_mut(output).push_str(&rendered);
|
||||||
format!("`{code}`").with(self.color_theme.inline_code)
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
Event::Rule => output.push_str("---\n"),
|
Event::Rule => output.push_str("---\n"),
|
||||||
Event::Text(text) => {
|
Event::Text(text) => {
|
||||||
self.push_text(text.as_ref(), state, output, code_buffer, *in_code_block);
|
self.push_text(text.as_ref(), state, output, code_buffer, *in_code_block);
|
||||||
}
|
}
|
||||||
Event::Html(html) | Event::InlineHtml(html) => output.push_str(&html),
|
Event::Html(html) | Event::InlineHtml(html) => {
|
||||||
Event::FootnoteReference(reference) => {
|
state.capture_target_mut(output).push_str(&html);
|
||||||
let _ = write!(output, "[{reference}]");
|
}
|
||||||
|
Event::FootnoteReference(reference) => {
|
||||||
|
let _ = write!(state.capture_target_mut(output), "[{reference}]");
|
||||||
|
}
|
||||||
|
Event::TaskListMarker(done) => {
|
||||||
|
state
|
||||||
|
.capture_target_mut(output)
|
||||||
|
.push_str(if done { "[x] " } else { "[ ] " });
|
||||||
|
}
|
||||||
|
Event::InlineMath(math) | Event::DisplayMath(math) => {
|
||||||
|
state.capture_target_mut(output).push_str(&math);
|
||||||
}
|
}
|
||||||
Event::TaskListMarker(done) => output.push_str(if done { "[x] " } else { "[ ] " }),
|
|
||||||
Event::InlineMath(math) | Event::DisplayMath(math) => output.push_str(&math),
|
|
||||||
Event::Start(Tag::Link { dest_url, .. }) => {
|
Event::Start(Tag::Link { dest_url, .. }) => {
|
||||||
let _ = write!(
|
let rendered = format!(
|
||||||
output,
|
|
||||||
"{}",
|
"{}",
|
||||||
format!("[{dest_url}]")
|
format!("[{dest_url}]")
|
||||||
.underlined()
|
.underlined()
|
||||||
.with(self.color_theme.link)
|
.with(self.color_theme.link)
|
||||||
);
|
);
|
||||||
|
state.capture_target_mut(output).push_str(&rendered);
|
||||||
}
|
}
|
||||||
Event::Start(Tag::Image { dest_url, .. }) => {
|
Event::Start(Tag::Image { dest_url, .. }) => {
|
||||||
let _ = write!(
|
let rendered = format!(
|
||||||
output,
|
|
||||||
"{}",
|
"{}",
|
||||||
format!("[image:{dest_url}]").with(self.color_theme.link)
|
format!("[image:{dest_url}]").with(self.color_theme.link)
|
||||||
);
|
);
|
||||||
|
state.capture_target_mut(output).push_str(&rendered);
|
||||||
}
|
}
|
||||||
Event::Start(
|
Event::Start(Tag::Table(..)) => state.table = Some(TableState::default()),
|
||||||
Tag::Paragraph
|
Event::End(TagEnd::Table) => {
|
||||||
| Tag::Table(..)
|
if let Some(table) = state.table.take() {
|
||||||
| Tag::TableHead
|
output.push_str(&self.render_table(&table));
|
||||||
| Tag::TableRow
|
output.push_str("\n\n");
|
||||||
| Tag::TableCell
|
}
|
||||||
| Tag::MetadataBlock(..)
|
}
|
||||||
| _,
|
Event::Start(Tag::TableHead) => {
|
||||||
)
|
if let Some(table) = state.table.as_mut() {
|
||||||
| Event::End(
|
table.in_head = true;
|
||||||
TagEnd::Link
|
}
|
||||||
| TagEnd::Image
|
}
|
||||||
| TagEnd::Table
|
Event::End(TagEnd::TableHead) => {
|
||||||
| TagEnd::TableHead
|
if let Some(table) = state.table.as_mut() {
|
||||||
| TagEnd::TableRow
|
table.finish_row();
|
||||||
| TagEnd::TableCell
|
table.in_head = false;
|
||||||
| TagEnd::MetadataBlock(..)
|
}
|
||||||
| _,
|
}
|
||||||
) => {}
|
Event::Start(Tag::TableRow) => {
|
||||||
|
if let Some(table) = state.table.as_mut() {
|
||||||
|
table.current_row.clear();
|
||||||
|
table.current_cell.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Event::End(TagEnd::TableRow) => {
|
||||||
|
if let Some(table) = state.table.as_mut() {
|
||||||
|
table.finish_row();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Event::Start(Tag::TableCell) => {
|
||||||
|
if let Some(table) = state.table.as_mut() {
|
||||||
|
table.current_cell.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Event::End(TagEnd::TableCell) => {
|
||||||
|
if let Some(table) = state.table.as_mut() {
|
||||||
|
table.push_cell();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Event::Start(Tag::Paragraph | Tag::MetadataBlock(..) | _)
|
||||||
|
| Event::End(TagEnd::Link | TagEnd::Image | TagEnd::MetadataBlock(..) | _) => {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -302,9 +389,19 @@ impl TerminalRenderer {
|
|||||||
let _ = write!(output, "{}", "│ ".with(self.color_theme.quote));
|
let _ = write!(output, "{}", "│ ".with(self.color_theme.quote));
|
||||||
}
|
}
|
||||||
|
|
||||||
fn start_item(state: &RenderState, output: &mut String) {
|
fn start_item(state: &mut RenderState, output: &mut String) {
|
||||||
output.push_str(&" ".repeat(state.list.saturating_sub(1)));
|
let depth = state.list_stack.len().saturating_sub(1);
|
||||||
output.push_str("• ");
|
output.push_str(&" ".repeat(depth));
|
||||||
|
|
||||||
|
let marker = match state.list_stack.last_mut() {
|
||||||
|
Some(ListKind::Ordered { next_index }) => {
|
||||||
|
let value = *next_index;
|
||||||
|
*next_index += 1;
|
||||||
|
format!("{value}. ")
|
||||||
|
}
|
||||||
|
_ => "• ".to_string(),
|
||||||
|
};
|
||||||
|
output.push_str(&marker);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn start_code_block(&self, code_language: &str, output: &mut String) {
|
fn start_code_block(&self, code_language: &str, output: &mut String) {
|
||||||
@@ -328,7 +425,7 @@ impl TerminalRenderer {
|
|||||||
fn push_text(
|
fn push_text(
|
||||||
&self,
|
&self,
|
||||||
text: &str,
|
text: &str,
|
||||||
state: &RenderState,
|
state: &mut RenderState,
|
||||||
output: &mut String,
|
output: &mut String,
|
||||||
code_buffer: &mut String,
|
code_buffer: &mut String,
|
||||||
in_code_block: bool,
|
in_code_block: bool,
|
||||||
@@ -336,10 +433,82 @@ impl TerminalRenderer {
|
|||||||
if in_code_block {
|
if in_code_block {
|
||||||
code_buffer.push_str(text);
|
code_buffer.push_str(text);
|
||||||
} else {
|
} else {
|
||||||
output.push_str(&state.style_text(text, &self.color_theme));
|
let rendered = state.style_text(text, &self.color_theme);
|
||||||
|
state.capture_target_mut(output).push_str(&rendered);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn render_table(&self, table: &TableState) -> String {
|
||||||
|
let mut rows = Vec::new();
|
||||||
|
if !table.headers.is_empty() {
|
||||||
|
rows.push(table.headers.clone());
|
||||||
|
}
|
||||||
|
rows.extend(table.rows.iter().cloned());
|
||||||
|
|
||||||
|
if rows.is_empty() {
|
||||||
|
return String::new();
|
||||||
|
}
|
||||||
|
|
||||||
|
let column_count = rows.iter().map(Vec::len).max().unwrap_or(0);
|
||||||
|
let widths = (0..column_count)
|
||||||
|
.map(|column| {
|
||||||
|
rows.iter()
|
||||||
|
.filter_map(|row| row.get(column))
|
||||||
|
.map(|cell| visible_width(cell))
|
||||||
|
.max()
|
||||||
|
.unwrap_or(0)
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let border = format!("{}", "│".with(self.color_theme.table_border));
|
||||||
|
let separator = widths
|
||||||
|
.iter()
|
||||||
|
.map(|width| "─".repeat(*width + 2))
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join(&format!("{}", "┼".with(self.color_theme.table_border)));
|
||||||
|
let separator = format!("{border}{separator}{border}");
|
||||||
|
|
||||||
|
let mut output = String::new();
|
||||||
|
if !table.headers.is_empty() {
|
||||||
|
output.push_str(&self.render_table_row(&table.headers, &widths, true));
|
||||||
|
output.push('\n');
|
||||||
|
output.push_str(&separator);
|
||||||
|
if !table.rows.is_empty() {
|
||||||
|
output.push('\n');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (index, row) in table.rows.iter().enumerate() {
|
||||||
|
output.push_str(&self.render_table_row(row, &widths, false));
|
||||||
|
if index + 1 < table.rows.len() {
|
||||||
|
output.push('\n');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
output
|
||||||
|
}
|
||||||
|
|
||||||
|
fn render_table_row(&self, row: &[String], widths: &[usize], is_header: bool) -> String {
|
||||||
|
let border = format!("{}", "│".with(self.color_theme.table_border));
|
||||||
|
let mut line = String::new();
|
||||||
|
line.push_str(&border);
|
||||||
|
|
||||||
|
for (index, width) in widths.iter().enumerate() {
|
||||||
|
let cell = row.get(index).map_or("", String::as_str);
|
||||||
|
line.push(' ');
|
||||||
|
if is_header {
|
||||||
|
let _ = write!(line, "{}", cell.bold().with(self.color_theme.heading));
|
||||||
|
} else {
|
||||||
|
line.push_str(cell);
|
||||||
|
}
|
||||||
|
let padding = width.saturating_sub(visible_width(cell));
|
||||||
|
line.push_str(&" ".repeat(padding + 1));
|
||||||
|
line.push_str(&border);
|
||||||
|
}
|
||||||
|
|
||||||
|
line
|
||||||
|
}
|
||||||
|
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn highlight_code(&self, code: &str, language: &str) -> String {
|
pub fn highlight_code(&self, code: &str, language: &str) -> String {
|
||||||
let syntax = self
|
let syntax = self
|
||||||
@@ -372,32 +541,36 @@ impl TerminalRenderer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
fn visible_width(input: &str) -> usize {
|
||||||
mod tests {
|
strip_ansi(input).chars().count()
|
||||||
use super::{Spinner, TerminalRenderer};
|
}
|
||||||
|
|
||||||
fn strip_ansi(input: &str) -> String {
|
fn strip_ansi(input: &str) -> String {
|
||||||
let mut output = String::new();
|
let mut output = String::new();
|
||||||
let mut chars = input.chars().peekable();
|
let mut chars = input.chars().peekable();
|
||||||
|
|
||||||
while let Some(ch) = chars.next() {
|
while let Some(ch) = chars.next() {
|
||||||
if ch == '\u{1b}' {
|
if ch == '\u{1b}' {
|
||||||
if chars.peek() == Some(&'[') {
|
if chars.peek() == Some(&'[') {
|
||||||
chars.next();
|
chars.next();
|
||||||
for next in chars.by_ref() {
|
for next in chars.by_ref() {
|
||||||
if next.is_ascii_alphabetic() {
|
if next.is_ascii_alphabetic() {
|
||||||
break;
|
break;
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
output.push(ch);
|
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
output.push(ch);
|
||||||
}
|
}
|
||||||
|
|
||||||
output
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
output
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::{strip_ansi, Spinner, TerminalRenderer};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn renders_markdown_with_styling_and_lists() {
|
fn renders_markdown_with_styling_and_lists() {
|
||||||
let terminal_renderer = TerminalRenderer::new();
|
let terminal_renderer = TerminalRenderer::new();
|
||||||
@@ -422,6 +595,34 @@ mod tests {
|
|||||||
assert!(markdown_output.contains('\u{1b}'));
|
assert!(markdown_output.contains('\u{1b}'));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn renders_ordered_and_nested_lists() {
|
||||||
|
let terminal_renderer = TerminalRenderer::new();
|
||||||
|
let markdown_output =
|
||||||
|
terminal_renderer.render_markdown("1. first\n2. second\n - nested\n - child");
|
||||||
|
let plain_text = strip_ansi(&markdown_output);
|
||||||
|
|
||||||
|
assert!(plain_text.contains("1. first"));
|
||||||
|
assert!(plain_text.contains("2. second"));
|
||||||
|
assert!(plain_text.contains(" • nested"));
|
||||||
|
assert!(plain_text.contains(" • child"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn renders_tables_with_alignment() {
|
||||||
|
let terminal_renderer = TerminalRenderer::new();
|
||||||
|
let markdown_output = terminal_renderer
|
||||||
|
.render_markdown("| Name | Value |\n| ---- | ----- |\n| alpha | 1 |\n| beta | 22 |");
|
||||||
|
let plain_text = strip_ansi(&markdown_output);
|
||||||
|
let lines = plain_text.lines().collect::<Vec<_>>();
|
||||||
|
|
||||||
|
assert_eq!(lines[0], "│ Name │ Value │");
|
||||||
|
assert_eq!(lines[1], "│───────┼───────│");
|
||||||
|
assert_eq!(lines[2], "│ alpha │ 1 │");
|
||||||
|
assert_eq!(lines[3], "│ beta │ 22 │");
|
||||||
|
assert!(markdown_output.contains('\u{1b}'));
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn spinner_advances_frames() {
|
fn spinner_advances_frames() {
|
||||||
let terminal_renderer = TerminalRenderer::new();
|
let terminal_renderer = TerminalRenderer::new();
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ use std::time::{Duration, Instant};
|
|||||||
use reqwest::blocking::Client;
|
use reqwest::blocking::Client;
|
||||||
use runtime::{
|
use runtime::{
|
||||||
edit_file, execute_bash, glob_search, grep_search, read_file, write_file, BashCommandInput,
|
edit_file, execute_bash, glob_search, grep_search, read_file, write_file, BashCommandInput,
|
||||||
GrepSearchInput,
|
GrepSearchInput, PermissionMode,
|
||||||
};
|
};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::{json, Value};
|
use serde_json::{json, Value};
|
||||||
@@ -45,6 +45,7 @@ pub struct ToolSpec {
|
|||||||
pub name: &'static str,
|
pub name: &'static str,
|
||||||
pub description: &'static str,
|
pub description: &'static str,
|
||||||
pub input_schema: Value,
|
pub input_schema: Value,
|
||||||
|
pub required_permission: PermissionMode,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[must_use]
|
#[must_use]
|
||||||
@@ -66,6 +67,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
|
|||||||
"required": ["command"],
|
"required": ["command"],
|
||||||
"additionalProperties": false
|
"additionalProperties": false
|
||||||
}),
|
}),
|
||||||
|
required_permission: PermissionMode::DangerFullAccess,
|
||||||
},
|
},
|
||||||
ToolSpec {
|
ToolSpec {
|
||||||
name: "read_file",
|
name: "read_file",
|
||||||
@@ -80,6 +82,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
|
|||||||
"required": ["path"],
|
"required": ["path"],
|
||||||
"additionalProperties": false
|
"additionalProperties": false
|
||||||
}),
|
}),
|
||||||
|
required_permission: PermissionMode::ReadOnly,
|
||||||
},
|
},
|
||||||
ToolSpec {
|
ToolSpec {
|
||||||
name: "write_file",
|
name: "write_file",
|
||||||
@@ -93,6 +96,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
|
|||||||
"required": ["path", "content"],
|
"required": ["path", "content"],
|
||||||
"additionalProperties": false
|
"additionalProperties": false
|
||||||
}),
|
}),
|
||||||
|
required_permission: PermissionMode::WorkspaceWrite,
|
||||||
},
|
},
|
||||||
ToolSpec {
|
ToolSpec {
|
||||||
name: "edit_file",
|
name: "edit_file",
|
||||||
@@ -108,6 +112,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
|
|||||||
"required": ["path", "old_string", "new_string"],
|
"required": ["path", "old_string", "new_string"],
|
||||||
"additionalProperties": false
|
"additionalProperties": false
|
||||||
}),
|
}),
|
||||||
|
required_permission: PermissionMode::WorkspaceWrite,
|
||||||
},
|
},
|
||||||
ToolSpec {
|
ToolSpec {
|
||||||
name: "glob_search",
|
name: "glob_search",
|
||||||
@@ -121,6 +126,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
|
|||||||
"required": ["pattern"],
|
"required": ["pattern"],
|
||||||
"additionalProperties": false
|
"additionalProperties": false
|
||||||
}),
|
}),
|
||||||
|
required_permission: PermissionMode::ReadOnly,
|
||||||
},
|
},
|
||||||
ToolSpec {
|
ToolSpec {
|
||||||
name: "grep_search",
|
name: "grep_search",
|
||||||
@@ -146,6 +152,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
|
|||||||
"required": ["pattern"],
|
"required": ["pattern"],
|
||||||
"additionalProperties": false
|
"additionalProperties": false
|
||||||
}),
|
}),
|
||||||
|
required_permission: PermissionMode::ReadOnly,
|
||||||
},
|
},
|
||||||
ToolSpec {
|
ToolSpec {
|
||||||
name: "WebFetch",
|
name: "WebFetch",
|
||||||
@@ -160,6 +167,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
|
|||||||
"required": ["url", "prompt"],
|
"required": ["url", "prompt"],
|
||||||
"additionalProperties": false
|
"additionalProperties": false
|
||||||
}),
|
}),
|
||||||
|
required_permission: PermissionMode::ReadOnly,
|
||||||
},
|
},
|
||||||
ToolSpec {
|
ToolSpec {
|
||||||
name: "WebSearch",
|
name: "WebSearch",
|
||||||
@@ -180,6 +188,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
|
|||||||
"required": ["query"],
|
"required": ["query"],
|
||||||
"additionalProperties": false
|
"additionalProperties": false
|
||||||
}),
|
}),
|
||||||
|
required_permission: PermissionMode::ReadOnly,
|
||||||
},
|
},
|
||||||
ToolSpec {
|
ToolSpec {
|
||||||
name: "TodoWrite",
|
name: "TodoWrite",
|
||||||
@@ -207,6 +216,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
|
|||||||
"required": ["todos"],
|
"required": ["todos"],
|
||||||
"additionalProperties": false
|
"additionalProperties": false
|
||||||
}),
|
}),
|
||||||
|
required_permission: PermissionMode::WorkspaceWrite,
|
||||||
},
|
},
|
||||||
ToolSpec {
|
ToolSpec {
|
||||||
name: "Skill",
|
name: "Skill",
|
||||||
@@ -220,6 +230,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
|
|||||||
"required": ["skill"],
|
"required": ["skill"],
|
||||||
"additionalProperties": false
|
"additionalProperties": false
|
||||||
}),
|
}),
|
||||||
|
required_permission: PermissionMode::ReadOnly,
|
||||||
},
|
},
|
||||||
ToolSpec {
|
ToolSpec {
|
||||||
name: "Agent",
|
name: "Agent",
|
||||||
@@ -236,6 +247,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
|
|||||||
"required": ["description", "prompt"],
|
"required": ["description", "prompt"],
|
||||||
"additionalProperties": false
|
"additionalProperties": false
|
||||||
}),
|
}),
|
||||||
|
required_permission: PermissionMode::DangerFullAccess,
|
||||||
},
|
},
|
||||||
ToolSpec {
|
ToolSpec {
|
||||||
name: "ToolSearch",
|
name: "ToolSearch",
|
||||||
@@ -249,6 +261,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
|
|||||||
"required": ["query"],
|
"required": ["query"],
|
||||||
"additionalProperties": false
|
"additionalProperties": false
|
||||||
}),
|
}),
|
||||||
|
required_permission: PermissionMode::ReadOnly,
|
||||||
},
|
},
|
||||||
ToolSpec {
|
ToolSpec {
|
||||||
name: "NotebookEdit",
|
name: "NotebookEdit",
|
||||||
@@ -265,6 +278,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
|
|||||||
"required": ["notebook_path"],
|
"required": ["notebook_path"],
|
||||||
"additionalProperties": false
|
"additionalProperties": false
|
||||||
}),
|
}),
|
||||||
|
required_permission: PermissionMode::WorkspaceWrite,
|
||||||
},
|
},
|
||||||
ToolSpec {
|
ToolSpec {
|
||||||
name: "Sleep",
|
name: "Sleep",
|
||||||
@@ -277,6 +291,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
|
|||||||
"required": ["duration_ms"],
|
"required": ["duration_ms"],
|
||||||
"additionalProperties": false
|
"additionalProperties": false
|
||||||
}),
|
}),
|
||||||
|
required_permission: PermissionMode::ReadOnly,
|
||||||
},
|
},
|
||||||
ToolSpec {
|
ToolSpec {
|
||||||
name: "SendUserMessage",
|
name: "SendUserMessage",
|
||||||
@@ -297,6 +312,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
|
|||||||
"required": ["message", "status"],
|
"required": ["message", "status"],
|
||||||
"additionalProperties": false
|
"additionalProperties": false
|
||||||
}),
|
}),
|
||||||
|
required_permission: PermissionMode::ReadOnly,
|
||||||
},
|
},
|
||||||
ToolSpec {
|
ToolSpec {
|
||||||
name: "Config",
|
name: "Config",
|
||||||
@@ -312,6 +328,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
|
|||||||
"required": ["setting"],
|
"required": ["setting"],
|
||||||
"additionalProperties": false
|
"additionalProperties": false
|
||||||
}),
|
}),
|
||||||
|
required_permission: PermissionMode::WorkspaceWrite,
|
||||||
},
|
},
|
||||||
ToolSpec {
|
ToolSpec {
|
||||||
name: "StructuredOutput",
|
name: "StructuredOutput",
|
||||||
@@ -320,6 +337,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
|
|||||||
"type": "object",
|
"type": "object",
|
||||||
"additionalProperties": true
|
"additionalProperties": true
|
||||||
}),
|
}),
|
||||||
|
required_permission: PermissionMode::ReadOnly,
|
||||||
},
|
},
|
||||||
ToolSpec {
|
ToolSpec {
|
||||||
name: "REPL",
|
name: "REPL",
|
||||||
@@ -334,6 +352,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
|
|||||||
"required": ["code", "language"],
|
"required": ["code", "language"],
|
||||||
"additionalProperties": false
|
"additionalProperties": false
|
||||||
}),
|
}),
|
||||||
|
required_permission: PermissionMode::DangerFullAccess,
|
||||||
},
|
},
|
||||||
ToolSpec {
|
ToolSpec {
|
||||||
name: "PowerShell",
|
name: "PowerShell",
|
||||||
@@ -349,6 +368,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> {
|
|||||||
"required": ["command"],
|
"required": ["command"],
|
||||||
"additionalProperties": false
|
"additionalProperties": false
|
||||||
}),
|
}),
|
||||||
|
required_permission: PermissionMode::DangerFullAccess,
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
@@ -2194,7 +2214,8 @@ fn execute_shell_command(
|
|||||||
structured_content: None,
|
structured_content: None,
|
||||||
persisted_output_path: None,
|
persisted_output_path: None,
|
||||||
persisted_output_size: None,
|
persisted_output_size: None,
|
||||||
});
|
sandbox_status: None,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut process = std::process::Command::new(shell);
|
let mut process = std::process::Command::new(shell);
|
||||||
@@ -2231,6 +2252,7 @@ fn execute_shell_command(
|
|||||||
structured_content: None,
|
structured_content: None,
|
||||||
persisted_output_path: None,
|
persisted_output_path: None,
|
||||||
persisted_output_size: None,
|
persisted_output_size: None,
|
||||||
|
sandbox_status: None,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
if started.elapsed() >= Duration::from_millis(timeout_ms) {
|
if started.elapsed() >= Duration::from_millis(timeout_ms) {
|
||||||
@@ -2261,7 +2283,8 @@ Command exceeded timeout of {timeout_ms} ms",
|
|||||||
structured_content: None,
|
structured_content: None,
|
||||||
persisted_output_path: None,
|
persisted_output_path: None,
|
||||||
persisted_output_size: None,
|
persisted_output_size: None,
|
||||||
});
|
sandbox_status: None,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
std::thread::sleep(Duration::from_millis(10));
|
std::thread::sleep(Duration::from_millis(10));
|
||||||
}
|
}
|
||||||
@@ -2287,6 +2310,7 @@ Command exceeded timeout of {timeout_ms} ms",
|
|||||||
structured_content: None,
|
structured_content: None,
|
||||||
persisted_output_path: None,
|
persisted_output_path: None,
|
||||||
persisted_output_size: None,
|
persisted_output_size: None,
|
||||||
|
sandbox_status: None,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2349,8 +2373,10 @@ fn parse_skill_description(contents: &str) -> Option<String> {
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
use std::fs;
|
||||||
use std::io::{Read, Write};
|
use std::io::{Read, Write};
|
||||||
use std::net::{SocketAddr, TcpListener};
|
use std::net::{SocketAddr, TcpListener};
|
||||||
|
use std::path::PathBuf;
|
||||||
use std::sync::{Arc, Mutex, OnceLock};
|
use std::sync::{Arc, Mutex, OnceLock};
|
||||||
use std::thread;
|
use std::thread;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
@@ -2363,6 +2389,14 @@ mod tests {
|
|||||||
LOCK.get_or_init(|| Mutex::new(()))
|
LOCK.get_or_init(|| Mutex::new(()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn temp_path(name: &str) -> PathBuf {
|
||||||
|
let unique = std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.expect("time")
|
||||||
|
.as_nanos();
|
||||||
|
std::env::temp_dir().join(format!("clawd-tools-{unique}-{name}"))
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn exposes_mvp_tools() {
|
fn exposes_mvp_tools() {
|
||||||
let names = mvp_tool_specs()
|
let names = mvp_tool_specs()
|
||||||
@@ -2432,6 +2466,40 @@ mod tests {
|
|||||||
assert!(titled_summary.contains("Title: Ignored"));
|
assert!(titled_summary.contains("Title: Ignored"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn web_fetch_supports_plain_text_and_rejects_invalid_url() {
|
||||||
|
let server = TestServer::spawn(Arc::new(|request_line: &str| {
|
||||||
|
assert!(request_line.starts_with("GET /plain "));
|
||||||
|
HttpResponse::text(200, "OK", "plain text response")
|
||||||
|
}));
|
||||||
|
|
||||||
|
let result = execute_tool(
|
||||||
|
"WebFetch",
|
||||||
|
&json!({
|
||||||
|
"url": format!("http://{}/plain", server.addr()),
|
||||||
|
"prompt": "Show me the content"
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.expect("WebFetch should succeed for text content");
|
||||||
|
|
||||||
|
let output: serde_json::Value = serde_json::from_str(&result).expect("valid json");
|
||||||
|
assert_eq!(output["url"], format!("http://{}/plain", server.addr()));
|
||||||
|
assert!(output["result"]
|
||||||
|
.as_str()
|
||||||
|
.expect("result")
|
||||||
|
.contains("plain text response"));
|
||||||
|
|
||||||
|
let error = execute_tool(
|
||||||
|
"WebFetch",
|
||||||
|
&json!({
|
||||||
|
"url": "not a url",
|
||||||
|
"prompt": "Summarize"
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.expect_err("invalid URL should fail");
|
||||||
|
assert!(error.contains("relative URL without a base") || error.contains("invalid"));
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn web_search_extracts_and_filters_results() {
|
fn web_search_extracts_and_filters_results() {
|
||||||
let server = TestServer::spawn(Arc::new(|request_line: &str| {
|
let server = TestServer::spawn(Arc::new(|request_line: &str| {
|
||||||
@@ -2476,15 +2544,63 @@ mod tests {
|
|||||||
assert_eq!(content[0]["url"], "https://docs.rs/reqwest");
|
assert_eq!(content[0]["url"], "https://docs.rs/reqwest");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn web_search_handles_generic_links_and_invalid_base_url() {
|
||||||
|
let _guard = env_lock()
|
||||||
|
.lock()
|
||||||
|
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||||
|
let server = TestServer::spawn(Arc::new(|request_line: &str| {
|
||||||
|
assert!(request_line.contains("GET /fallback?q=generic+links "));
|
||||||
|
HttpResponse::html(
|
||||||
|
200,
|
||||||
|
"OK",
|
||||||
|
r#"
|
||||||
|
<html><body>
|
||||||
|
<a href="https://example.com/one">Example One</a>
|
||||||
|
<a href="https://example.com/one">Duplicate Example One</a>
|
||||||
|
<a href="https://docs.rs/tokio">Tokio Docs</a>
|
||||||
|
</body></html>
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
}));
|
||||||
|
|
||||||
|
std::env::set_var(
|
||||||
|
"CLAWD_WEB_SEARCH_BASE_URL",
|
||||||
|
format!("http://{}/fallback", server.addr()),
|
||||||
|
);
|
||||||
|
let result = execute_tool(
|
||||||
|
"WebSearch",
|
||||||
|
&json!({
|
||||||
|
"query": "generic links"
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.expect("WebSearch fallback parsing should succeed");
|
||||||
|
std::env::remove_var("CLAWD_WEB_SEARCH_BASE_URL");
|
||||||
|
|
||||||
|
let output: serde_json::Value = serde_json::from_str(&result).expect("valid json");
|
||||||
|
let results = output["results"].as_array().expect("results array");
|
||||||
|
let search_result = results
|
||||||
|
.iter()
|
||||||
|
.find(|item| item.get("content").is_some())
|
||||||
|
.expect("search result block present");
|
||||||
|
let content = search_result["content"].as_array().expect("content array");
|
||||||
|
assert_eq!(content.len(), 2);
|
||||||
|
assert_eq!(content[0]["url"], "https://example.com/one");
|
||||||
|
assert_eq!(content[1]["url"], "https://docs.rs/tokio");
|
||||||
|
|
||||||
|
std::env::set_var("CLAWD_WEB_SEARCH_BASE_URL", "://bad-base-url");
|
||||||
|
let error = execute_tool("WebSearch", &json!({ "query": "generic links" }))
|
||||||
|
.expect_err("invalid base URL should fail");
|
||||||
|
std::env::remove_var("CLAWD_WEB_SEARCH_BASE_URL");
|
||||||
|
assert!(error.contains("relative URL without a base") || error.contains("empty host"));
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn todo_write_persists_and_returns_previous_state() {
|
fn todo_write_persists_and_returns_previous_state() {
|
||||||
let path = std::env::temp_dir().join(format!(
|
let _guard = env_lock()
|
||||||
"clawd-tools-todos-{}.json",
|
.lock()
|
||||||
std::time::SystemTime::now()
|
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||||
.duration_since(std::time::UNIX_EPOCH)
|
let path = temp_path("todos.json");
|
||||||
.expect("time")
|
|
||||||
.as_nanos()
|
|
||||||
));
|
|
||||||
std::env::set_var("CLAWD_TODO_STORE", &path);
|
std::env::set_var("CLAWD_TODO_STORE", &path);
|
||||||
|
|
||||||
let first = execute_tool(
|
let first = execute_tool(
|
||||||
@@ -2526,6 +2642,59 @@ mod tests {
|
|||||||
assert!(second_output["verificationNudgeNeeded"].is_null());
|
assert!(second_output["verificationNudgeNeeded"].is_null());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn todo_write_rejects_invalid_payloads_and_sets_verification_nudge() {
|
||||||
|
let _guard = env_lock()
|
||||||
|
.lock()
|
||||||
|
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||||
|
let path = temp_path("todos-errors.json");
|
||||||
|
std::env::set_var("CLAWD_TODO_STORE", &path);
|
||||||
|
|
||||||
|
let empty = execute_tool("TodoWrite", &json!({ "todos": [] }))
|
||||||
|
.expect_err("empty todos should fail");
|
||||||
|
assert!(empty.contains("todos must not be empty"));
|
||||||
|
|
||||||
|
let too_many_active = execute_tool(
|
||||||
|
"TodoWrite",
|
||||||
|
&json!({
|
||||||
|
"todos": [
|
||||||
|
{"content": "One", "activeForm": "Doing one", "status": "in_progress"},
|
||||||
|
{"content": "Two", "activeForm": "Doing two", "status": "in_progress"}
|
||||||
|
]
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.expect_err("multiple in-progress todos should fail");
|
||||||
|
assert!(too_many_active.contains("zero or one todo items may be in_progress"));
|
||||||
|
|
||||||
|
let blank_content = execute_tool(
|
||||||
|
"TodoWrite",
|
||||||
|
&json!({
|
||||||
|
"todos": [
|
||||||
|
{"content": " ", "activeForm": "Doing it", "status": "pending"}
|
||||||
|
]
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.expect_err("blank content should fail");
|
||||||
|
assert!(blank_content.contains("todo content must not be empty"));
|
||||||
|
|
||||||
|
let nudge = execute_tool(
|
||||||
|
"TodoWrite",
|
||||||
|
&json!({
|
||||||
|
"todos": [
|
||||||
|
{"content": "Write tests", "activeForm": "Writing tests", "status": "completed"},
|
||||||
|
{"content": "Fix errors", "activeForm": "Fixing errors", "status": "completed"},
|
||||||
|
{"content": "Ship branch", "activeForm": "Shipping branch", "status": "completed"}
|
||||||
|
]
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.expect("completed todos should succeed");
|
||||||
|
std::env::remove_var("CLAWD_TODO_STORE");
|
||||||
|
let _ = fs::remove_file(path);
|
||||||
|
|
||||||
|
let output: serde_json::Value = serde_json::from_str(&nudge).expect("valid json");
|
||||||
|
assert_eq!(output["verificationNudgeNeeded"], true);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn skill_loads_local_skill_prompt() {
|
fn skill_loads_local_skill_prompt() {
|
||||||
let result = execute_tool(
|
let result = execute_tool(
|
||||||
@@ -2599,13 +2768,10 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn agent_persists_handoff_metadata() {
|
fn agent_persists_handoff_metadata() {
|
||||||
let dir = std::env::temp_dir().join(format!(
|
let _guard = env_lock()
|
||||||
"clawd-agent-store-{}",
|
.lock()
|
||||||
std::time::SystemTime::now()
|
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||||
.duration_since(std::time::UNIX_EPOCH)
|
let dir = temp_path("agent-store");
|
||||||
.expect("time")
|
|
||||||
.as_nanos()
|
|
||||||
));
|
|
||||||
std::env::set_var("CLAWD_AGENT_STORE", &dir);
|
std::env::set_var("CLAWD_AGENT_STORE", &dir);
|
||||||
|
|
||||||
let result = execute_tool(
|
let result = execute_tool(
|
||||||
@@ -2661,15 +2827,32 @@ mod tests {
|
|||||||
let _ = std::fs::remove_dir_all(dir);
|
let _ = std::fs::remove_dir_all(dir);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn agent_rejects_blank_required_fields() {
|
||||||
|
let missing_description = execute_tool(
|
||||||
|
"Agent",
|
||||||
|
&json!({
|
||||||
|
"description": " ",
|
||||||
|
"prompt": "Inspect"
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.expect_err("blank description should fail");
|
||||||
|
assert!(missing_description.contains("description must not be empty"));
|
||||||
|
|
||||||
|
let missing_prompt = execute_tool(
|
||||||
|
"Agent",
|
||||||
|
&json!({
|
||||||
|
"description": "Inspect branch",
|
||||||
|
"prompt": " "
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.expect_err("blank prompt should fail");
|
||||||
|
assert!(missing_prompt.contains("prompt must not be empty"));
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn notebook_edit_replaces_inserts_and_deletes_cells() {
|
fn notebook_edit_replaces_inserts_and_deletes_cells() {
|
||||||
let path = std::env::temp_dir().join(format!(
|
let path = temp_path("notebook.ipynb");
|
||||||
"clawd-notebook-{}.ipynb",
|
|
||||||
std::time::SystemTime::now()
|
|
||||||
.duration_since(std::time::UNIX_EPOCH)
|
|
||||||
.expect("time")
|
|
||||||
.as_nanos()
|
|
||||||
));
|
|
||||||
std::fs::write(
|
std::fs::write(
|
||||||
&path,
|
&path,
|
||||||
r#"{
|
r#"{
|
||||||
@@ -2747,6 +2930,270 @@ mod tests {
|
|||||||
let _ = std::fs::remove_file(path);
|
let _ = std::fs::remove_file(path);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn notebook_edit_rejects_invalid_inputs() {
|
||||||
|
let text_path = temp_path("notebook.txt");
|
||||||
|
fs::write(&text_path, "not a notebook").expect("write text file");
|
||||||
|
let wrong_extension = execute_tool(
|
||||||
|
"NotebookEdit",
|
||||||
|
&json!({
|
||||||
|
"notebook_path": text_path.display().to_string(),
|
||||||
|
"new_source": "print(1)\n"
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.expect_err("non-ipynb file should fail");
|
||||||
|
assert!(wrong_extension.contains("Jupyter notebook"));
|
||||||
|
let _ = fs::remove_file(&text_path);
|
||||||
|
|
||||||
|
let empty_notebook = temp_path("empty.ipynb");
|
||||||
|
fs::write(
|
||||||
|
&empty_notebook,
|
||||||
|
r#"{"cells":[],"metadata":{"kernelspec":{"language":"python"}},"nbformat":4,"nbformat_minor":5}"#,
|
||||||
|
)
|
||||||
|
.expect("write empty notebook");
|
||||||
|
|
||||||
|
let missing_source = execute_tool(
|
||||||
|
"NotebookEdit",
|
||||||
|
&json!({
|
||||||
|
"notebook_path": empty_notebook.display().to_string(),
|
||||||
|
"edit_mode": "insert"
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.expect_err("insert without source should fail");
|
||||||
|
assert!(missing_source.contains("new_source is required"));
|
||||||
|
|
||||||
|
let missing_cell = execute_tool(
|
||||||
|
"NotebookEdit",
|
||||||
|
&json!({
|
||||||
|
"notebook_path": empty_notebook.display().to_string(),
|
||||||
|
"edit_mode": "delete"
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.expect_err("delete on empty notebook should fail");
|
||||||
|
assert!(missing_cell.contains("Notebook has no cells to edit"));
|
||||||
|
let _ = fs::remove_file(empty_notebook);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn bash_tool_reports_success_exit_failure_timeout_and_background() {
|
||||||
|
let success = execute_tool("bash", &json!({ "command": "printf 'hello'" }))
|
||||||
|
.expect("bash should succeed");
|
||||||
|
let success_output: serde_json::Value = serde_json::from_str(&success).expect("json");
|
||||||
|
assert_eq!(success_output["stdout"], "hello");
|
||||||
|
assert_eq!(success_output["interrupted"], false);
|
||||||
|
|
||||||
|
let failure = execute_tool("bash", &json!({ "command": "printf 'oops' >&2; exit 7" }))
|
||||||
|
.expect("bash failure should still return structured output");
|
||||||
|
let failure_output: serde_json::Value = serde_json::from_str(&failure).expect("json");
|
||||||
|
assert_eq!(failure_output["returnCodeInterpretation"], "exit_code:7");
|
||||||
|
assert!(failure_output["stderr"]
|
||||||
|
.as_str()
|
||||||
|
.expect("stderr")
|
||||||
|
.contains("oops"));
|
||||||
|
|
||||||
|
let timeout = execute_tool("bash", &json!({ "command": "sleep 1", "timeout": 10 }))
|
||||||
|
.expect("bash timeout should return output");
|
||||||
|
let timeout_output: serde_json::Value = serde_json::from_str(&timeout).expect("json");
|
||||||
|
assert_eq!(timeout_output["interrupted"], true);
|
||||||
|
assert_eq!(timeout_output["returnCodeInterpretation"], "timeout");
|
||||||
|
assert!(timeout_output["stderr"]
|
||||||
|
.as_str()
|
||||||
|
.expect("stderr")
|
||||||
|
.contains("Command exceeded timeout"));
|
||||||
|
|
||||||
|
let background = execute_tool(
|
||||||
|
"bash",
|
||||||
|
&json!({ "command": "sleep 1", "run_in_background": true }),
|
||||||
|
)
|
||||||
|
.expect("bash background should succeed");
|
||||||
|
let background_output: serde_json::Value = serde_json::from_str(&background).expect("json");
|
||||||
|
assert!(background_output["backgroundTaskId"].as_str().is_some());
|
||||||
|
assert_eq!(background_output["noOutputExpected"], true);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn file_tools_cover_read_write_and_edit_behaviors() {
|
||||||
|
let _guard = env_lock()
|
||||||
|
.lock()
|
||||||
|
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||||
|
let root = temp_path("fs-suite");
|
||||||
|
fs::create_dir_all(&root).expect("create root");
|
||||||
|
let original_dir = std::env::current_dir().expect("cwd");
|
||||||
|
std::env::set_current_dir(&root).expect("set cwd");
|
||||||
|
|
||||||
|
let write_create = execute_tool(
|
||||||
|
"write_file",
|
||||||
|
&json!({ "path": "nested/demo.txt", "content": "alpha\nbeta\nalpha\n" }),
|
||||||
|
)
|
||||||
|
.expect("write create should succeed");
|
||||||
|
let write_create_output: serde_json::Value =
|
||||||
|
serde_json::from_str(&write_create).expect("json");
|
||||||
|
assert_eq!(write_create_output["type"], "create");
|
||||||
|
assert!(root.join("nested/demo.txt").exists());
|
||||||
|
|
||||||
|
let write_update = execute_tool(
|
||||||
|
"write_file",
|
||||||
|
&json!({ "path": "nested/demo.txt", "content": "alpha\nbeta\ngamma\n" }),
|
||||||
|
)
|
||||||
|
.expect("write update should succeed");
|
||||||
|
let write_update_output: serde_json::Value =
|
||||||
|
serde_json::from_str(&write_update).expect("json");
|
||||||
|
assert_eq!(write_update_output["type"], "update");
|
||||||
|
assert_eq!(write_update_output["originalFile"], "alpha\nbeta\nalpha\n");
|
||||||
|
|
||||||
|
let read_full = execute_tool("read_file", &json!({ "path": "nested/demo.txt" }))
|
||||||
|
.expect("read full should succeed");
|
||||||
|
let read_full_output: serde_json::Value = serde_json::from_str(&read_full).expect("json");
|
||||||
|
assert_eq!(read_full_output["file"]["content"], "alpha\nbeta\ngamma");
|
||||||
|
assert_eq!(read_full_output["file"]["startLine"], 1);
|
||||||
|
|
||||||
|
let read_slice = execute_tool(
|
||||||
|
"read_file",
|
||||||
|
&json!({ "path": "nested/demo.txt", "offset": 1, "limit": 1 }),
|
||||||
|
)
|
||||||
|
.expect("read slice should succeed");
|
||||||
|
let read_slice_output: serde_json::Value = serde_json::from_str(&read_slice).expect("json");
|
||||||
|
assert_eq!(read_slice_output["file"]["content"], "beta");
|
||||||
|
assert_eq!(read_slice_output["file"]["startLine"], 2);
|
||||||
|
|
||||||
|
let read_past_end = execute_tool(
|
||||||
|
"read_file",
|
||||||
|
&json!({ "path": "nested/demo.txt", "offset": 50 }),
|
||||||
|
)
|
||||||
|
.expect("read past EOF should succeed");
|
||||||
|
let read_past_end_output: serde_json::Value =
|
||||||
|
serde_json::from_str(&read_past_end).expect("json");
|
||||||
|
assert_eq!(read_past_end_output["file"]["content"], "");
|
||||||
|
assert_eq!(read_past_end_output["file"]["startLine"], 4);
|
||||||
|
|
||||||
|
let read_error = execute_tool("read_file", &json!({ "path": "missing.txt" }))
|
||||||
|
.expect_err("missing file should fail");
|
||||||
|
assert!(!read_error.is_empty());
|
||||||
|
|
||||||
|
let edit_once = execute_tool(
|
||||||
|
"edit_file",
|
||||||
|
&json!({ "path": "nested/demo.txt", "old_string": "alpha", "new_string": "omega" }),
|
||||||
|
)
|
||||||
|
.expect("single edit should succeed");
|
||||||
|
let edit_once_output: serde_json::Value = serde_json::from_str(&edit_once).expect("json");
|
||||||
|
assert_eq!(edit_once_output["replaceAll"], false);
|
||||||
|
assert_eq!(
|
||||||
|
fs::read_to_string(root.join("nested/demo.txt")).expect("read file"),
|
||||||
|
"omega\nbeta\ngamma\n"
|
||||||
|
);
|
||||||
|
|
||||||
|
execute_tool(
|
||||||
|
"write_file",
|
||||||
|
&json!({ "path": "nested/demo.txt", "content": "alpha\nbeta\nalpha\n" }),
|
||||||
|
)
|
||||||
|
.expect("reset file");
|
||||||
|
let edit_all = execute_tool(
|
||||||
|
"edit_file",
|
||||||
|
&json!({
|
||||||
|
"path": "nested/demo.txt",
|
||||||
|
"old_string": "alpha",
|
||||||
|
"new_string": "omega",
|
||||||
|
"replace_all": true
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.expect("replace all should succeed");
|
||||||
|
let edit_all_output: serde_json::Value = serde_json::from_str(&edit_all).expect("json");
|
||||||
|
assert_eq!(edit_all_output["replaceAll"], true);
|
||||||
|
assert_eq!(
|
||||||
|
fs::read_to_string(root.join("nested/demo.txt")).expect("read file"),
|
||||||
|
"omega\nbeta\nomega\n"
|
||||||
|
);
|
||||||
|
|
||||||
|
let edit_same = execute_tool(
|
||||||
|
"edit_file",
|
||||||
|
&json!({ "path": "nested/demo.txt", "old_string": "omega", "new_string": "omega" }),
|
||||||
|
)
|
||||||
|
.expect_err("identical old/new should fail");
|
||||||
|
assert!(edit_same.contains("must differ"));
|
||||||
|
|
||||||
|
let edit_missing = execute_tool(
|
||||||
|
"edit_file",
|
||||||
|
&json!({ "path": "nested/demo.txt", "old_string": "missing", "new_string": "omega" }),
|
||||||
|
)
|
||||||
|
.expect_err("missing substring should fail");
|
||||||
|
assert!(edit_missing.contains("old_string not found"));
|
||||||
|
|
||||||
|
std::env::set_current_dir(&original_dir).expect("restore cwd");
|
||||||
|
let _ = fs::remove_dir_all(root);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn glob_and_grep_tools_cover_success_and_errors() {
|
||||||
|
let _guard = env_lock()
|
||||||
|
.lock()
|
||||||
|
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||||
|
let root = temp_path("search-suite");
|
||||||
|
fs::create_dir_all(root.join("nested")).expect("create root");
|
||||||
|
let original_dir = std::env::current_dir().expect("cwd");
|
||||||
|
std::env::set_current_dir(&root).expect("set cwd");
|
||||||
|
|
||||||
|
fs::write(
|
||||||
|
root.join("nested/lib.rs"),
|
||||||
|
"fn main() {}\nlet alpha = 1;\nlet alpha = 2;\n",
|
||||||
|
)
|
||||||
|
.expect("write rust file");
|
||||||
|
fs::write(root.join("nested/notes.txt"), "alpha\nbeta\n").expect("write txt file");
|
||||||
|
|
||||||
|
let globbed = execute_tool("glob_search", &json!({ "pattern": "nested/*.rs" }))
|
||||||
|
.expect("glob should succeed");
|
||||||
|
let globbed_output: serde_json::Value = serde_json::from_str(&globbed).expect("json");
|
||||||
|
assert_eq!(globbed_output["numFiles"], 1);
|
||||||
|
assert!(globbed_output["filenames"][0]
|
||||||
|
.as_str()
|
||||||
|
.expect("filename")
|
||||||
|
.ends_with("nested/lib.rs"));
|
||||||
|
|
||||||
|
let glob_error = execute_tool("glob_search", &json!({ "pattern": "[" }))
|
||||||
|
.expect_err("invalid glob should fail");
|
||||||
|
assert!(!glob_error.is_empty());
|
||||||
|
|
||||||
|
let grep_content = execute_tool(
|
||||||
|
"grep_search",
|
||||||
|
&json!({
|
||||||
|
"pattern": "alpha",
|
||||||
|
"path": "nested",
|
||||||
|
"glob": "*.rs",
|
||||||
|
"output_mode": "content",
|
||||||
|
"-n": true,
|
||||||
|
"head_limit": 1,
|
||||||
|
"offset": 1
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.expect("grep content should succeed");
|
||||||
|
let grep_content_output: serde_json::Value =
|
||||||
|
serde_json::from_str(&grep_content).expect("json");
|
||||||
|
assert_eq!(grep_content_output["numFiles"], 0);
|
||||||
|
assert!(grep_content_output["appliedLimit"].is_null());
|
||||||
|
assert_eq!(grep_content_output["appliedOffset"], 1);
|
||||||
|
assert!(grep_content_output["content"]
|
||||||
|
.as_str()
|
||||||
|
.expect("content")
|
||||||
|
.contains("let alpha = 2;"));
|
||||||
|
|
||||||
|
let grep_count = execute_tool(
|
||||||
|
"grep_search",
|
||||||
|
&json!({ "pattern": "alpha", "path": "nested", "output_mode": "count" }),
|
||||||
|
)
|
||||||
|
.expect("grep count should succeed");
|
||||||
|
let grep_count_output: serde_json::Value = serde_json::from_str(&grep_count).expect("json");
|
||||||
|
assert_eq!(grep_count_output["numMatches"], 3);
|
||||||
|
|
||||||
|
let grep_error = execute_tool(
|
||||||
|
"grep_search",
|
||||||
|
&json!({ "pattern": "(alpha", "path": "nested" }),
|
||||||
|
)
|
||||||
|
.expect_err("invalid regex should fail");
|
||||||
|
assert!(!grep_error.is_empty());
|
||||||
|
|
||||||
|
std::env::set_current_dir(&original_dir).expect("restore cwd");
|
||||||
|
let _ = fs::remove_dir_all(root);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn sleep_waits_and_reports_duration() {
|
fn sleep_waits_and_reports_duration() {
|
||||||
let started = std::time::Instant::now();
|
let started = std::time::Instant::now();
|
||||||
@@ -3038,6 +3485,15 @@ printf 'pwsh:%s' "$1"
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn text(status: u16, reason: &'static str, body: &str) -> Self {
|
||||||
|
Self {
|
||||||
|
status,
|
||||||
|
reason,
|
||||||
|
content_type: "text/plain; charset=utf-8",
|
||||||
|
body: body.to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn to_bytes(&self) -> Vec<u8> {
|
fn to_bytes(&self) -> Vec<u8> {
|
||||||
format!(
|
format!(
|
||||||
"HTTP/1.1 {} {}\r\nContent-Type: {}\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
|
"HTTP/1.1 {} {}\r\nContent-Type: {}\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
|
||||||
|
|||||||
Reference in New Issue
Block a user