Custom Providers

This guide walks through implementing a custom LLM provider for agent-air. You will create a provider struct, implement the LlmProvider trait, and integrate it with the configuration system.


Step 1: Create Provider Struct

Create a new module for your provider:

src/client/providers/
├── anthropic/
├── openai/
└── custom/          <- New provider
    ├── mod.rs
    └── types.rs

Define the provider struct in mod.rs:

// src/client/providers/custom/mod.rs

mod types;

use crate::client::error::LlmError;
use crate::client::http::HttpClient;
use crate::client::models::{Message, MessageOptions};
use crate::client::traits::LlmProvider;
use std::future::Future;
use std::pin::Pin;

pub struct CustomProvider {
    pub api_key: String,
    pub model: String,
    pub base_url: String,
}

impl CustomProvider {
    pub fn new(api_key: String, model: String) -> Self {
        Self {
            api_key,
            model,
            base_url: "https://api.custom-llm.com".to_string(),
        }
    }

    pub fn with_base_url(mut self, url: String) -> Self {
        self.base_url = url;
        self
    }
}

Step 2: Implement LlmProvider Trait

Implement the required send_msg method:

impl LlmProvider for CustomProvider {
    fn send_msg(
        &self,
        client: &HttpClient,
        messages: &[Message],
        options: &MessageOptions,
    ) -> Pin<Box<dyn Future<Output = Result<Message, LlmError>> + Send>> {
        // Clone data for the async block
        let client = client.clone();
        let api_key = self.api_key.clone();
        let model = self.model.clone();
        let base_url = self.base_url.clone();
        let messages = messages.to_vec();
        let options = options.clone();

        Box::pin(async move {
            // Build request body
            let body = types::build_request_body(&messages, &options, &model)?;

            // Build headers
            let auth_header = format!("Bearer {}", api_key);
            let headers = [
                ("Authorization", auth_header.as_str()),
                ("Content-Type", "application/json"),
            ];

            // Make API call
            let url = format!("{}/v1/chat/completions", base_url);
            let response = client.post(&url, &headers, &body).await?;

            // Parse response
            types::parse_response(&response)
        })
    }
}

Step 3: Implement Request Building

Create types.rs for request/response handling:

// src/client/providers/custom/types.rs

use crate::client::error::LlmError;
use crate::client::models::{Message, MessageOptions, MessageRole, ContentBlock};
use serde::{Deserialize, Serialize};
use serde_json::Value;

#[derive(Serialize)]
struct ChatRequest {
    model: String,
    messages: Vec<ApiMessage>,
    #[serde(skip_serializing_if = "Option::is_none")]
    max_tokens: Option<u32>,
    #[serde(skip_serializing_if = "Option::is_none")]
    temperature: Option<f32>,
    #[serde(skip_serializing_if = "Vec::is_empty")]
    tools: Vec<ApiTool>,
}

#[derive(Serialize)]
struct ApiMessage {
    role: String,
    content: String,
}

#[derive(Serialize)]
struct ApiTool {
    name: String,
    description: String,
    parameters: Value,
}

pub fn build_request_body(
    messages: &[Message],
    options: &MessageOptions,
    model: &str,
) -> Result<String, LlmError> {
    // Convert messages to API format
    let api_messages: Vec<ApiMessage> = messages
        .iter()
        .filter_map(|m| {
            let role = match m.role {
                MessageRole::User => "user",
                MessageRole::Assistant => "assistant",
                MessageRole::System => "system",
            };

            // Extract text content
            let text = m.content.iter()
                .filter_map(|c| match c {
                    ContentBlock::Text(t) => Some(t.clone()),
                    _ => None,
                })
                .collect::<Vec<_>>()
                .join("\n");

            if text.is_empty() {
                None
            } else {
                Some(ApiMessage {
                    role: role.to_string(),
                    content: text,
                })
            }
        })
        .collect();

    // Convert tools
    let tools: Vec<ApiTool> = options.tools.iter()
        .map(|t| ApiTool {
            name: t.name.clone(),
            description: t.description.clone(),
            parameters: serde_json::from_str(&t.input_schema).unwrap_or(Value::Null),
        })
        .collect();

    let request = ChatRequest {
        model: model.to_string(),
        messages: api_messages,
        max_tokens: options.max_tokens,
        temperature: options.temperature,
        tools,
    };

    serde_json::to_string(&request)
        .map_err(|e| LlmError::new("SERIALIZE_ERROR", &e.to_string()))
}

