From 0f67cf6d71fc67d77bf974803458b72ebbb9cdde Mon Sep 17 00:00:00 2001 From: b0xtch Date: Fri, 19 Jul 2024 13:15:24 -0700 Subject: [PATCH] chore: add with_headers method --- async-openai/src/config.rs | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/async-openai/src/config.rs b/async-openai/src/config.rs index 91b3699a..230d2b10 100644 --- a/async-openai/src/config.rs +++ b/async-openai/src/config.rs @@ -1,5 +1,7 @@ //! Client configurations: [OpenAIConfig] for OpenAI, [AzureConfig] for Azure OpenAI Service. -use reqwest::header::{HeaderMap, AUTHORIZATION}; +use std::collections::HashMap; + +use reqwest::header::{HeaderMap, HeaderName, HeaderValue, AUTHORIZATION}; use secrecy::{ExposeSecret, Secret}; use serde::Deserialize; @@ -31,10 +33,20 @@ pub trait Config: Clone { pub struct OpenAIConfig { api_base: String, api_key: Secret, + #[serde(deserialize_with = "deserialize_header_map")] + headers: HashMap, org_id: String, project_id: String, } +fn deserialize_header_map<'de, D>(deserializer: D) -> Result, D::Error> +where + D: serde::Deserializer<'de>, +{ + let header_map: HashMap = HashMap::deserialize(deserializer)?; + Ok(header_map) +} + impl Default for OpenAIConfig { fn default() -> Self { Self { @@ -42,6 +54,7 @@ impl Default for OpenAIConfig { api_key: std::env::var("OPENAI_API_KEY") .unwrap_or_else(|_| "".to_string()) .into(), + headers: HashMap::new(), org_id: Default::default(), project_id: Default::default(), } @@ -78,6 +91,12 @@ impl OpenAIConfig { self } + /// Add custom headers to the existing headers + pub fn with_headers(mut self, headers: HashMap) -> Self { + self.headers.extend(headers); + self + } + pub fn org_id(&self) -> &str { &self.org_id } @@ -112,6 +131,13 @@ impl Config for OpenAIConfig { // Calls to the Assistants API require that you pass a Beta header headers.insert(OPENAI_BETA_HEADER, "assistants=v2".parse().unwrap()); + headers.extend(self.headers.iter().map(|(k, v)| { + ( + HeaderName::from_bytes(k.as_bytes()).unwrap(), + HeaderValue::from_str(v).unwrap(), + ) + })); + headers }