Files
claw-code/rust/crates/api/src/providers/openai_compat.rs
Jobdori adcea6bceb fix(api): route DashScope models to dashscope config, not openai
ProviderClient::from_model_with_anthropic_auth was dispatching every
ProviderKind::OpenAi match to OpenAiCompatConfig::openai(), which reads
OPENAI_API_KEY and points at api.openai.com. But DashScope models
(qwen-plus, qwen/qwen3-coder, etc.) also return ProviderKind::OpenAi
from detect_provider_kind because DashScope speaks the OpenAI wire
format. The metadata layer correctly identifies them as needing
DASHSCOPE_API_KEY and the DashScope compatible-mode endpoint, but that
metadata was being ignored at dispatch time.

Result: users running `claw --model qwen-plus` with DASHSCOPE_API_KEY
set would get a "missing OPENAI_API_KEY" error instead of being routed
to DashScope.

Fix: consult providers::metadata_for_model in the OpenAi dispatch arm
and pick dashscope() vs openai() based on metadata.auth_env.

Adds a regression test asserting ProviderClient::from_model("qwen-plus")
builds with the DashScope base URL. Exposes a pub base_url() accessor
on OpenAiCompatClient so the test can verify the routing.

Authored by droid (Kimi K2.5 Turbo) via acpx, cleaned up by Jobdori
(removed unsafe blocks unnecessary under edition 2021, imported
ProviderClient from super, adopted EnvVarGuard pattern from
providers/mod.rs tests).

Co-Authored-By: Droid <noreply@factory.ai>
2026-04-08 18:04:37 +09:00

1327 lines
42 KiB
Rust