Step 4: Implement Response Parsing

Add response parsing to types.rs:

#[derive(Deserialize)]
struct ChatResponse {
    choices: Vec<Choice>,
}

#[derive(Deserialize)]
struct Choice {
    message: ResponseMessage,
}

#[derive(Deserialize)]
struct ResponseMessage {
    role: String,
    content: Option<String>,
    #[serde(default)]
    tool_calls: Vec<ToolCall>,
}

#[derive(Deserialize)]
struct ToolCall {
    id: String,
    function: FunctionCall,
}

#[derive(Deserialize)]
struct FunctionCall {
    name: String,
    arguments: String,
}

pub fn parse_response(response: &str) -> Result<Message, LlmError> {
    let parsed: ChatResponse = serde_json::from_str(response)
        .map_err(|e| LlmError::new("PARSE_ERROR", &e.to_string()))?;

    let choice = parsed.choices.first()
        .ok_or_else(|| LlmError::new("NO_RESPONSE", "No choices in response"))?;

    let mut content = Vec::new();

    // Add text content
    if let Some(text) = &choice.message.content {
        if !text.is_empty() {
            content.push(ContentBlock::Text(text.clone()));
        }
    }

    // Add tool calls
    for tool_call in &choice.message.tool_calls {
        let input: Value = serde_json::from_str(&tool_call.function.arguments)
            .unwrap_or(Value::Null);

        content.push(ContentBlock::ToolUse {
            id: tool_call.id.clone(),
            name: tool_call.function.name.clone(),
            input,
        });
    }

    Ok(Message {
        role: MessageRole::Assistant,
        content,
    })
}

Step 5: Add Streaming Support (Optional)

Override send_msg_stream if your API supports streaming:

use async_stream::stream;
use futures::Stream;
use crate::client::models::StreamEvent;

impl LlmProvider for CustomProvider {
    // ... send_msg implementation ...

    fn send_msg_stream(
        &self,
        client: &HttpClient,
        messages: &[Message],
        options: &MessageOptions,
    ) -> Pin<Box<dyn Future<Output = Result<Pin<Box<dyn Stream<Item = Result<StreamEvent, LlmError>> + Send>>, LlmError>> + Send>> {
        let client = client.clone();
        let api_key = self.api_key.clone();
        let model = self.model.clone();
        let base_url = self.base_url.clone();
        let messages = messages.to_vec();
        let options = options.clone();

        Box::pin(async move {
            // Build streaming request (add stream: true)
            let body = types::build_streaming_request_body(&messages, &options, &model)?;

            let auth_header = format!("Bearer {}", api_key);
            let headers = [
                ("Authorization", auth_header.as_str()),
                ("Content-Type", "application/json"),
            ];

            let url = format!("{}/v1/chat/completions", base_url);
            let byte_stream = client.post_stream(&url, &headers, &body).await?;

            // Parse SSE stream
            use futures::StreamExt;
            let event_stream = stream! {
                let mut buffer = String::new();
                let mut byte_stream = byte_stream;

                while let Some(chunk_result) = byte_stream.next().await {
                    match chunk_result {
                        Ok(bytes) => {
                            if let Ok(text) = std::str::from_utf8(&bytes) {
                                buffer.push_str(text);
                            }

                            // Parse SSE events from buffer
                            for event in parse_sse_events(&mut buffer) {
                                yield event;
                            }
                        }
                        Err(e) => {
                            yield Err(e);
                            break;
                        }
                    }
                }
            };

            Ok(Box::pin(event_stream) as Pin<Box<dyn Stream<Item = Result<StreamEvent, LlmError>> + Send>>)
        })
    }
}

Step 6: Export the Provider

Add to src/client/providers/mod.rs:

pub mod anthropic;
pub mod openai;
pub mod custom;  // Add this line

