Custom Providers
This guide walks through implementing a custom LLM provider for agent-air. You will create a provider struct, implement the LlmProvider trait, and integrate it with the configuration system.
Step 1: Create Provider Struct
Create a new module for your provider:
src/client/providers/
├── anthropic/
├── openai/
└── custom/ <- New provider
├── mod.rs
└── types.rs
Define the provider struct in mod.rs:
// src/client/providers/custom/mod.rs
mod types;
use crate::client::error::LlmError;
use crate::client::http::HttpClient;
use crate::client::models::{Message, MessageOptions};
use crate::client::traits::LlmProvider;
use std::future::Future;
use std::pin::Pin;
pub struct CustomProvider {
pub api_key: String,
pub model: String,
pub base_url: String,
}
impl CustomProvider {
pub fn new(api_key: String, model: String) -> Self {
Self {
api_key,
model,
base_url: "https://api.custom-llm.com".to_string(),
}
}
pub fn with_base_url(mut self, url: String) -> Self {
self.base_url = url;
self
}
}
Step 2: Implement LlmProvider Trait
Implement the required send_msg method:
impl LlmProvider for CustomProvider {
fn send_msg(
&self,
client: &HttpClient,
messages: &[Message],
options: &MessageOptions,
) -> Pin<Box<dyn Future<Output = Result<Message, LlmError>> + Send>> {
// Clone data for the async block
let client = client.clone();
let api_key = self.api_key.clone();
let model = self.model.clone();
let base_url = self.base_url.clone();
let messages = messages.to_vec();
let options = options.clone();
Box::pin(async move {
// Build request body
let body = types::build_request_body(&messages, &options, &model)?;
// Build headers
let auth_header = format!("Bearer {}", api_key);
let headers = [
("Authorization", auth_header.as_str()),
("Content-Type", "application/json"),
];
// Make API call
let url = format!("{}/v1/chat/completions", base_url);
let response = client.post(&url, &headers, &body).await?;
// Parse response
types::parse_response(&response)
})
}
}
Step 3: Implement Request Building
Create types.rs for request/response handling:
// src/client/providers/custom/types.rs
use crate::client::error::LlmError;
use crate::client::models::{Message, MessageOptions, MessageRole, ContentBlock};
use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Serialize)]
struct ChatRequest {
model: String,
messages: Vec<ApiMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
#[serde(skip_serializing_if = "Vec::is_empty")]
tools: Vec<ApiTool>,
}
#[derive(Serialize)]
struct ApiMessage {
role: String,
content: String,
}
#[derive(Serialize)]
struct ApiTool {
name: String,
description: String,
parameters: Value,
}
pub fn build_request_body(
messages: &[Message],
options: &MessageOptions,
model: &str,
) -> Result<String, LlmError> {
// Convert messages to API format
let api_messages: Vec<ApiMessage> = messages
.iter()
.filter_map(|m| {
let role = match m.role {
MessageRole::User => "user",
MessageRole::Assistant => "assistant",
MessageRole::System => "system",
};
// Extract text content
let text = m.content.iter()
.filter_map(|c| match c {
ContentBlock::Text(t) => Some(t.clone()),
_ => None,
})
.collect::<Vec<_>>()
.join("\n");
if text.is_empty() {
None
} else {
Some(ApiMessage {
role: role.to_string(),
content: text,
})
}
})
.collect();
// Convert tools
let tools: Vec<ApiTool> = options.tools.iter()
.map(|t| ApiTool {
name: t.name.clone(),
description: t.description.clone(),
parameters: serde_json::from_str(&t.input_schema).unwrap_or(Value::Null),
})
.collect();
let request = ChatRequest {
model: model.to_string(),
messages: api_messages,
max_tokens: options.max_tokens,
temperature: options.temperature,
tools,
};
serde_json::to_string(&request)
.map_err(|e| LlmError::new("SERIALIZE_ERROR", &e.to_string()))
}
Step 4: Implement Response Parsing
Add response parsing to types.rs:
#[derive(Deserialize)]
struct ChatResponse {
choices: Vec<Choice>,
}
#[derive(Deserialize)]
struct Choice {
message: ResponseMessage,
}
#[derive(Deserialize)]
struct ResponseMessage {
role: String,
content: Option<String>,
#[serde(default)]
tool_calls: Vec<ToolCall>,
}
#[derive(Deserialize)]
struct ToolCall {
id: String,
function: FunctionCall,
}
#[derive(Deserialize)]
struct FunctionCall {
name: String,
arguments: String,
}
pub fn parse_response(response: &str) -> Result<Message, LlmError> {
let parsed: ChatResponse = serde_json::from_str(response)
.map_err(|e| LlmError::new("PARSE_ERROR", &e.to_string()))?;
let choice = parsed.choices.first()
.ok_or_else(|| LlmError::new("NO_RESPONSE", "No choices in response"))?;
let mut content = Vec::new();
// Add text content
if let Some(text) = &choice.message.content {
if !text.is_empty() {
content.push(ContentBlock::Text(text.clone()));
}
}
// Add tool calls
for tool_call in &choice.message.tool_calls {
let input: Value = serde_json::from_str(&tool_call.function.arguments)
.unwrap_or(Value::Null);
content.push(ContentBlock::ToolUse {
id: tool_call.id.clone(),
name: tool_call.function.name.clone(),
input,
});
}
Ok(Message {
role: MessageRole::Assistant,
content,
})
}
Step 5: Add Streaming Support (Optional)
Override send_msg_stream if your API supports streaming:
use async_stream::stream;
use futures::Stream;
use crate::client::models::StreamEvent;
impl LlmProvider for CustomProvider {
// ... send_msg implementation ...
fn send_msg_stream(
&self,
client: &HttpClient,
messages: &[Message],
options: &MessageOptions,
) -> Pin<Box<dyn Future<Output = Result<Pin<Box<dyn Stream<Item = Result<StreamEvent, LlmError>> + Send>>, LlmError>> + Send>> {
let client = client.clone();
let api_key = self.api_key.clone();
let model = self.model.clone();
let base_url = self.base_url.clone();
let messages = messages.to_vec();
let options = options.clone();
Box::pin(async move {
// Build streaming request (add stream: true)
let body = types::build_streaming_request_body(&messages, &options, &model)?;
let auth_header = format!("Bearer {}", api_key);
let headers = [
("Authorization", auth_header.as_str()),
("Content-Type", "application/json"),
];
let url = format!("{}/v1/chat/completions", base_url);
let byte_stream = client.post_stream(&url, &headers, &body).await?;
// Parse SSE stream
use futures::StreamExt;
let event_stream = stream! {
let mut buffer = String::new();
let mut byte_stream = byte_stream;
while let Some(chunk_result) = byte_stream.next().await {
match chunk_result {
Ok(bytes) => {
if let Ok(text) = std::str::from_utf8(&bytes) {
buffer.push_str(text);
}
// Parse SSE events from buffer
for event in parse_sse_events(&mut buffer) {
yield event;
}
}
Err(e) => {
yield Err(e);
break;
}
}
}
};
Ok(Box::pin(event_stream) as Pin<Box<dyn Stream<Item = Result<StreamEvent, LlmError>> + Send>>)
})
}
}
Step 6: Export the Provider
Add to src/client/providers/mod.rs:
pub mod anthropic;
pub mod openai;
pub mod custom; // Add this line
pub use anthropic::AnthropicProvider;
pub use openai::OpenAIProvider;
pub use custom::CustomProvider; // Add this line
Step 7: Add LLMProvider Enum Variant
Update src/controller/session/config.rs:
#[derive(Debug, Clone, PartialEq)]
pub enum LLMProvider {
Anthropic,
OpenAI,
Custom, // Add this variant
}
Step 8: Add Session Config Builder
Add a builder method in LLMSessionConfig:
impl LLMSessionConfig {
// ... existing methods ...
pub fn custom(api_key: impl Into<String>, model: impl Into<String>) -> Self {
Self {
provider: LLMProvider::Custom,
api_key: api_key.into(),
model: model.into(),
max_tokens: Some(4096),
system_prompt: None,
temperature: None,
streaming: false, // Set based on your implementation
context_limit: 100_000, // Set based on your model
compaction: Some(CompactorType::default()),
}
}
}
Step 9: Update Provider Factory
Update the factory function that creates providers (in session.rs or executor.rs):
fn create_llm_client(config: &LLMSessionConfig) -> Result<LLMClient, LlmError> {
match config.provider {
LLMProvider::Anthropic => {
let provider = AnthropicProvider::new(
config.api_key.clone(),
config.model.clone()
);
LLMClient::new(Box::new(provider))
}
LLMProvider::OpenAI => {
let provider = OpenAIProvider::new(
config.api_key.clone(),
config.model.clone()
);
LLMClient::new(Box::new(provider))
}
LLMProvider::Custom => {
let provider = CustomProvider::new(
config.api_key.clone(),
config.model.clone()
);
LLMClient::new(Box::new(provider))
}
}
}
Step 10: Add YAML Configuration Support
Update src/agent/config.rs to recognize your provider:
fn create_session_config(
config: &ProviderConfig,
default_system_prompt: &str
) -> Result<LLMSessionConfig, ConfigError> {
let provider = match config.provider.as_str() {
"anthropic" => LLMProvider::Anthropic,
"openai" => LLMProvider::OpenAI,
"custom" => LLMProvider::Custom, // Add this
other => {
return Err(ConfigError::UnknownProvider {
provider: other.to_string(),
})
}
};
let mut session_config = match provider {
LLMProvider::Anthropic => {
LLMSessionConfig::anthropic(&config.api_key, &config.model)
}
LLMProvider::OpenAI => {
LLMSessionConfig::openai(&config.api_key, &config.model)
}
LLMProvider::Custom => {
LLMSessionConfig::custom(&config.api_key, &config.model)
}
};
// ... rest of configuration ...
}
Usage Example
After implementing, use your provider:
YAML Configuration
providers:
- provider: custom
api_key: your-api-key
model: custom-model-v1
system_prompt: "You are helpful."
default_provider: custom
Programmatic Usage
use agent_air::controller::LLMSessionConfig;
let config = LLMSessionConfig::custom("api-key", "model-name")
.with_system_prompt("You are helpful.")
.with_max_tokens(4096);
Direct Provider Usage
use agent_air::client::{LLMClient, CustomProvider};
let provider = CustomProvider::new(
"api-key".to_string(),
"model-name".to_string()
).with_base_url("https://custom-api.example.com".to_string());
let client = LLMClient::new(Box::new(provider))?;
let response = client.send_message(&messages, &options).await?;
Testing Your Provider
Create tests in src/client/providers/custom/mod.rs:
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_provider_creation() {
let provider = CustomProvider::new(
"test-key".to_string(),
"test-model".to_string()
);
assert_eq!(provider.api_key, "test-key");
assert_eq!(provider.model, "test-model");
}
#[test]
fn test_request_building() {
let messages = vec![Message {
role: MessageRole::User,
content: vec![ContentBlock::Text("Hello".to_string())],
}];
let options = MessageOptions::default();
let body = types::build_request_body(&messages, &options, "model")?;
assert!(body.contains("Hello"));
assert!(body.contains("model"));
}
} 