diff --git a/async-openai/README.md b/async-openai/README.md index 4cc594de..f5bd50b4 100644 --- a/async-openai/README.md +++ b/async-openai/README.md @@ -141,7 +141,24 @@ This can be useful in many scenarios: - To avoid verbose types. - To escape deserialization errors. -Visit [examples/bring-your-own-type](https://github.com/64bit/async-openai/tree/main/examples/bring-your-own-type) directory to learn more. +Visit [examples/bring-your-own-type](https://github.com/64bit/async-openai/tree/main/examples/bring-your-own-type) +directory to learn more. + +## Dynamic Dispatch for Different Providers + +For any struct that implements `Config` trait, you can wrap it in a smart pointer and cast the pointer to `dyn Config` +trait object, then your client can accept any wrapped configuration type. + +For example, + +```rust +use async_openai::{Client, config::Config, config::OpenAIConfig}; + +let openai_config = OpenAIConfig::default(); +// You can use `std::sync::Arc` to wrap the config as well +let config = Box::new(openai_config) as Box; +let client: Client > = Client::with_config(config); +``` ## Contributing diff --git a/async-openai/src/config.rs b/async-openai/src/config.rs index 4c5468c2..82ab043c 100644 --- a/async-openai/src/config.rs +++ b/async-openai/src/config.rs @@ -15,7 +15,7 @@ pub const OPENAI_BETA_HEADER: &str = "OpenAI-Beta"; /// [crate::Client] relies on this for every API call on OpenAI /// or Azure OpenAI service -pub trait Config: Clone { +pub trait Config: Send + Sync { fn headers(&self) -> HeaderMap; fn url(&self, path: &str) -> String; fn query(&self) -> Vec<(&str, &str)>; @@ -25,6 +25,32 @@ pub trait Config: Clone { fn api_key(&self) -> &SecretString; } +/// Macro to implement Config trait for pointer types with dyn objects +macro_rules! impl_config_for_ptr { + ($t:ty) => { + impl Config for $t { + fn headers(&self) -> HeaderMap { + self.as_ref().headers() + } + fn url(&self, path: &str) -> String { + self.as_ref().url(path) + } + fn query(&self) -> Vec<(&str, &str)> { + self.as_ref().query() + } + fn api_base(&self) -> &str { + self.as_ref().api_base() + } + fn api_key(&self) -> &SecretString { + self.as_ref().api_key() + } + } + }; +} + +impl_config_for_ptr!(Box); +impl_config_for_ptr!(std::sync::Arc); + /// Configuration for OpenAI API #[derive(Clone, Debug, Deserialize)] #[serde(default)] @@ -211,3 +237,55 @@ impl Config for AzureConfig { vec![("api-version", &self.api_version)] } } + +#[cfg(test)] +mod test { + use super::*; + use crate::types::{ + ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, CreateChatCompletionRequest, + }; + use crate::Client; + use std::sync::Arc; + #[test] + fn test_client_creation() { + unsafe { std::env::set_var("OPENAI_API_KEY", "test") } + let openai_config = OpenAIConfig::default(); + let config = Box::new(openai_config.clone()) as Box; + let client = Client::with_config(config); + assert!(client.config().url("").ends_with("/v1")); + + let config = Arc::new(openai_config) as Arc; + let client = Client::with_config(config); + assert!(client.config().url("").ends_with("/v1")); + let cloned_client = client.clone(); + assert!(cloned_client.config().url("").ends_with("/v1")); + } + + async fn dynamic_dispatch_compiles(client: &Client>) { + let _ = client.chat().create(CreateChatCompletionRequest { + model: "gpt-4o".to_string(), + messages: vec![ChatCompletionRequestMessage::User( + ChatCompletionRequestUserMessage { + content: "Hello, world!".into(), + ..Default::default() + }, + )], + ..Default::default() + }); + } + + #[tokio::test] + async fn test_dynamic_dispatch() { + let openai_config = OpenAIConfig::default(); + let azure_config = AzureConfig::default(); + + let azure_client = Client::with_config(Box::new(azure_config.clone()) as Box); + let oai_client = Client::with_config(Box::new(openai_config.clone()) as Box); + + let _ = dynamic_dispatch_compiles(&azure_client).await; + let _ = dynamic_dispatch_compiles(&oai_client).await; + + let _ = tokio::spawn(async move { dynamic_dispatch_compiles(&azure_client).await }); + let _ = tokio::spawn(async move { dynamic_dispatch_compiles(&oai_client).await }); + } +} diff --git a/async-openai/src/lib.rs b/async-openai/src/lib.rs index 1c290d61..6165069e 100644 --- a/async-openai/src/lib.rs +++ b/async-openai/src/lib.rs @@ -94,6 +94,22 @@ //! # }); //!``` //! +//! ## Dynamic Dispatch for Different Providers +//! +//! For any struct that implements `Config` trait, you can wrap it in a smart pointer and cast the pointer to `dyn Config` +//! trait object, then your client can accept any wrapped configuration type. +//! +//! For example, +//! ``` +//! use async_openai::{Client, config::Config, config::OpenAIConfig}; +//! unsafe { std::env::set_var("OPENAI_API_KEY", "only for doc test") } +//! +//! let openai_config = OpenAIConfig::default(); +//! // You can use `std::sync::Arc` to wrap the config as well +//! let config = Box::new(openai_config) as Box; +//! let client: Client > = Client::with_config(config); +//! ``` +//! //! ## Microsoft Azure //! //! ```