pub use anthropic::AnthropicProvider;
pub use openai::OpenAIProvider;
pub use custom::CustomProvider;  // Add this line

Step 7: Add LLMProvider Enum Variant

Update src/controller/session/config.rs:

#[derive(Debug, Clone, PartialEq)]
pub enum LLMProvider {
    Anthropic,
    OpenAI,
    Custom,  // Add this variant
}

Step 8: Add Session Config Builder

Add a builder method in LLMSessionConfig:

impl LLMSessionConfig {
    // ... existing methods ...

    pub fn custom(api_key: impl Into<String>, model: impl Into<String>) -> Self {
        Self {
            provider: LLMProvider::Custom,
            api_key: api_key.into(),
            model: model.into(),
            max_tokens: Some(4096),
            system_prompt: None,
            temperature: None,
            streaming: false,  // Set based on your implementation
            context_limit: 100_000,  // Set based on your model
            compaction: Some(CompactorType::default()),
        }
    }
}

Step 9: Update Provider Factory

Update the factory function that creates providers (in session.rs or executor.rs):

fn create_llm_client(config: &LLMSessionConfig) -> Result<LLMClient, LlmError> {
    match config.provider {
        LLMProvider::Anthropic => {
            let provider = AnthropicProvider::new(
                config.api_key.clone(),
                config.model.clone()
            );
            LLMClient::new(Box::new(provider))
        }
        LLMProvider::OpenAI => {
            let provider = OpenAIProvider::new(
                config.api_key.clone(),
                config.model.clone()
            );
            LLMClient::new(Box::new(provider))
        }
        LLMProvider::Custom => {
            let provider = CustomProvider::new(
                config.api_key.clone(),
                config.model.clone()
            );
            LLMClient::new(Box::new(provider))
        }
    }
}

Step 10: Add YAML Configuration Support

Update src/agent/config.rs to recognize your provider:

fn create_session_config(
    config: &ProviderConfig,
    default_system_prompt: &str
) -> Result<LLMSessionConfig, ConfigError> {
    let provider = match config.provider.as_str() {
        "anthropic" => LLMProvider::Anthropic,
        "openai" => LLMProvider::OpenAI,
        "custom" => LLMProvider::Custom,  // Add this
        other => {
            return Err(ConfigError::UnknownProvider {
                provider: other.to_string(),
            })
        }
    };

    let mut session_config = match provider {
        LLMProvider::Anthropic => {
            LLMSessionConfig::anthropic(&config.api_key, &config.model)
        }
        LLMProvider::OpenAI => {
            LLMSessionConfig::openai(&config.api_key, &config.model)
        }
        LLMProvider::Custom => {
            LLMSessionConfig::custom(&config.api_key, &config.model)
        }
    };

    // ... rest of configuration ...
}

Usage Example

After implementing, use your provider:

YAML Configuration

providers:
  - provider: custom
    api_key: your-api-key
    model: custom-model-v1
    system_prompt: "You are helpful."

default_provider: custom

Programmatic Usage

use agent_air::controller::LLMSessionConfig;

let config = LLMSessionConfig::custom("api-key", "model-name")
    .with_system_prompt("You are helpful.")
    .with_max_tokens(4096);

Direct Provider Usage

use agent_air::client::{LLMClient, CustomProvider};

let provider = CustomProvider::new(
    "api-key".to_string(),
    "model-name".to_string()
).with_base_url("https://custom-api.example.com".to_string());

let client = LLMClient::new(Box::new(provider))?;

let response = client.send_message(&messages, &options).await?;

Testing Your Provider

Create tests in src/client/providers/custom/mod.rs:

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_provider_creation() {
        let provider = CustomProvider::new(
            "test-key".to_string(),
            "test-model".to_string()
        );
        assert_eq!(provider.api_key, "test-key");
        assert_eq!(provider.model, "test-model");
    }

    #[test]
    fn test_request_building() {
        let messages = vec![Message {
            role: MessageRole::User,
            content: vec![ContentBlock::Text("Hello".to_string())],
        }];

        let options = MessageOptions::default();
        let body = types::build_request_body(&messages, &options, "model")?;

        assert!(body.contains("Hello"));
        assert!(body.contains("model"));
    }
}