mirror of
https://github.com/instructkr/claw-code.git
synced 2026-06-07 18:25:22 -04:00
Compare commits
3 Commits
main
...
fix/openai
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a1da1ca8e6 | ||
|
|
27acfe1014 | ||
|
|
c1646613d1 |
@@ -296,9 +296,6 @@ pub fn metadata_for_model(model: &str) -> Option<ProviderMetadata> {
|
||||
None
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
#[must_use]
|
||||
pub fn strip_provider_prefix(canonical_model: &str) -> String {
|
||||
if let Some(pos) = canonical_model.find('/') {
|
||||
@@ -308,8 +305,6 @@ pub fn strip_provider_prefix(canonical_model: &str) -> String {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
#[must_use]
|
||||
pub fn provider_diagnostics_for_model(model: &str) -> ProviderDiagnostics {
|
||||
let resolved_model = resolve_model_alias(model);
|
||||
|
||||
@@ -16,8 +16,7 @@ use crate::types::{
|
||||
ToolChoice, ToolDefinition, ToolResultContentBlock, Usage,
|
||||
};
|
||||
|
||||
use super::{preflight_message_request, Provider, ProviderFuture, resolve_model_alias, strip_provider_prefix};
|
||||
|
||||
use super::{preflight_message_request, resolve_model_alias, Provider, ProviderFuture};
|
||||
|
||||
pub const DEFAULT_XAI_BASE_URL: &str = "https://api.x.ai/v1";
|
||||
pub const DEFAULT_OPENAI_BASE_URL: &str = "https://api.openai.com/v1";
|
||||
@@ -213,80 +212,22 @@ impl OpenAiCompatClient {
|
||||
}
|
||||
|
||||
pub async fn send_message(
|
||||
&self,
|
||||
request: &MessageRequest,
|
||||
) -> Result<MessageResponse, ApiError> {
|
||||
// 1. Keep track of what Claw originally asked for
|
||||
let original_model = request.model.clone();
|
||||
let canonical = resolve_model_alias(&request.model);
|
||||
|
||||
// 2. Clean the model string (e.g., "openai/deepseek-v4-flash" -> "deepseek-v4-flash")
|
||||
let downstream_model = strip_provider_prefix(&canonical);
|
||||
&self,
|
||||
request: &MessageRequest,
|
||||
) -> Result<MessageResponse, ApiError> {
|
||||
let original_model = request.model.clone();
|
||||
let canonical = resolve_model_alias(&request.model);
|
||||
|
||||
let mut request = MessageRequest {
|
||||
stream: false,
|
||||
..request.clone()
|
||||
};
|
||||
request.model = downstream_model; // Use the clean name for the API payload
|
||||
|
||||
preflight_message_request(&request)?;
|
||||
let response = self.send_with_retry(&request).await?;
|
||||
let request_id = request_id_from_headers(response.headers());
|
||||
let body = response.text().await.map_err(ApiError::from)?;
|
||||
let mut request = MessageRequest {
|
||||
stream: false,
|
||||
..request.clone()
|
||||
};
|
||||
request.model = canonical;
|
||||
|
||||
// Some backends return {"error":{"message":"...","type":"...","code":...}}
|
||||
// instead of a valid completion object. Check for this before attempting
|
||||
// full deserialization so the user sees the actual error, not a cryptic.
|
||||
if let Ok(raw) = serde_json::from_str::<serde_json::Value>(&body) {
|
||||
if let Some(err_obj) = raw.get("error") {
|
||||
let msg = err_obj
|
||||
.get("message")
|
||||
.and_then(|m| m.as_str())
|
||||
.unwrap_or("provider returned an error")
|
||||
.to_string();
|
||||
let code = err_obj
|
||||
.get("code")
|
||||
.and_then(serde_json::Value::as_u64)
|
||||
.map(|c| c as u16);
|
||||
return Err(ApiError::Api {
|
||||
status: reqwest::StatusCode::from_u16(code.unwrap_or(400))
|
||||
.unwrap_or(reqwest::StatusCode::BAD_REQUEST),
|
||||
error_type: err_obj
|
||||
.get("type")
|
||||
.and_then(|t| t.as_str())
|
||||
.map(str::to_owned),
|
||||
message: Some(msg),
|
||||
request_id,
|
||||
body,
|
||||
retryable: false,
|
||||
suggested_action: suggested_action_for_status(
|
||||
reqwest::StatusCode::from_u16(code.unwrap_or(400))
|
||||
.unwrap_or(reqwest::StatusCode::BAD_REQUEST),
|
||||
),
|
||||
retry_after: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Pass original_model to the deserializer error context so debugging logs are accurate
|
||||
let payload = serde_json::from_str::<ChatCompletionResponse>(&body).map_err(|error| {
|
||||
ApiError::json_deserialize(self.config.provider_name, &original_model, &body, error)
|
||||
})?;
|
||||
|
||||
let mut normalized = normalize_response(&request.model, payload)?;
|
||||
if normalized.request_id.is_none() {
|
||||
normalized.request_id = request_id;
|
||||
}
|
||||
|
||||
// 3. CRITICAL: Put the original model string back so Claw's internal routing stays happy
|
||||
normalized.model = original_model;
|
||||
|
||||
Ok(normalized)
|
||||
}
|
||||
// Some backends return {"error":{"message":"...","type":"...","code":...}}
|
||||
// instead of a valid completion object. Check for this before attempting
|
||||
// full deserialization so the user sees the actual error, not a cryptic
|
||||
// "missing field 'id'" parse failure.
|
||||
preflight_message_request(&request)?;
|
||||
let response = self.send_with_retry(&request).await?;
|
||||
let request_id = request_id_from_headers(response.headers());
|
||||
let body = response.text().await.map_err(ApiError::from)?;
|
||||
if let Ok(raw) = serde_json::from_str::<serde_json::Value>(&body) {
|
||||
if let Some(err_obj) = raw.get("error") {
|
||||
let msg = err_obj
|
||||
@@ -318,41 +259,41 @@ impl OpenAiCompatClient {
|
||||
}
|
||||
}
|
||||
let payload = serde_json::from_str::<ChatCompletionResponse>(&body).map_err(|error| {
|
||||
ApiError::json_deserialize(self.config.provider_name, &request.model, &body, error)
|
||||
ApiError::json_deserialize(self.config.provider_name, &original_model, &body, error)
|
||||
})?;
|
||||
let mut normalized = normalize_response(&request.model, payload)?;
|
||||
if normalized.request_id.is_none() {
|
||||
normalized.request_id = request_id;
|
||||
}
|
||||
normalized.model = original_model;
|
||||
Ok(normalized)
|
||||
}
|
||||
|
||||
pub async fn stream_message(
|
||||
&self,
|
||||
request: &MessageRequest,
|
||||
) -> Result<MessageStream, ApiError> {
|
||||
// 1. Keep track of the original model name
|
||||
let original_model = request.model.clone();
|
||||
let canonical = resolve_model_alias(&request.model);
|
||||
|
||||
// 2. Clean it up for DeepSeek
|
||||
let downstream_model = strip_provider_prefix(&canonical);
|
||||
pub async fn stream_message(
|
||||
&self,
|
||||
request: &MessageRequest,
|
||||
) -> Result<MessageStream, ApiError> {
|
||||
let original_model = request.model.clone();
|
||||
let canonical = resolve_model_alias(&request.model);
|
||||
|
||||
let mut streaming_request = request.clone().with_streaming();
|
||||
streaming_request.model = downstream_model;
|
||||
let mut streaming_request = request.clone().with_streaming();
|
||||
streaming_request.model = canonical;
|
||||
|
||||
preflight_message_request(&streaming_request)?;
|
||||
let response = self.send_with_retry(&streaming_request).await?;
|
||||
preflight_message_request(&streaming_request)?;
|
||||
let response = self.send_with_retry(&streaming_request).await?;
|
||||
|
||||
Ok(MessageStream {
|
||||
request_id: request_id_from_headers(response.headers()),
|
||||
response,
|
||||
parser: OpenAiSseParser::with_context(self.config.provider_name, original_model.clone()),
|
||||
pending: VecDeque::new(),
|
||||
done: false,
|
||||
state: StreamState::new(original_model), // 3. Use the original name here
|
||||
})
|
||||
}
|
||||
Ok(MessageStream {
|
||||
request_id: request_id_from_headers(response.headers()),
|
||||
response,
|
||||
parser: OpenAiSseParser::with_context(
|
||||
self.config.provider_name,
|
||||
original_model.clone(),
|
||||
),
|
||||
pending: VecDeque::new(),
|
||||
done: false,
|
||||
state: StreamState::new(original_model),
|
||||
})
|
||||
}
|
||||
|
||||
async fn send_with_retry(
|
||||
&self,
|
||||
|
||||
@@ -548,12 +548,13 @@ async fn openai_compatible_client_honors_http_proxy_for_requests() {
|
||||
.with_base_url("http://origin.invalid/v1");
|
||||
let response = client
|
||||
.send_message(&MessageRequest {
|
||||
model: "gpt-4o".to_string(),
|
||||
model: "openai/gpt-4.1-mini".to_string(),
|
||||
..sample_request(false)
|
||||
})
|
||||
.await
|
||||
.expect("proxy should return the OpenAI-compatible response");
|
||||
|
||||
assert_eq!(response.model, "openai/gpt-4.1-mini");
|
||||
assert_eq!(response.total_tokens(), 7);
|
||||
let captured = state.lock().await;
|
||||
let request = captured.first().expect("proxy should capture request");
|
||||
@@ -562,6 +563,8 @@ async fn openai_compatible_client_honors_http_proxy_for_requests() {
|
||||
request.headers.get("authorization").map(String::as_str),
|
||||
Some("Bearer openai-test-key")
|
||||
);
|
||||
let body: serde_json::Value = serde_json::from_str(&request.body).expect("json body");
|
||||
assert_eq!(body["model"], json!("openai/gpt-4.1-mini"));
|
||||
}
|
||||
|
||||
#[allow(clippy::await_holding_lock)]
|
||||
|
||||
@@ -832,6 +832,28 @@ mod tests {
|
||||
|
||||
static TEMP_COUNTER: AtomicU64 = AtomicU64::new(0);
|
||||
|
||||
struct EnvVarGuard {
|
||||
key: &'static str,
|
||||
previous: Option<std::ffi::OsString>,
|
||||
}
|
||||
|
||||
impl EnvVarGuard {
|
||||
fn set(key: &'static str, value: &Path) -> Self {
|
||||
let previous = std::env::var_os(key);
|
||||
std::env::set_var(key, value);
|
||||
Self { key, previous }
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for EnvVarGuard {
|
||||
fn drop(&mut self) {
|
||||
match &self.previous {
|
||||
Some(value) => std::env::set_var(self.key, value),
|
||||
None => std::env::remove_var(self.key),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn temp_dir() -> PathBuf {
|
||||
let nanos = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
@@ -1290,8 +1312,11 @@ mod tests {
|
||||
#[test]
|
||||
fn latest_session_returns_all_empty_error_when_sessions_exist_but_have_no_messages() {
|
||||
// given — create sessions with 0 messages (empty)
|
||||
let _env_guard = crate::test_env_lock();
|
||||
let base = temp_dir();
|
||||
fs::create_dir_all(&base).expect("base dir should exist");
|
||||
let isolated_config_home = base.join("config-home");
|
||||
let _claw_config_home = EnvVarGuard::set("CLAW_CONFIG_HOME", &isolated_config_home);
|
||||
let store = SessionStore::from_cwd(&base).expect("store should build");
|
||||
|
||||
let empty_handle = store.create_handle("empty-session");
|
||||
|
||||
@@ -1644,16 +1644,13 @@ mod tests {
|
||||
|
||||
let tmp = tempfile::tempdir().expect("tempdir");
|
||||
let worktree = tmp.path().join("worktree");
|
||||
let git_dir = tmp.path().join("external-gitdir");
|
||||
fs::create_dir_all(&worktree).expect("worktree dir");
|
||||
fs::create_dir_all(git_dir.join("objects")).expect("objects dir");
|
||||
fs::create_dir_all(git_dir.join("refs/heads")).expect("refs dir");
|
||||
fs::write(git_dir.join("HEAD"), "ref: refs/heads/main\n").expect("HEAD");
|
||||
fs::write(
|
||||
worktree.join(".git"),
|
||||
format!("gitdir: {}\n", git_dir.display()),
|
||||
)
|
||||
.expect(".git file");
|
||||
Command::new("git")
|
||||
.arg("init")
|
||||
.current_dir(&worktree)
|
||||
.output()
|
||||
.expect("git init should run");
|
||||
let git_dir = worktree.join(".git");
|
||||
|
||||
let original_permissions = fs::metadata(&git_dir)
|
||||
.expect("gitdir metadata")
|
||||
|
||||
@@ -13737,8 +13737,15 @@ fn push_output_block(
|
||||
};
|
||||
*pending_tool = Some((id, name, initial_input));
|
||||
}
|
||||
OutputContentBlock::Thinking { thinking, .. } => {
|
||||
OutputContentBlock::Thinking {
|
||||
thinking,
|
||||
signature,
|
||||
} => {
|
||||
render_thinking_block_summary(out, Some(thinking.chars().count()), false)?;
|
||||
events.push(AssistantEvent::Thinking {
|
||||
thinking,
|
||||
signature,
|
||||
});
|
||||
*block_has_thinking_summary = true;
|
||||
}
|
||||
OutputContentBlock::RedactedThinking { .. } => {
|
||||
@@ -19073,6 +19080,13 @@ UU conflicted.rs",
|
||||
|
||||
assert!(matches!(
|
||||
&events[0],
|
||||
AssistantEvent::Thinking {
|
||||
thinking,
|
||||
signature
|
||||
} if thinking == "step 1" && signature.as_deref() == Some("sig_123")
|
||||
));
|
||||
assert!(matches!(
|
||||
&events[1],
|
||||
AssistantEvent::TextDelta(text) if text == "Final answer"
|
||||
));
|
||||
let rendered = String::from_utf8(out).expect("utf8");
|
||||
@@ -19649,6 +19663,41 @@ mod dump_manifests_tests {
|
||||
|
||||
#[cfg(test)]
|
||||
mod alias_resolution_tests {
|
||||
fn ollama_env_lock() -> std::sync::MutexGuard<'static, ()> {
|
||||
static LOCK: std::sync::OnceLock<std::sync::Mutex<()>> = std::sync::OnceLock::new();
|
||||
LOCK.get_or_init(|| std::sync::Mutex::new(()))
|
||||
.lock()
|
||||
.expect("ollama env lock poisoned")
|
||||
}
|
||||
|
||||
struct EnvVarGuard {
|
||||
key: &'static str,
|
||||
previous: Option<String>,
|
||||
}
|
||||
|
||||
impl EnvVarGuard {
|
||||
fn unset(key: &'static str) -> Self {
|
||||
let previous = std::env::var(key).ok();
|
||||
std::env::remove_var(key);
|
||||
Self { key, previous }
|
||||
}
|
||||
|
||||
fn set(key: &'static str, value: &str) -> Self {
|
||||
let previous = std::env::var(key).ok();
|
||||
std::env::set_var(key, value);
|
||||
Self { key, previous }
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for EnvVarGuard {
|
||||
fn drop(&mut self) {
|
||||
match &self.previous {
|
||||
Some(value) => std::env::set_var(self.key, value),
|
||||
None => std::env::remove_var(self.key),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
use super::{resolve_model_alias_with_config, validate_model_syntax};
|
||||
|
||||
#[test]
|
||||
@@ -19670,6 +19719,8 @@ mod alias_resolution_tests {
|
||||
|
||||
#[test]
|
||||
fn test_alias_resolution_syntax_validation() {
|
||||
let _guard = ollama_env_lock();
|
||||
let _env = EnvVarGuard::unset("OLLAMA_HOST");
|
||||
// Resolved aliases should pass syntax validation
|
||||
let resolved = resolve_model_alias_with_config("opus");
|
||||
assert!(validate_model_syntax(&resolved).is_ok());
|
||||
@@ -19680,6 +19731,8 @@ mod alias_resolution_tests {
|
||||
|
||||
#[test]
|
||||
fn test_unknown_alias_fails_validation() {
|
||||
let _guard = ollama_env_lock();
|
||||
let _env = EnvVarGuard::unset("OLLAMA_HOST");
|
||||
// Unknown aliases resolve to themselves
|
||||
let resolved = resolve_model_alias_with_config("unknown-alias");
|
||||
assert_eq!(resolved, "unknown-alias");
|
||||
@@ -19699,14 +19752,13 @@ mod alias_resolution_tests {
|
||||
}
|
||||
#[test]
|
||||
fn test_ollama_host_bypasses_model_validation() {
|
||||
// Safety: test sets and clears env var within the test.
|
||||
std::env::set_var("OLLAMA_HOST", "http://127.0.0.1:11434");
|
||||
let _guard = ollama_env_lock();
|
||||
let _env = EnvVarGuard::set("OLLAMA_HOST", "http://127.0.0.1:11434");
|
||||
// Ollama model names with colons pass
|
||||
assert!(validate_model_syntax("qwen3:8b").is_ok());
|
||||
assert!(validate_model_syntax("gemma4:e2b").is_ok());
|
||||
assert!(validate_model_syntax("qwen3.6:27b-nvfp4").is_ok());
|
||||
// Empty model still rejected
|
||||
assert!(validate_model_syntax("").is_err());
|
||||
std::env::remove_var("OLLAMA_HOST");
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user