1use crate::error::{MojenticError, Result};
7use crate::llm::gateway::{CompletionConfig, LlmGateway, StreamChunk};
8use crate::llm::gateways::openai_messages_adapter::{adapt_messages_to_openai, convert_tool_calls};
9use crate::llm::gateways::openai_model_registry::{get_model_registry, ModelType};
10use crate::llm::models::{LlmGatewayResponse, LlmMessage, LlmToolCall};
11use crate::llm::tools::LlmTool;
12use async_trait::async_trait;
13use futures::stream::{Stream, StreamExt};
14use reqwest::Client;
15use serde_json::Value;
16use std::collections::HashMap;
17use std::pin::Pin;
18use tracing::{debug, info, warn};
19
20#[derive(Debug, Clone)]
22pub struct OpenAIConfig {
23 pub api_key: String,
24 pub base_url: String,
25 pub timeout: Option<std::time::Duration>,
26}
27
28impl Default for OpenAIConfig {
29 fn default() -> Self {
30 Self {
31 api_key: std::env::var("OPENAI_API_KEY").unwrap_or_default(),
32 base_url: std::env::var("OPENAI_API_ENDPOINT")
33 .unwrap_or_else(|_| "https://api.openai.com/v1".to_string()),
34 timeout: None,
35 }
36 }
37}
38
39pub struct OpenAIGateway {
44 client: Client,
45 config: OpenAIConfig,
46}
47
48impl OpenAIGateway {
49 pub fn new() -> Self {
51 Self::with_config(OpenAIConfig::default())
52 }
53
54 pub fn with_config(config: OpenAIConfig) -> Self {
56 let mut client_builder = Client::builder();
57
58 if let Some(timeout) = config.timeout {
59 client_builder = client_builder.timeout(timeout);
60 }
61
62 let client = client_builder.build().unwrap();
63
64 Self { client, config }
65 }
66
67 pub fn with_api_key(api_key: impl Into<String>) -> Self {
69 Self::with_config(OpenAIConfig {
70 api_key: api_key.into(),
71 ..Default::default()
72 })
73 }
74
75 pub fn with_api_key_and_base_url(
77 api_key: impl Into<String>,
78 base_url: impl Into<String>,
79 ) -> Self {
80 Self::with_config(OpenAIConfig {
81 api_key: api_key.into(),
82 base_url: base_url.into(),
83 ..Default::default()
84 })
85 }
86
87 fn adapt_parameters_for_model(
89 &self,
90 model: &str,
91 config: &CompletionConfig,
92 ) -> (HashMap<String, Value>, bool) {
93 let registry = get_model_registry();
94 let capabilities = registry.get_model_capabilities(model);
95
96 let mut params = HashMap::new();
97
98 debug!(
99 model = model,
100 model_type = ?capabilities.model_type,
101 supports_tools = capabilities.supports_tools,
102 supports_streaming = capabilities.supports_streaming,
103 "Adapting parameters for model"
104 );
105
106 let max_tokens = if config.max_tokens > 0 {
108 config.max_tokens
109 } else if let Some(np) = config.num_predict {
110 np as usize
111 } else {
112 16384
113 };
114
115 if capabilities.model_type == ModelType::Reasoning {
116 params.insert("max_completion_tokens".to_string(), serde_json::json!(max_tokens));
117 } else {
118 params.insert("max_tokens".to_string(), serde_json::json!(max_tokens));
119 }
120
121 if capabilities.supports_temperature(config.temperature) {
123 params.insert("temperature".to_string(), serde_json::json!(config.temperature));
124 } else if capabilities.supported_temperatures.as_ref().is_some_and(|t| t.is_empty()) {
125 warn!(
127 model = model,
128 requested_temperature = config.temperature,
129 "Model does not support temperature parameter at all"
130 );
131 } else {
132 warn!(
134 model = model,
135 requested_temperature = config.temperature,
136 default_temperature = 1.0,
137 "Model does not support requested temperature, using default"
138 );
139 params.insert("temperature".to_string(), serde_json::json!(1.0));
140 }
141
142 if let Some(top_p) = config.top_p {
144 params.insert("top_p".to_string(), serde_json::json!(top_p));
145 }
146
147 if let Some(reasoning_effort) = config.reasoning_effort {
149 if capabilities.model_type == ModelType::Reasoning {
150 use crate::llm::gateway::ReasoningEffort;
151 let effort_str = match reasoning_effort {
152 ReasoningEffort::Low => "low",
153 ReasoningEffort::Medium => "medium",
154 ReasoningEffort::High => "high",
155 };
156 params.insert("reasoning_effort".to_string(), serde_json::json!(effort_str));
157 } else {
158 warn!(
159 model = model,
160 "reasoning_effort specified but model is not a reasoning model, ignoring"
161 );
162 }
163 }
164
165 (params, capabilities.supports_tools)
166 }
167
168 fn chunk_text(&self, text: &str, chunk_size: usize) -> Vec<String> {
170 let chars: Vec<char> = text.chars().collect();
173 let avg_chars_per_token = 4; let max_chars = chunk_size * avg_chars_per_token;
175
176 if chars.len() <= max_chars {
177 return vec![text.to_string()];
178 }
179
180 let mut chunks = Vec::new();
181 let mut start = 0;
182
183 while start < chars.len() {
184 let end = std::cmp::min(start + max_chars, chars.len());
185 let chunk: String = chars[start..end].iter().collect();
186 chunks.push(chunk);
187 start = end;
188 }
189
190 chunks
191 }
192
193 fn weighted_average_embeddings(&self, embeddings: &[Vec<f32>], weights: &[f32]) -> Vec<f32> {
195 if embeddings.is_empty() {
196 return vec![];
197 }
198
199 let dimension = embeddings[0].len();
200 let total_weight: f32 = weights.iter().sum();
201
202 let average: Vec<f32> = (0..dimension)
204 .map(|dim_idx| {
205 embeddings
206 .iter()
207 .zip(weights.iter())
208 .map(|(embedding, &weight)| {
209 embedding.get(dim_idx).unwrap_or(&0.0) * (weight / total_weight)
210 })
211 .sum()
212 })
213 .collect();
214
215 let norm: f32 = average.iter().map(|x| x * x).sum::<f32>().sqrt();
217 if norm > 0.0 {
218 average.iter().map(|x| x / norm).collect()
219 } else {
220 average
221 }
222 }
223}
224
225impl Default for OpenAIGateway {
226 fn default() -> Self {
227 Self::new()
228 }
229}
230
231#[async_trait]
232impl LlmGateway for OpenAIGateway {
233 async fn complete(
234 &self,
235 model: &str,
236 messages: &[LlmMessage],
237 tools: Option<&[Box<dyn LlmTool>]>,
238 config: &CompletionConfig,
239 ) -> Result<LlmGatewayResponse> {
240 info!("Delegating to OpenAI for completion");
241 debug!("Model: {}, Message count: {}", model, messages.len());
242
243 let openai_messages = adapt_messages_to_openai(messages)?;
244 let (adapted_params, supports_tools) = self.adapt_parameters_for_model(model, config);
245
246 let mut body = serde_json::json!({
247 "model": model,
248 "messages": openai_messages,
249 });
250
251 for (key, value) in adapted_params {
253 body[key] = value;
254 }
255
256 if let Some(tools) = tools {
258 if supports_tools {
259 let tool_defs: Vec<_> = tools.iter().map(|t| t.descriptor()).collect();
260 body["tools"] = serde_json::to_value(tool_defs)?;
261 } else {
262 warn!(model = model, "Model does not support tools, ignoring tool configuration");
263 }
264 }
265
266 let response = self
268 .client
269 .post(format!("{}/chat/completions", self.config.base_url))
270 .header("Authorization", format!("Bearer {}", self.config.api_key))
271 .header("Content-Type", "application/json")
272 .json(&body)
273 .send()
274 .await?;
275
276 if !response.status().is_success() {
277 let status = response.status();
278 let error_text = response.text().await.unwrap_or_default();
279 return Err(MojenticError::GatewayError(format!(
280 "OpenAI API error: {} - {}",
281 status, error_text
282 )));
283 }
284
285 let response_body: Value = response.json().await?;
286
287 let content = response_body["choices"][0]["message"]["content"].as_str().map(String::from);
289
290 let tool_calls =
292 if let Some(calls) = response_body["choices"][0]["message"]["tool_calls"].as_array() {
293 convert_tool_calls(calls)
294 } else {
295 vec![]
296 };
297
298 Ok(LlmGatewayResponse {
299 content,
300 object: None,
301 tool_calls,
302 thinking: None,
303 })
304 }
305
306 async fn complete_json(
307 &self,
308 model: &str,
309 messages: &[LlmMessage],
310 schema: Value,
311 config: &CompletionConfig,
312 ) -> Result<Value> {
313 info!("Requesting structured output from OpenAI");
314
315 let openai_messages = adapt_messages_to_openai(messages)?;
316 let (adapted_params, _) = self.adapt_parameters_for_model(model, config);
317
318 let mut body = serde_json::json!({
319 "model": model,
320 "messages": openai_messages,
321 "response_format": {
322 "type": "json_schema",
323 "json_schema": {
324 "name": "response",
325 "schema": schema
326 }
327 }
328 });
329
330 for (key, value) in adapted_params {
332 body[key] = value;
333 }
334
335 let response = self
336 .client
337 .post(format!("{}/chat/completions", self.config.base_url))
338 .header("Authorization", format!("Bearer {}", self.config.api_key))
339 .header("Content-Type", "application/json")
340 .json(&body)
341 .send()
342 .await?;
343
344 if !response.status().is_success() {
345 let status = response.status();
346 let error_text = response.text().await.unwrap_or_default();
347 return Err(MojenticError::GatewayError(format!(
348 "OpenAI API error: {} - {}",
349 status, error_text
350 )));
351 }
352
353 let response_body: Value = response.json().await?;
354 let content = response_body["choices"][0]["message"]["content"]
355 .as_str()
356 .ok_or_else(|| MojenticError::GatewayError("No content in response".to_string()))?;
357
358 let json_value: Value = serde_json::from_str(content)?;
360
361 Ok(json_value)
362 }
363
364 async fn get_available_models(&self) -> Result<Vec<String>> {
365 debug!("Fetching available OpenAI models");
366
367 let response = self
368 .client
369 .get(format!("{}/models", self.config.base_url))
370 .header("Authorization", format!("Bearer {}", self.config.api_key))
371 .send()
372 .await?;
373
374 if !response.status().is_success() {
375 return Err(MojenticError::GatewayError(format!(
376 "Failed to get models: {}",
377 response.status()
378 )));
379 }
380
381 let body: Value = response.json().await?;
382
383 let mut models = body["data"]
384 .as_array()
385 .ok_or_else(|| MojenticError::GatewayError("Invalid response format".to_string()))?
386 .iter()
387 .filter_map(|m| m["id"].as_str().map(String::from))
388 .collect::<Vec<_>>();
389
390 models.sort();
391 Ok(models)
392 }
393
394 async fn calculate_embeddings(&self, text: &str, model: Option<&str>) -> Result<Vec<f32>> {
395 let model = model.unwrap_or("text-embedding-3-large");
396 debug!("Calculating embeddings with model: {}", model);
397
398 let chunks = self.chunk_text(text, 8191);
400
401 if chunks.is_empty() {
402 return Ok(vec![]);
403 }
404
405 let mut all_embeddings = Vec::new();
406 let mut weights = Vec::new();
407
408 for chunk in &chunks {
409 let body = serde_json::json!({
410 "model": model,
411 "input": chunk
412 });
413
414 let response = self
415 .client
416 .post(format!("{}/embeddings", self.config.base_url))
417 .header("Authorization", format!("Bearer {}", self.config.api_key))
418 .header("Content-Type", "application/json")
419 .json(&body)
420 .send()
421 .await?;
422
423 if !response.status().is_success() {
424 return Err(MojenticError::GatewayError(format!(
425 "Embeddings API error: {}",
426 response.status()
427 )));
428 }
429
430 let response_body: Value = response.json().await?;
431
432 let embedding: Vec<f32> = response_body["data"][0]["embedding"]
433 .as_array()
434 .ok_or_else(|| {
435 MojenticError::GatewayError("Invalid embeddings response".to_string())
436 })?
437 .iter()
438 .filter_map(|v| v.as_f64().map(|f| f as f32))
439 .collect();
440
441 weights.push(embedding.len() as f32);
442 all_embeddings.push(embedding);
443 }
444
445 if all_embeddings.len() == 1 {
447 return Ok(all_embeddings.remove(0));
448 }
449
450 Ok(self.weighted_average_embeddings(&all_embeddings, &weights))
452 }
453
454 fn complete_stream<'a>(
455 &'a self,
456 model: &'a str,
457 messages: &'a [LlmMessage],
458 tools: Option<&'a [Box<dyn LlmTool>]>,
459 config: &'a CompletionConfig,
460 ) -> Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send + 'a>> {
461 Box::pin(async_stream::stream! {
462 info!("Starting OpenAI streaming completion");
463 debug!("Model: {}, Message count: {}", model, messages.len());
464
465 let registry = get_model_registry();
467 let capabilities = registry.get_model_capabilities(model);
468 if !capabilities.supports_streaming {
469 yield Err(MojenticError::GatewayError(format!(
470 "Model {} does not support streaming",
471 model
472 )));
473 return;
474 }
475
476 let openai_messages = match adapt_messages_to_openai(messages) {
477 Ok(msgs) => msgs,
478 Err(e) => {
479 yield Err(e);
480 return;
481 }
482 };
483
484 let (adapted_params, supports_tools) = self.adapt_parameters_for_model(model, config);
485
486 let mut body = serde_json::json!({
487 "model": model,
488 "messages": openai_messages,
489 "stream": true
490 });
491
492 for (key, value) in adapted_params {
494 body[key] = value;
495 }
496
497 if let Some(tools) = tools {
499 if supports_tools {
500 let tool_defs: Vec<_> = tools.iter().map(|t| t.descriptor()).collect();
501 if let Ok(tools_value) = serde_json::to_value(tool_defs) {
502 body["tools"] = tools_value;
503 }
504 }
505 }
506
507 let response = match self
509 .client
510 .post(format!("{}/chat/completions", self.config.base_url))
511 .header("Authorization", format!("Bearer {}", self.config.api_key))
512 .header("Content-Type", "application/json")
513 .json(&body)
514 .send()
515 .await
516 {
517 Ok(r) => r,
518 Err(e) => {
519 yield Err(e.into());
520 return;
521 }
522 };
523
524 if !response.status().is_success() {
525 yield Err(MojenticError::GatewayError(format!(
526 "OpenAI API error: {}",
527 response.status()
528 )));
529 return;
530 }
531
532 let mut stream = response.bytes_stream();
534 let mut buffer = String::new();
535
536 let mut tool_calls_accumulator: HashMap<usize, ToolCallAccumulator> = HashMap::new();
538
539 while let Some(chunk_result) = stream.next().await {
540 match chunk_result {
541 Ok(bytes) => {
542 if let Ok(text) = std::str::from_utf8(&bytes) {
543 buffer.push_str(text);
544
545 while let Some(line_end) = buffer.find('\n') {
547 let line = buffer[..line_end].trim().to_string();
548 buffer = buffer[line_end + 1..].to_string();
549
550 if line.is_empty() || !line.starts_with("data: ") {
551 continue;
552 }
553
554 let data = line.strip_prefix("data: ").unwrap();
555
556 if data == "[DONE]" {
557 if !tool_calls_accumulator.is_empty() {
559 let complete_tool_calls = build_complete_tool_calls(&tool_calls_accumulator);
560 if !complete_tool_calls.is_empty() {
561 yield Ok(StreamChunk::ToolCalls(complete_tool_calls));
562 }
563 }
564 continue;
565 }
566
567 match serde_json::from_str::<Value>(data) {
569 Ok(json) => {
570 if let Some(choices) = json["choices"].as_array() {
571 if choices.is_empty() {
572 continue;
573 }
574
575 let delta = &choices[0]["delta"];
576 let finish_reason = choices[0]["finish_reason"].as_str();
577
578 if let Some(content) = delta["content"].as_str() {
580 if !content.is_empty() {
581 yield Ok(StreamChunk::Content(content.to_string()));
582 }
583 }
584
585 if let Some(tool_calls) = delta["tool_calls"].as_array() {
587 for tc in tool_calls {
588 if let Some(index) = tc["index"].as_u64() {
589 let index = index as usize;
590
591 let acc = tool_calls_accumulator.entry(index).or_insert_with(|| ToolCallAccumulator {
593 id: None,
594 name: None,
595 arguments: String::new(),
596 });
597
598 if let Some(id) = tc["id"].as_str() {
600 acc.id = Some(id.to_string());
601 }
602
603 if let Some(name) = tc["function"]["name"].as_str() {
605 acc.name = Some(name.to_string());
606 }
607
608 if let Some(args) = tc["function"]["arguments"].as_str() {
610 acc.arguments.push_str(args);
611 }
612 }
613 }
614 }
615
616 if finish_reason == Some("tool_calls") && !tool_calls_accumulator.is_empty() {
618 let complete_tool_calls = build_complete_tool_calls(&tool_calls_accumulator);
619 if !complete_tool_calls.is_empty() {
620 yield Ok(StreamChunk::ToolCalls(complete_tool_calls));
621 }
622 tool_calls_accumulator.clear();
623 }
624 }
625 }
626 Err(e) => {
627 warn!("Failed to parse streaming chunk: {}", e);
628 }
629 }
630 }
631 }
632 }
633 Err(e) => {
634 yield Err(e.into());
635 return;
636 }
637 }
638 }
639 })
640 }
641}
642
643struct ToolCallAccumulator {
645 id: Option<String>,
646 name: Option<String>,
647 arguments: String,
648}
649
650fn build_complete_tool_calls(
652 accumulators: &HashMap<usize, ToolCallAccumulator>,
653) -> Vec<LlmToolCall> {
654 let mut indices: Vec<_> = accumulators.keys().collect();
655 indices.sort();
656
657 indices
658 .iter()
659 .filter_map(|&&index| {
660 let acc = accumulators.get(&index)?;
661 let name = acc.name.clone()?;
662
663 let arguments: HashMap<String, Value> =
665 serde_json::from_str(&acc.arguments).unwrap_or_default();
666
667 Some(LlmToolCall {
668 id: acc.id.clone(),
669 name,
670 arguments,
671 })
672 })
673 .collect()
674}
675
676#[cfg(test)]
677mod tests {
678 use super::*;
679
680 #[test]
681 fn test_openai_config_default() {
682 std::env::remove_var("OPENAI_API_KEY");
683 std::env::remove_var("OPENAI_API_ENDPOINT");
684 let config = OpenAIConfig::default();
685 assert_eq!(config.api_key, "");
686 assert_eq!(config.base_url, "https://api.openai.com/v1");
687 assert!(config.timeout.is_none());
688 }
689
690 #[test]
691 fn test_openai_config_from_env() {
692 std::env::set_var("OPENAI_API_KEY", "test-key");
693 std::env::set_var("OPENAI_API_ENDPOINT", "https://custom.openai.com");
694 let config = OpenAIConfig::default();
695 assert_eq!(config.api_key, "test-key");
696 assert_eq!(config.base_url, "https://custom.openai.com");
697 std::env::remove_var("OPENAI_API_KEY");
698 std::env::remove_var("OPENAI_API_ENDPOINT");
699 }
700
701 #[test]
702 fn test_gateway_new() {
703 let gateway = OpenAIGateway::new();
704 assert_eq!(gateway.config.base_url, "https://api.openai.com/v1");
705 }
706
707 #[test]
708 fn test_gateway_with_api_key() {
709 let gateway = OpenAIGateway::with_api_key("my-api-key");
710 assert_eq!(gateway.config.api_key, "my-api-key");
711 }
712
713 #[test]
714 fn test_gateway_with_api_key_and_base_url() {
715 let gateway = OpenAIGateway::with_api_key_and_base_url("key", "https://custom.com");
716 assert_eq!(gateway.config.api_key, "key");
717 assert_eq!(gateway.config.base_url, "https://custom.com");
718 }
719
720 #[test]
721 fn test_gateway_default() {
722 let gateway = OpenAIGateway::default();
723 assert_eq!(gateway.config.base_url, "https://api.openai.com/v1");
724 }
725
726 #[test]
727 fn test_chunk_text_short() {
728 let gateway = OpenAIGateway::new();
729 let chunks = gateway.chunk_text("Hello world", 100);
730 assert_eq!(chunks.len(), 1);
731 assert_eq!(chunks[0], "Hello world");
732 }
733
734 #[test]
735 fn test_chunk_text_long() {
736 let gateway = OpenAIGateway::new();
737 let long_text = "a".repeat(50000);
738 let chunks = gateway.chunk_text(&long_text, 100);
739 assert!(chunks.len() > 1);
740 }
741
742 #[test]
743 fn test_weighted_average_embeddings_single() {
744 let gateway = OpenAIGateway::new();
745 let embeddings = vec![vec![1.0, 2.0, 3.0]];
746 let weights = vec![1.0];
747 let result = gateway.weighted_average_embeddings(&embeddings, &weights);
748
749 let norm = (1.0_f32 + 4.0 + 9.0).sqrt();
751 assert!((result[0] - 1.0 / norm).abs() < 0.001);
752 assert!((result[1] - 2.0 / norm).abs() < 0.001);
753 assert!((result[2] - 3.0 / norm).abs() < 0.001);
754 }
755
756 #[test]
757 fn test_weighted_average_embeddings_multiple() {
758 let gateway = OpenAIGateway::new();
759 let embeddings = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
760 let weights = vec![1.0, 1.0];
761 let result = gateway.weighted_average_embeddings(&embeddings, &weights);
762
763 let expected = 1.0 / (2.0_f32).sqrt();
765 assert!((result[0] - expected).abs() < 0.001);
766 assert!((result[1] - expected).abs() < 0.001);
767 }
768
769 #[test]
770 fn test_weighted_average_embeddings_empty() {
771 let gateway = OpenAIGateway::new();
772 let embeddings: Vec<Vec<f32>> = vec![];
773 let weights: Vec<f32> = vec![];
774 let result = gateway.weighted_average_embeddings(&embeddings, &weights);
775 assert!(result.is_empty());
776 }
777
778 #[test]
779 fn test_build_complete_tool_calls() {
780 let mut accumulators = HashMap::new();
781 accumulators.insert(
782 0,
783 ToolCallAccumulator {
784 id: Some("call_123".to_string()),
785 name: Some("get_weather".to_string()),
786 arguments: r#"{"location": "NYC"}"#.to_string(),
787 },
788 );
789 accumulators.insert(
790 1,
791 ToolCallAccumulator {
792 id: Some("call_456".to_string()),
793 name: Some("search".to_string()),
794 arguments: r#"{"query": "test"}"#.to_string(),
795 },
796 );
797
798 let result = build_complete_tool_calls(&accumulators);
799
800 assert_eq!(result.len(), 2);
801 assert_eq!(result[0].id, Some("call_123".to_string()));
802 assert_eq!(result[0].name, "get_weather");
803 assert_eq!(result[1].id, Some("call_456".to_string()));
804 assert_eq!(result[1].name, "search");
805 }
806
807 #[test]
808 fn test_build_complete_tool_calls_missing_name() {
809 let mut accumulators = HashMap::new();
810 accumulators.insert(
811 0,
812 ToolCallAccumulator {
813 id: Some("call_123".to_string()),
814 name: None, arguments: r#"{}"#.to_string(),
816 },
817 );
818
819 let result = build_complete_tool_calls(&accumulators);
820 assert!(result.is_empty()); }
822
823 #[test]
824 fn test_adapt_parameters_chat_model() {
825 let gateway = OpenAIGateway::new();
826 let config = CompletionConfig {
827 temperature: 0.7,
828 max_tokens: 1000,
829 ..Default::default()
830 };
831
832 let (params, supports_tools) = gateway.adapt_parameters_for_model("gpt-4", &config);
833
834 assert!(params.contains_key("max_tokens"));
835 assert!(!params.contains_key("max_completion_tokens"));
836 assert!(supports_tools);
837 }
838
839 #[test]
840 fn test_adapt_parameters_reasoning_model() {
841 let gateway = OpenAIGateway::new();
842 let config = CompletionConfig {
843 temperature: 0.7,
844 max_tokens: 1000,
845 ..Default::default()
846 };
847
848 let (params, supports_tools) = gateway.adapt_parameters_for_model("o1", &config);
849
850 assert!(!params.contains_key("max_tokens"));
851 assert!(params.contains_key("max_completion_tokens"));
852 assert!(supports_tools); }
854
855 #[tokio::test]
856 async fn test_complete_success() {
857 let mut server = mockito::Server::new_async().await;
858 let mock = server
859 .mock("POST", "/chat/completions")
860 .with_status(200)
861 .with_body(r#"{"choices":[{"message":{"role":"assistant","content":"Hello!"}}]}"#)
862 .create();
863
864 let gateway = OpenAIGateway::with_api_key_and_base_url("test-key", server.url());
865 let messages = vec![LlmMessage::user("Hi")];
866 let config = CompletionConfig::default();
867
868 let result = gateway.complete("gpt-4", &messages, None, &config).await;
869
870 mock.assert();
871 assert!(result.is_ok());
872 let response = result.unwrap();
873 assert_eq!(response.content, Some("Hello!".to_string()));
874 }
875
876 #[tokio::test]
877 async fn test_complete_with_tool_calls() {
878 let mut server = mockito::Server::new_async().await;
879 let mock = server
880 .mock("POST", "/chat/completions")
881 .with_status(200)
882 .with_body(r#"{"choices":[{"message":{"role":"assistant","content":null,"tool_calls":[{"id":"call_1","type":"function","function":{"name":"get_weather","arguments":"{\"location\": \"NYC\"}"}}]}}]}"#)
883 .create();
884
885 let gateway = OpenAIGateway::with_api_key_and_base_url("test-key", server.url());
886 let messages = vec![LlmMessage::user("Weather?")];
887 let config = CompletionConfig::default();
888
889 let result = gateway.complete("gpt-4", &messages, None, &config).await;
890
891 mock.assert();
892 assert!(result.is_ok());
893 let response = result.unwrap();
894 assert_eq!(response.tool_calls.len(), 1);
895 assert_eq!(response.tool_calls[0].name, "get_weather");
896 }
897
898 #[tokio::test]
899 async fn test_complete_error() {
900 let mut server = mockito::Server::new_async().await;
901 let mock = server
902 .mock("POST", "/chat/completions")
903 .with_status(401)
904 .with_body("Unauthorized")
905 .create();
906
907 let gateway = OpenAIGateway::with_api_key_and_base_url("bad-key", server.url());
908 let messages = vec![LlmMessage::user("Hi")];
909 let config = CompletionConfig::default();
910
911 let result = gateway.complete("gpt-4", &messages, None, &config).await;
912
913 mock.assert();
914 assert!(result.is_err());
915 }
916
917 #[tokio::test]
918 async fn test_complete_json() {
919 let mut server = mockito::Server::new_async().await;
920 let mock = server
921 .mock("POST", "/chat/completions")
922 .with_status(200)
923 .with_body(
924 r#"{"choices":[{"message":{"content":"{\"name\":\"test\",\"value\":42}"}}]}"#,
925 )
926 .create();
927
928 let gateway = OpenAIGateway::with_api_key_and_base_url("test-key", server.url());
929 let messages = vec![LlmMessage::user("Generate JSON")];
930 let schema = serde_json::json!({"type": "object"});
931 let config = CompletionConfig::default();
932
933 let result = gateway.complete_json("gpt-4", &messages, schema, &config).await;
934
935 mock.assert();
936 assert!(result.is_ok());
937 let json = result.unwrap();
938 assert_eq!(json["name"], "test");
939 assert_eq!(json["value"], 42);
940 }
941
942 #[tokio::test]
943 async fn test_get_available_models() {
944 let mut server = mockito::Server::new_async().await;
945 let mock = server
946 .mock("GET", "/models")
947 .with_status(200)
948 .with_body(r#"{"data":[{"id":"gpt-4"},{"id":"gpt-3.5-turbo"}]}"#)
949 .create();
950
951 let gateway = OpenAIGateway::with_api_key_and_base_url("test-key", server.url());
952 let result = gateway.get_available_models().await;
953
954 mock.assert();
955 assert!(result.is_ok());
956 let models = result.unwrap();
957 assert_eq!(models.len(), 2);
958 assert_eq!(models[0], "gpt-3.5-turbo");
960 assert_eq!(models[1], "gpt-4");
961 }
962
963 #[tokio::test]
964 async fn test_calculate_embeddings() {
965 let mut server = mockito::Server::new_async().await;
966 let mock = server
967 .mock("POST", "/embeddings")
968 .with_status(200)
969 .with_body(r#"{"data":[{"embedding":[0.1,0.2,0.3,0.4]}]}"#)
970 .create();
971
972 let gateway = OpenAIGateway::with_api_key_and_base_url("test-key", server.url());
973 let result = gateway.calculate_embeddings("test text", None).await;
974
975 mock.assert();
976 assert!(result.is_ok());
977 let embeddings = result.unwrap();
978 assert_eq!(embeddings.len(), 4);
979 }
980
981 #[tokio::test]
982 async fn test_calculate_embeddings_custom_model() {
983 let mut server = mockito::Server::new_async().await;
984 let mock = server
985 .mock("POST", "/embeddings")
986 .match_body(mockito::Matcher::JsonString(
987 r#"{"model":"text-embedding-3-small","input":"test"}"#.to_string(),
988 ))
989 .with_status(200)
990 .with_body(r#"{"data":[{"embedding":[0.5,0.6]}]}"#)
991 .create();
992
993 let gateway = OpenAIGateway::with_api_key_and_base_url("test-key", server.url());
994 let result = gateway.calculate_embeddings("test", Some("text-embedding-3-small")).await;
995
996 mock.assert();
997 assert!(result.is_ok());
998 }
999}