use std::collections::{BTreeMap, VecDeque};
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use serde::Deserialize;
use serde_json::{json, Value};
use crate::error::ApiError;
use crate::http_client::build_http_client_or_default;
use crate::types::{
ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent,
InputContentBlock, InputMessage, MessageDelta, MessageDeltaEvent, MessageRequest,
MessageResponse, MessageStartEvent, MessageStopEvent, OutputContentBlock, StreamEvent,
ToolChoice, ToolDefinition, ToolResultContentBlock, Usage,
};
use super::{preflight_message_request, Provider, ProviderFuture};
pub const DEFAULT_XAI_BASE_URL: &str = "https://api.x.ai/v1";
pub const DEFAULT_OPENAI_BASE_URL: &str = "https://api.openai.com/v1";
pub const DEFAULT_DASHSCOPE_BASE_URL: &str = "https://dashscope.aliyuncs.com/compatible-mode/v1";
const REQUEST_ID_HEADER: &str = "request-id";
const ALT_REQUEST_ID_HEADER: &str = "x-request-id";
const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_secs(1);
const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(128);
const DEFAULT_MAX_RETRIES: u32 = 8;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct OpenAiCompatConfig {
pub provider_name: &'static str,
pub api_key_env: &'static str,
pub base_url_env: &'static str,
pub default_base_url: &'static str,
}
const XAI_ENV_VARS: &[&str] = &["XAI_API_KEY"];
const OPENAI_ENV_VARS: &[&str] = &["OPENAI_API_KEY"];
const DASHSCOPE_ENV_VARS: &[&str] = &["DASHSCOPE_API_KEY"];
impl OpenAiCompatConfig {
#[must_use]
pub const fn xai() -> Self {
Self {
provider_name: "xAI",
api_key_env: "XAI_API_KEY",
base_url_env: "XAI_BASE_URL",
default_base_url: DEFAULT_XAI_BASE_URL,
}
}
#[must_use]
pub const fn openai() -> Self {
Self {
provider_name: "OpenAI",
api_key_env: "OPENAI_API_KEY",
base_url_env: "OPENAI_BASE_URL",
default_base_url: DEFAULT_OPENAI_BASE_URL,
}
}
/// Alibaba DashScope compatible-mode endpoint (Qwen family models).
/// Uses the OpenAI-compatible REST shape at /compatible-mode/v1.
/// Requested via Discord #clawcode-get-help: native Alibaba API for
/// higher rate limits than going through OpenRouter.
#[must_use]
pub const fn dashscope() -> Self {
Self {
provider_name: "DashScope",
api_key_env: "DASHSCOPE_API_KEY",
base_url_env: "DASHSCOPE_BASE_URL",
default_base_url: DEFAULT_DASHSCOPE_BASE_URL,
}
}
#[must_use]
pub fn credential_env_vars(self) -> &'static [&'static str] {
match self.provider_name {
"xAI" => XAI_ENV_VARS,
"OpenAI" => OPENAI_ENV_VARS,
"DashScope" => DASHSCOPE_ENV_VARS,
_ => &[],
}
}
}
#[derive(Debug, Clone)]
pub struct OpenAiCompatClient {
http: reqwest::Client,
api_key: String,
config: OpenAiCompatConfig,
base_url: String,
max_retries: u32,
initial_backoff: Duration,
max_backoff: Duration,
}
impl OpenAiCompatClient {
const fn config(&self) -> OpenAiCompatConfig {
self.config
}
#[must_use]
pub fn base_url(&self) -> &str {
&self.base_url
}
#[must_use]
pub fn new(api_key: impl Into<String>, config: OpenAiCompatConfig) -> Self {
Self {
http: build_http_client_or_default(),
api_key: api_key.into(),
config,
base_url: read_base_url(config),
max_retries: DEFAULT_MAX_RETRIES,
initial_backoff: DEFAULT_INITIAL_BACKOFF,
max_backoff: DEFAULT_MAX_BACKOFF,
}
}
pub fn from_env(config: OpenAiCompatConfig) -> Result<Self, ApiError> {
let Some(api_key) = read_env_non_empty(config.api_key_env)? else {
return Err(ApiError::missing_credentials(
config.provider_name,
config.credential_env_vars(),
));
};
Ok(Self::new(api_key, config))
}
#[must_use]
pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
self.base_url = base_url.into();
self
}
#[must_use]
pub fn with_retry_policy(
mut self,
max_retries: u32,
initial_backoff: Duration,
max_backoff: Duration,
) -> Self {
self.max_retries = max_retries;
self.initial_backoff = initial_backoff;
self.max_backoff = max_backoff;
self
}
pub async fn send_message(
&self,
request: &MessageRequest,
) -> Result<MessageResponse, ApiError> {
let request = MessageRequest {
stream: false,
..request.clone()
};
preflight_message_request(&request)?;
let response = self.send_with_retry(&request).await?;
let request_id = request_id_from_headers(response.headers());
let body = response.text().await.map_err(ApiError::from)?;
let payload = serde_json::from_str::<ChatCompletionResponse>(&body).map_err(|error| {
ApiError::json_deserialize(self.config.provider_name, &request.model, &body, error)
})?;
let mut normalized = normalize_response(&request.model, payload)?;
if normalized.request_id.is_none() {
normalized.request_id = request_id;
}
Ok(normalized)
}
pub async fn stream_message(
&self,
request: &MessageRequest,
) -> Result<MessageStream, ApiError> {
preflight_message_request(request)?;
let response = self
.send_with_retry(&request.clone().with_streaming())
.await?;
Ok(MessageStream {
request_id: request_id_from_headers(response.headers()),
response,
parser: OpenAiSseParser::with_context(self.config.provider_name, request.model.clone()),
pending: VecDeque::new(),
done: false,
state: StreamState::new(request.model.clone()),
})
}
async fn send_with_retry(
&self,
request: &MessageRequest,
) -> Result<reqwest::Response, ApiError> {
let mut attempts = 0;
let last_error = loop {
attempts += 1;
let retryable_error = match self.send_raw_request(request).await {
Ok(response) => match expect_success(response).await {
Ok(response) => return Ok(response),
Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => error,
Err(error) => return Err(error),
},
Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => error,
Err(error) => return Err(error),
};
if attempts > self.max_retries {
break retryable_error;
}
tokio::time::sleep(self.jittered_backoff_for_attempt(attempts)?).await;
};
Err(ApiError::RetriesExhausted {
attempts,
last_error: Box::new(last_error),
})
}
async fn send_raw_request(
&self,
request: &MessageRequest,
) -> Result<reqwest::Response, ApiError> {
let request_url = chat_completions_endpoint(&self.base_url);
self.http
.post(&request_url)
.header("content-type", "application/json")
.bearer_auth(&self.api_key)
.json(&build_chat_completion_request(request, self.config()))
.send()
.await
.map_err(ApiError::from)
}
fn backoff_for_attempt(&self, attempt: u32) -> Result<Duration, ApiError> {
let Some(multiplier) = 1_u32.checked_shl(attempt.saturating_sub(1)) else {
return Err(ApiError::BackoffOverflow {
attempt,
base_delay: self.initial_backoff,
});
};
Ok(self
.initial_backoff
.checked_mul(multiplier)
.map_or(self.max_backoff, |delay| delay.min(self.max_backoff)))
}
fn jittered_backoff_for_attempt(&self, attempt: u32) -> Result<Duration, ApiError> {
let base = self.backoff_for_attempt(attempt)?;
Ok(base + jitter_for_base(base))
}
}
/// Process-wide counter that guarantees distinct jitter samples even when
/// the system clock resolution is coarser than consecutive retry sleeps.
static JITTER_COUNTER: AtomicU64 = AtomicU64::new(0);
/// Returns a random additive jitter in `[0, base]` to decorrelate retries
/// from multiple concurrent clients. Entropy is drawn from the nanosecond
/// wall clock mixed with a monotonic counter and run through a splitmix64
/// finalizer; adequate for retry jitter (no cryptographic requirement).
fn jitter_for_base(base: Duration) -> Duration {
let base_nanos = u64::try_from(base.as_nanos()).unwrap_or(u64::MAX);
if base_nanos == 0 {
return Duration::ZERO;
}
let raw_nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|elapsed| u64::try_from(elapsed.as_nanos()).unwrap_or(u64::MAX))
.unwrap_or(0);
let tick = JITTER_COUNTER.fetch_add(1, Ordering::Relaxed);
let mut mixed = raw_nanos
.wrapping_add(tick)
.wrapping_add(0x9E37_79B9_7F4A_7C15);
mixed = (mixed ^ (mixed >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
mixed = (mixed ^ (mixed >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
mixed ^= mixed >> 31;
let jitter_nanos = mixed % base_nanos.saturating_add(1);
Duration::from_nanos(jitter_nanos)
}
impl Provider for OpenAiCompatClient {
type Stream = MessageStream;
fn send_message<'a>(
&'a self,
request: &'a MessageRequest,
) -> ProviderFuture<'a, MessageResponse> {
Box::pin(async move { self.send_message(request).await })
}
fn stream_message<'a>(
&'a self,
request: &'a MessageRequest,
) -> ProviderFuture<'a, Self::Stream> {
Box::pin(async move { self.stream_message(request).await })
}
}
#[derive(Debug)]
pub struct MessageStream {
request_id: Option<String>,
response: reqwest::Response,
parser: OpenAiSseParser,
pending: VecDeque<StreamEvent>,
done: bool,
state: StreamState,
}
impl MessageStream {
#[must_use]
pub fn request_id(&self) -> Option<&str> {
self.request_id.as_deref()
}
pub async fn next_event(&mut self) -> Result<Option<StreamEvent>, ApiError> {
loop {
if let Some(event) = self.pending.pop_front() {
return Ok(Some(event));
}
if self.done {
self.pending.extend(self.state.finish()?);
if let Some(event) = self.pending.pop_front() {
return Ok(Some(event));
}
return Ok(None);
}
match self.response.chunk().await? {
Some(chunk) => {
for parsed in self.parser.push(&chunk)? {
self.pending.extend(self.state.ingest_chunk(parsed)?);
}
}
None => {
self.done = true;
}
}
}
}
}
#[derive(Debug, Default)]
struct OpenAiSseParser {
buffer: Vec<u8>,
provider: String,
model: String,
}
impl OpenAiSseParser {
fn with_context(provider: impl Into<String>, model: impl Into<String>) -> Self {
Self {
buffer: Vec::new(),
provider: provider.into(),
model: model.into(),
}
}
fn push(&mut self, chunk: &[u8]) -> Result<Vec<ChatCompletionChunk>, ApiError> {
self.buffer.extend_from_slice(chunk);
let mut events = Vec::new();
while let Some(frame) = next_sse_frame(&mut self.buffer) {
if let Some(event) = parse_sse_frame(&frame, &self.provider, &self.model)? {
events.push(event);
}
}
Ok(events)
}
}
#[allow(clippy::struct_excessive_bools)]
#[derive(Debug)]
struct StreamState {
model: String,
message_started: bool,
text_started: bool,
text_finished: bool,
finished: bool,
stop_reason: Option<String>,
usage: Option<Usage>,
tool_calls: BTreeMap<u32, ToolCallState>,
}
impl StreamState {
fn new(model: String) -> Self {
Self {
model,
message_started: false,
text_started: false,
text_finished: false,
finished: false,
stop_reason: None,
usage: None,
tool_calls: BTreeMap::new(),
}
}
fn ingest_chunk(&mut self, chunk: ChatCompletionChunk) -> Result<Vec<StreamEvent>, ApiError> {
let mut events = Vec::new();
if !self.message_started {
self.message_started = true;
events.push(StreamEvent::MessageStart(MessageStartEvent {
message: MessageResponse {
id: chunk.id.clone(),
kind: "message".to_string(),
role: "assistant".to_string(),
content: Vec::new(),
model: chunk.model.clone().unwrap_or_else(|| self.model.clone()),
stop_reason: None,
stop_sequence: None,
usage: Usage {
input_tokens: 0,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
output_tokens: 0,
},
request_id: None,
},
}));
}
if let Some(usage) = chunk.usage {
self.usage = Some(Usage {
input_tokens: usage.prompt_tokens,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
output_tokens: usage.completion_tokens,
});
}
for choice in chunk.choices {
if let Some(content) = choice.delta.content.filter(|value| !value.is_empty()) {
if !self.text_started {
self.text_started = true;
events.push(StreamEvent::ContentBlockStart(ContentBlockStartEvent {
index: 0,
content_block: OutputContentBlock::Text {
text: String::new(),
},
}));
}
events.push(StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent {
index: 0,
delta: ContentBlockDelta::TextDelta { text: content },
}));
}
for tool_call in choice.delta.tool_calls {
let state = self.tool_calls.entry(tool_call.index).or_default();
state.apply(tool_call);
let block_index = state.block_index();
if !state.started {
if let Some(start_event) = state.start_event()? {
state.started = true;
events.push(StreamEvent::ContentBlockStart(start_event));
} else {
continue;
}
}
if let Some(delta_event) = state.delta_event() {
events.push(StreamEvent::ContentBlockDelta(delta_event));
}
if choice.finish_reason.as_deref() == Some("tool_calls") && !state.stopped {
state.stopped = true;
events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent {
index: block_index,
}));
}
}
if let Some(finish_reason) = choice.finish_reason {
self.stop_reason = Some(normalize_finish_reason(&finish_reason));
if finish_reason == "tool_calls" {
for state in self.tool_calls.values_mut() {
if state.started && !state.stopped {
state.stopped = true;
events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent {
index: state.block_index(),
}));
}
}
}
}
}
Ok(events)
}
fn finish(&mut self) -> Result<Vec<StreamEvent>, ApiError> {
if self.finished {
return Ok(Vec::new());
}
self.finished = true;
let mut events = Vec::new();
if self.text_started && !self.text_finished {
self.text_finished = true;
events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent {
index: 0,
}));
}
for state in self.tool_calls.values_mut() {
if !state.started {
if let Some(start_event) = state.start_event()? {
state.started = true;
events.push(StreamEvent::ContentBlockStart(start_event));
if let Some(delta_event) = state.delta_event() {
events.push(StreamEvent::ContentBlockDelta(delta_event));
}
}
}
if state.started && !state.stopped {
state.stopped = true;
events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent {
index: state.block_index(),
}));
}
}
if self.message_started {
events.push(StreamEvent::MessageDelta(MessageDeltaEvent {
delta: MessageDelta {
stop_reason: Some(
self.stop_reason
.clone()
.unwrap_or_else(|| "end_turn".to_string()),
),
stop_sequence: None,
},
usage: self.usage.clone().unwrap_or(Usage {
input_tokens: 0,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
output_tokens: 0,
}),
}));
events.push(StreamEvent::MessageStop(MessageStopEvent {}));
}
Ok(events)
}
}
#[derive(Debug, Default)]
struct ToolCallState {
openai_index: u32,
id: Option<String>,
name: Option<String>,
arguments: String,
emitted_len: usize,
started: bool,
stopped: bool,
}
impl ToolCallState {
fn apply(&mut self, tool_call: DeltaToolCall) {
self.openai_index = tool_call.index;
if let Some(id) = tool_call.id {
self.id = Some(id);
}
if let Some(name) = tool_call.function.name {
self.name = Some(name);
}
if let Some(arguments) = tool_call.function.arguments {
self.arguments.push_str(&arguments);
}
}
const fn block_index(&self) -> u32 {
self.openai_index + 1
}
#[allow(clippy::unnecessary_wraps)]
fn start_event(&self) -> Result<Option<ContentBlockStartEvent>, ApiError> {
let Some(name) = self.name.clone() else {
return Ok(None);
};
let id = self
.id
.clone()
.unwrap_or_else(|| format!("tool_call_{}", self.openai_index));
Ok(Some(ContentBlockStartEvent {
index: self.block_index(),
content_block: OutputContentBlock::ToolUse {
id,
name,
input: json!({}),
},
}))
}
fn delta_event(&mut self) -> Option<ContentBlockDeltaEvent> {
if self.emitted_len >= self.arguments.len() {
return None;
}
let delta = self.arguments[self.emitted_len..].to_string();
self.emitted_len = self.arguments.len();
Some(ContentBlockDeltaEvent {
index: self.block_index(),
delta: ContentBlockDelta::InputJsonDelta {
partial_json: delta,
},
})
}
}
#[derive(Debug, Deserialize)]
struct ChatCompletionResponse {
id: String,
model: String,
choices: Vec<ChatChoice>,
#[serde(default)]
usage: Option<OpenAiUsage>,
}
#[derive(Debug, Deserialize)]
struct ChatChoice {
message: ChatMessage,
#[serde(default)]
finish_reason: Option<String>,
}
#[derive(Debug, Deserialize)]
struct ChatMessage {
role: String,
#[serde(default)]
content: Option<String>,
#[serde(default)]
tool_calls: Vec<ResponseToolCall>,
}
#[derive(Debug, Deserialize)]
struct ResponseToolCall {
id: String,
function: ResponseToolFunction,
}
#[derive(Debug, Deserialize)]
struct ResponseToolFunction {
name: String,
arguments: String,
}
#[derive(Debug, Deserialize)]
struct OpenAiUsage {
#[serde(default)]
prompt_tokens: u32,
#[serde(default)]
completion_tokens: u32,
}
#[derive(Debug, Deserialize)]
struct ChatCompletionChunk {
id: String,
#[serde(default)]
model: Option<String>,
#[serde(default)]
choices: Vec<ChunkChoice>,
#[serde(default)]
usage: Option<OpenAiUsage>,
}
#[derive(Debug, Deserialize)]
struct ChunkChoice {
delta: ChunkDelta,
#[serde(default)]
finish_reason: Option<String>,
}
#[derive(Debug, Default, Deserialize)]
struct ChunkDelta {
#[serde(default)]
content: Option<String>,
#[serde(default)]
tool_calls: Vec<DeltaToolCall>,
}
#[derive(Debug, Deserialize)]
struct DeltaToolCall {
#[serde(default)]
index: u32,
#[serde(default)]
id: Option<String>,
#[serde(default)]
function: DeltaFunction,
}
#[derive(Debug, Default, Deserialize)]
struct DeltaFunction {
#[serde(default)]
name: Option<String>,
#[serde(default)]
arguments: Option<String>,
}
#[derive(Debug, Deserialize)]
struct ErrorEnvelope {
error: ErrorBody,
}
#[derive(Debug, Deserialize)]
struct ErrorBody {
#[serde(rename = "type")]
error_type: Option<String>,
message: Option<String>,
}
/// Returns true for models known to reject tuning parameters like temperature,
/// top_p, frequency_penalty, and presence_penalty. These are typically
/// reasoning/chain-of-thought models with fixed sampling.
fn is_reasoning_model(model: &str) -> bool {
let lowered = model.to_ascii_lowercase();
// Strip any provider/ prefix for the check (e.g. qwen/qwen-qwq -> qwen-qwq)
let canonical = lowered.rsplit('/').next().unwrap_or(lowered.as_str());
// OpenAI reasoning models
canonical.starts_with("o1")
|| canonical.starts_with("o3")
|| canonical.starts_with("o4")
// xAI reasoning: grok-3-mini always uses reasoning mode
|| canonical == "grok-3-mini"
// Alibaba DashScope reasoning variants (QwQ + Qwen3-Thinking family)
|| canonical.starts_with("qwen-qwq")
|| canonical.starts_with("qwq")
|| canonical.contains("thinking")
}
fn build_chat_completion_request(request: &MessageRequest, config: OpenAiCompatConfig) -> Value {
let mut messages = Vec::new();
if let Some(system) = request.system.as_ref().filter(|value| !value.is_empty()) {
messages.push(json!({
"role": "system",
"content": system,
}));
}
for message in &request.messages {
messages.extend(translate_message(message));
}
let mut payload = json!({
"model": request.model,
"max_tokens": request.max_tokens,
"messages": messages,
"stream": request.stream,
});
if request.stream && should_request_stream_usage(config) {
payload["stream_options"] = json!({ "include_usage": true });
}
if let Some(tools) = &request.tools {
payload["tools"] =
Value::Array(tools.iter().map(openai_tool_definition).collect::<Vec<_>>());
}
if let Some(tool_choice) = &request.tool_choice {
payload["tool_choice"] = openai_tool_choice(tool_choice);
}
// OpenAI-compatible tuning parameters — only included when explicitly set.
// Reasoning models (o1/o3/o4/grok-3-mini) reject these params with 400;
// silently strip them to avoid cryptic provider errors.
if !is_reasoning_model(&request.model) {
if let Some(temperature) = request.temperature {
payload["temperature"] = json!(temperature);
}
if let Some(top_p) = request.top_p {
payload["top_p"] = json!(top_p);
}
if let Some(frequency_penalty) = request.frequency_penalty {
payload["frequency_penalty"] = json!(frequency_penalty);
}
if let Some(presence_penalty) = request.presence_penalty {
payload["presence_penalty"] = json!(presence_penalty);
}
}
// stop is generally safe for all providers
if let Some(stop) = &request.stop {
if !stop.is_empty() {
payload["stop"] = json!(stop);
}
}
payload
}
fn translate_message(message: &InputMessage) -> Vec<Value> {
match message.role.as_str() {
"assistant" => {
let mut text = String::new();
let mut tool_calls = Vec::new();
for block in &message.content {
match block {
InputContentBlock::Text { text: value } => text.push_str(value),
InputContentBlock::ToolUse { id, name, input } => tool_calls.push(json!({
"id": id,
"type": "function",
"function": {
"name": name,
"arguments": input.to_string(),
}
})),
InputContentBlock::ToolResult { .. } => {}
}
}
if text.is_empty() && tool_calls.is_empty() {
Vec::new()
} else {
vec![json!({
"role": "assistant",
"content": (!text.is_empty()).then_some(text),
"tool_calls": tool_calls,
})]
}
}
_ => message
.content
.iter()
.filter_map(|block| match block {
InputContentBlock::Text { text } => Some(json!({
"role": "user",
"content": text,
})),
InputContentBlock::ToolResult {
tool_use_id,
content,
is_error,
} => Some(json!({
"role": "tool",
"tool_call_id": tool_use_id,
"content": flatten_tool_result_content(content),
"is_error": is_error,
})),
InputContentBlock::ToolUse { .. } => None,
})
.collect(),
}
}
fn flatten_tool_result_content(content: &[ToolResultContentBlock]) -> String {
content
.iter()
.map(|block| match block {
ToolResultContentBlock::Text { text } => text.clone(),
ToolResultContentBlock::Json { value } => value.to_string(),
})
.collect::<Vec<_>>()
.join("\n")
}
fn openai_tool_definition(tool: &ToolDefinition) -> Value {
json!({
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.input_schema,
}
})
}
fn openai_tool_choice(tool_choice: &ToolChoice) -> Value {
match tool_choice {
ToolChoice::Auto => Value::String("auto".to_string()),
ToolChoice::Any => Value::String("required".to_string()),
ToolChoice::Tool { name } => json!({
"type": "function",
"function": { "name": name },
}),
}
}
fn should_request_stream_usage(config: OpenAiCompatConfig) -> bool {
matches!(config.provider_name, "OpenAI")
}
fn normalize_response(
model: &str,
response: ChatCompletionResponse,
) -> Result<MessageResponse, ApiError> {
let choice = response
.choices
.into_iter()
.next()
.ok_or(ApiError::InvalidSseFrame(
"chat completion response missing choices",
))?;
let mut content = Vec::new();
if let Some(text) = choice.message.content.filter(|value| !value.is_empty()) {
content.push(OutputContentBlock::Text { text });
}
for tool_call in choice.message.tool_calls {
content.push(OutputContentBlock::ToolUse {
id: tool_call.id,
name: tool_call.function.name,
input: parse_tool_arguments(&tool_call.function.arguments),
});
}
Ok(MessageResponse {
id: response.id,
kind: "message".to_string(),
role: choice.message.role,
content,
model: response.model.if_empty_then(model.to_string()),
stop_reason: choice
.finish_reason
.map(|value| normalize_finish_reason(&value)),
stop_sequence: None,
usage: Usage {
input_tokens: response
.usage
.as_ref()
.map_or(0, |usage| usage.prompt_tokens),
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
output_tokens: response
.usage
.as_ref()
.map_or(0, |usage| usage.completion_tokens),
},
request_id: None,
})
}
fn parse_tool_arguments(arguments: &str) -> Value {
serde_json::from_str(arguments).unwrap_or_else(|_| json!({ "raw": arguments }))
}
fn next_sse_frame(buffer: &mut Vec<u8>) -> Option<String> {
let separator = buffer
.windows(2)
.position(|window| window == b"\n\n")
.map(|position| (position, 2))
.or_else(|| {
buffer
.windows(4)
.position(|window| window == b"\r\n\r\n")
.map(|position| (position, 4))
})?;
let (position, separator_len) = separator;
let frame = buffer.drain(..position + separator_len).collect::<Vec<_>>();
let frame_len = frame.len().saturating_sub(separator_len);
Some(String::from_utf8_lossy(&frame[..frame_len]).into_owned())
}
fn parse_sse_frame(
frame: &str,
provider: &str,
model: &str,
) -> Result<Option<ChatCompletionChunk>, ApiError> {
let trimmed = frame.trim();
if trimmed.is_empty() {
return Ok(None);
}
let mut data_lines = Vec::new();
for line in trimmed.lines() {
if line.starts_with(':') {
continue;
}
if let Some(data) = line.strip_prefix("data:") {
data_lines.push(data.trim_start());
}
}
if data_lines.is_empty() {
return Ok(None);
}
let payload = data_lines.join("\n");
if payload == "[DONE]" {
return Ok(None);
}
serde_json::from_str::<ChatCompletionChunk>(&payload)
.map(Some)
.map_err(|error| ApiError::json_deserialize(provider, model, &payload, error))
}
fn read_env_non_empty(key: &str) -> Result<Option<String>, ApiError> {
match std::env::var(key) {
Ok(value) if !value.is_empty() => Ok(Some(value)),
Ok(_) | Err(std::env::VarError::NotPresent) => Ok(super::dotenv_value(key)),
Err(error) => Err(ApiError::from(error)),
}
}
#[must_use]
pub fn has_api_key(key: &str) -> bool {
read_env_non_empty(key)
.ok()
.and_then(std::convert::identity)
.is_some()
}
#[must_use]
pub fn read_base_url(config: OpenAiCompatConfig) -> String {
std::env::var(config.base_url_env).unwrap_or_else(|_| config.default_base_url.to_string())
}
fn chat_completions_endpoint(base_url: &str) -> String {
let trimmed = base_url.trim_end_matches('/');
if trimmed.ends_with("/chat/completions") {
trimmed.to_string()
} else {
format!("{trimmed}/chat/completions")
}
}
fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option<String> {
headers
.get(REQUEST_ID_HEADER)
.or_else(|| headers.get(ALT_REQUEST_ID_HEADER))
.and_then(|value| value.to_str().ok())
.map(ToOwned::to_owned)
}
async fn expect_success(response: reqwest::Response) -> Result<reqwest::Response, ApiError> {
let status = response.status();
if status.is_success() {
return Ok(response);
}
let request_id = request_id_from_headers(response.headers());
let body = response.text().await.unwrap_or_default();
let parsed_error = serde_json::from_str::<ErrorEnvelope>(&body).ok();
let retryable = is_retryable_status(status);
Err(ApiError::Api {
status,
error_type: parsed_error
.as_ref()
.and_then(|error| error.error.error_type.clone()),
message: parsed_error
.as_ref()
.and_then(|error| error.error.message.clone()),
request_id,
body,
retryable,
})
}
const fn is_retryable_status(status: reqwest::StatusCode) -> bool {
matches!(status.as_u16(), 408 | 409 | 429 | 500 | 502 | 503 | 504)
}
fn normalize_finish_reason(value: &str) -> String {
match value {
"stop" => "end_turn",
"tool_calls" => "tool_use",
other => other,
}
.to_string()
}
trait StringExt {
fn if_empty_then(self, fallback: String) -> String;
}
impl StringExt for String {
fn if_empty_then(self, fallback: String) -> String {
if self.is_empty() {
fallback
} else {
self
}
}
}
#[cfg(test)]
mod tests {
use super::{
build_chat_completion_request, chat_completions_endpoint, is_reasoning_model,
normalize_finish_reason, openai_tool_choice, parse_tool_arguments, OpenAiCompatClient,
OpenAiCompatConfig,
};
use crate::error::ApiError;
use crate::types::{
InputContentBlock, InputMessage, MessageRequest, ToolChoice, ToolDefinition,
ToolResultContentBlock,
};
use serde_json::json;
use std::sync::{Mutex, OnceLock};
#[test]
fn request_translation_uses_openai_compatible_shape() {
let payload = build_chat_completion_request(
&MessageRequest {
model: "grok-3".to_string(),
max_tokens: 64,
messages: vec![InputMessage {
role: "user".to_string(),
content: vec![
InputContentBlock::Text {
text: "hello".to_string(),
},
InputContentBlock::ToolResult {
tool_use_id: "tool_1".to_string(),
content: vec![ToolResultContentBlock::Json {
value: json!({"ok": true}),
}],
is_error: false,
},
],
}],
system: Some("be helpful".to_string()),
tools: Some(vec![ToolDefinition {
name: "weather".to_string(),
description: Some("Get weather".to_string()),
input_schema: json!({"type": "object"}),
}]),
tool_choice: Some(ToolChoice::Auto),
stream: false,
..Default::default()
},
OpenAiCompatConfig::xai(),
);
assert_eq!(payload["messages"][0]["role"], json!("system"));
assert_eq!(payload["messages"][1]["role"], json!("user"));
assert_eq!(payload["messages"][2]["role"], json!("tool"));
assert_eq!(payload["tools"][0]["type"], json!("function"));
assert_eq!(payload["tool_choice"], json!("auto"));
}
#[test]
fn openai_streaming_requests_include_usage_opt_in() {
let payload = build_chat_completion_request(
&MessageRequest {
model: "gpt-5".to_string(),
max_tokens: 64,
messages: vec![InputMessage::user_text("hello")],
system: None,
tools: None,
tool_choice: None,
stream: true,
..Default::default()
},
OpenAiCompatConfig::openai(),
);
assert_eq!(payload["stream_options"], json!({"include_usage": true}));
}
#[test]
fn xai_streaming_requests_skip_openai_specific_usage_opt_in() {
let payload = build_chat_completion_request(
&MessageRequest {
model: "grok-3".to_string(),
max_tokens: 64,
messages: vec![InputMessage::user_text("hello")],
system: None,
tools: None,
tool_choice: None,
stream: true,
..Default::default()
},
OpenAiCompatConfig::xai(),
);
assert!(payload.get("stream_options").is_none());
}
#[test]
fn tool_choice_translation_supports_required_function() {
assert_eq!(openai_tool_choice(&ToolChoice::Any), json!("required"));
assert_eq!(
openai_tool_choice(&ToolChoice::Tool {
name: "weather".to_string(),
}),
json!({"type": "function", "function": {"name": "weather"}})
);
}
#[test]
fn parses_tool_arguments_fallback() {
assert_eq!(
parse_tool_arguments("{\"city\":\"Paris\"}"),
json!({"city": "Paris"})
);
assert_eq!(parse_tool_arguments("not-json"), json!({"raw": "not-json"}));
}
#[test]
fn missing_xai_api_key_is_provider_specific() {
let _lock = env_lock();
std::env::remove_var("XAI_API_KEY");
let error = OpenAiCompatClient::from_env(OpenAiCompatConfig::xai())
.expect_err("missing key should error");
assert!(matches!(
error,
ApiError::MissingCredentials {
provider: "xAI",
..
}
));
}
#[test]
fn endpoint_builder_accepts_base_urls_and_full_endpoints() {
assert_eq!(
chat_completions_endpoint("https://api.x.ai/v1"),
"https://api.x.ai/v1/chat/completions"
);
assert_eq!(
chat_completions_endpoint("https://api.x.ai/v1/"),
"https://api.x.ai/v1/chat/completions"
);
assert_eq!(
chat_completions_endpoint("https://api.x.ai/v1/chat/completions"),
"https://api.x.ai/v1/chat/completions"
);
}
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::new(()))
.lock()
.expect("env lock")
}
#[test]
fn normalizes_stop_reasons() {
assert_eq!(normalize_finish_reason("stop"), "end_turn");
assert_eq!(normalize_finish_reason("tool_calls"), "tool_use");
}
#[test]
fn tuning_params_included_in_payload_when_set() {
let request = MessageRequest {
model: "gpt-4o".to_string(),
max_tokens: 1024,
messages: vec![],
system: None,
tools: None,
tool_choice: None,
stream: false,
temperature: Some(0.7),
top_p: Some(0.9),
frequency_penalty: Some(0.5),
presence_penalty: Some(0.3),
stop: Some(vec!["\n".to_string()]),
};
let payload = build_chat_completion_request(&request, OpenAiCompatConfig::openai());
assert_eq!(payload["temperature"], 0.7);
assert_eq!(payload["top_p"], 0.9);
assert_eq!(payload["frequency_penalty"], 0.5);
assert_eq!(payload["presence_penalty"], 0.3);
assert_eq!(payload["stop"], json!(["\n"]));
}
#[test]
fn reasoning_model_strips_tuning_params() {
let request = MessageRequest {
model: "o1-mini".to_string(),
max_tokens: 1024,
messages: vec![],
stream: false,
temperature: Some(0.7),
top_p: Some(0.9),
frequency_penalty: Some(0.5),
presence_penalty: Some(0.3),
stop: Some(vec!["\n".to_string()]),
..Default::default()
};
let payload = build_chat_completion_request(&request, OpenAiCompatConfig::openai());
assert!(
payload.get("temperature").is_none(),
"reasoning model should strip temperature"
);
assert!(
payload.get("top_p").is_none(),
"reasoning model should strip top_p"
);
assert!(payload.get("frequency_penalty").is_none());
assert!(payload.get("presence_penalty").is_none());
// stop is safe for all providers
assert_eq!(payload["stop"], json!(["\n"]));
}
#[test]
fn grok_3_mini_is_reasoning_model() {
assert!(is_reasoning_model("grok-3-mini"));
assert!(is_reasoning_model("o1"));
assert!(is_reasoning_model("o1-mini"));
assert!(is_reasoning_model("o3-mini"));
assert!(!is_reasoning_model("gpt-4o"));
assert!(!is_reasoning_model("grok-3"));
assert!(!is_reasoning_model("claude-sonnet-4-6"));
}
#[test]
fn qwen_reasoning_variants_are_detected() {
// QwQ reasoning model
assert!(is_reasoning_model("qwen-qwq-32b"));
assert!(is_reasoning_model("qwen/qwen-qwq-32b"));
// Qwen3 thinking family
assert!(is_reasoning_model("qwen3-30b-a3b-thinking"));
assert!(is_reasoning_model("qwen/qwen3-30b-a3b-thinking"));
// Bare qwq
assert!(is_reasoning_model("qwq-plus"));
// Regular Qwen models must NOT be classified as reasoning
assert!(!is_reasoning_model("qwen-max"));
assert!(!is_reasoning_model("qwen/qwen-plus"));
assert!(!is_reasoning_model("qwen-turbo"));
}
#[test]
fn tuning_params_omitted_from_payload_when_none() {
let request = MessageRequest {
model: "gpt-4o".to_string(),
max_tokens: 1024,
messages: vec![],
stream: false,
..Default::default()
};
let payload = build_chat_completion_request(&request, OpenAiCompatConfig::openai());
assert!(
payload.get("temperature").is_none(),
"temperature should be absent"
);
assert!(payload.get("top_p").is_none(), "top_p should be absent");
assert!(payload.get("frequency_penalty").is_none());
assert!(payload.get("presence_penalty").is_none());
assert!(payload.get("stop").is_none());
}
}