mojentic/llm/gateways/
openai.rs

1//! OpenAI Gateway for LLM interactions.
2//!
3//! This module provides a gateway for interacting with OpenAI's API,
4//! including chat completions, streaming, and embeddings.
5
6use crate::error::{MojenticError, Result};
7use crate::llm::gateway::{CompletionConfig, LlmGateway, StreamChunk};
8use crate::llm::gateways::openai_messages_adapter::{adapt_messages_to_openai, convert_tool_calls};
9use crate::llm::gateways::openai_model_registry::{get_model_registry, ModelType};
10use crate::llm::models::{LlmGatewayResponse, LlmMessage, LlmToolCall};
11use crate::llm::tools::LlmTool;
12use async_trait::async_trait;
13use futures::stream::{Stream, StreamExt};
14use reqwest::Client;
15use serde_json::Value;
16use std::collections::HashMap;
17use std::pin::Pin;
18use tracing::{debug, info, warn};
19
20/// Configuration for connecting to OpenAI API.
21#[derive(Debug, Clone)]
22pub struct OpenAIConfig {
23    pub api_key: String,
24    pub base_url: String,
25    pub timeout: Option<std::time::Duration>,
26}
27
28impl Default for OpenAIConfig {
29    fn default() -> Self {
30        Self {
31            api_key: std::env::var("OPENAI_API_KEY").unwrap_or_default(),
32            base_url: std::env::var("OPENAI_API_ENDPOINT")
33                .unwrap_or_else(|_| "https://api.openai.com/v1".to_string()),
34            timeout: None,
35        }
36    }
37}
38
39/// Gateway for OpenAI LLM service.
40///
41/// This gateway provides access to OpenAI models through their API,
42/// supporting text generation, structured output, tool calling, and embeddings.
43pub struct OpenAIGateway {
44    client: Client,
45    config: OpenAIConfig,
46}
47
48impl OpenAIGateway {
49    /// Create a new OpenAI gateway with default configuration.
50    pub fn new() -> Self {
51        Self::with_config(OpenAIConfig::default())
52    }
53
54    /// Create a new OpenAI gateway with custom configuration.
55    pub fn with_config(config: OpenAIConfig) -> Self {
56        let mut client_builder = Client::builder();
57
58        if let Some(timeout) = config.timeout {
59            client_builder = client_builder.timeout(timeout);
60        }
61
62        let client = client_builder.build().unwrap();
63
64        Self { client, config }
65    }
66
67    /// Create gateway with custom API key.
68    pub fn with_api_key(api_key: impl Into<String>) -> Self {
69        Self::with_config(OpenAIConfig {
70            api_key: api_key.into(),
71            ..Default::default()
72        })
73    }
74
75    /// Create gateway with custom API key and base URL.
76    pub fn with_api_key_and_base_url(
77        api_key: impl Into<String>,
78        base_url: impl Into<String>,
79    ) -> Self {
80        Self::with_config(OpenAIConfig {
81            api_key: api_key.into(),
82            base_url: base_url.into(),
83            ..Default::default()
84        })
85    }
86
87    /// Adapt parameters based on model type and capabilities.
88    fn adapt_parameters_for_model(
89        &self,
90        model: &str,
91        config: &CompletionConfig,
92    ) -> (HashMap<String, Value>, bool) {
93        let registry = get_model_registry();
94        let capabilities = registry.get_model_capabilities(model);
95
96        let mut params = HashMap::new();
97
98        debug!(
99            model = model,
100            model_type = ?capabilities.model_type,
101            supports_tools = capabilities.supports_tools,
102            supports_streaming = capabilities.supports_streaming,
103            "Adapting parameters for model"
104        );
105
106        // Handle token limit parameter conversion
107        let max_tokens = if config.max_tokens > 0 {
108            config.max_tokens
109        } else if let Some(np) = config.num_predict {
110            np as usize
111        } else {
112            16384
113        };
114
115        if capabilities.model_type == ModelType::Reasoning {
116            params.insert("max_completion_tokens".to_string(), serde_json::json!(max_tokens));
117        } else {
118            params.insert("max_tokens".to_string(), serde_json::json!(max_tokens));
119        }
120
121        // Handle temperature restrictions
122        if capabilities.supports_temperature(config.temperature) {
123            params.insert("temperature".to_string(), serde_json::json!(config.temperature));
124        } else if capabilities.supported_temperatures.as_ref().is_some_and(|t| t.is_empty()) {
125            // Model doesn't support temperature at all - don't add it
126            warn!(
127                model = model,
128                requested_temperature = config.temperature,
129                "Model does not support temperature parameter at all"
130            );
131        } else {
132            // Use default temperature
133            warn!(
134                model = model,
135                requested_temperature = config.temperature,
136                default_temperature = 1.0,
137                "Model does not support requested temperature, using default"
138            );
139            params.insert("temperature".to_string(), serde_json::json!(1.0));
140        }
141
142        // Add optional sampling parameters
143        if let Some(top_p) = config.top_p {
144            params.insert("top_p".to_string(), serde_json::json!(top_p));
145        }
146
147        // Handle reasoning effort for reasoning models
148        if let Some(reasoning_effort) = config.reasoning_effort {
149            if capabilities.model_type == ModelType::Reasoning {
150                use crate::llm::gateway::ReasoningEffort;
151                let effort_str = match reasoning_effort {
152                    ReasoningEffort::Low => "low",
153                    ReasoningEffort::Medium => "medium",
154                    ReasoningEffort::High => "high",
155                };
156                params.insert("reasoning_effort".to_string(), serde_json::json!(effort_str));
157            } else {
158                warn!(
159                    model = model,
160                    "reasoning_effort specified but model is not a reasoning model, ignoring"
161                );
162            }
163        }
164
165        (params, capabilities.supports_tools)
166    }
167
168    /// Chunk tokens for embedding calculation.
169    fn chunk_text(&self, text: &str, chunk_size: usize) -> Vec<String> {
170        // Simple character-based chunking as a fallback
171        // In production, you'd use a proper tokenizer
172        let chars: Vec<char> = text.chars().collect();
173        let avg_chars_per_token = 4; // Rough estimate
174        let max_chars = chunk_size * avg_chars_per_token;
175
176        if chars.len() <= max_chars {
177            return vec![text.to_string()];
178        }
179
180        let mut chunks = Vec::new();
181        let mut start = 0;
182
183        while start < chars.len() {
184            let end = std::cmp::min(start + max_chars, chars.len());
185            let chunk: String = chars[start..end].iter().collect();
186            chunks.push(chunk);
187            start = end;
188        }
189
190        chunks
191    }
192
193    /// Calculate weighted average of embeddings.
194    fn weighted_average_embeddings(&self, embeddings: &[Vec<f32>], weights: &[f32]) -> Vec<f32> {
195        if embeddings.is_empty() {
196            return vec![];
197        }
198
199        let dimension = embeddings[0].len();
200        let total_weight: f32 = weights.iter().sum();
201
202        // Build weighted sum for each dimension
203        let average: Vec<f32> = (0..dimension)
204            .map(|dim_idx| {
205                embeddings
206                    .iter()
207                    .zip(weights.iter())
208                    .map(|(embedding, &weight)| {
209                        embedding.get(dim_idx).unwrap_or(&0.0) * (weight / total_weight)
210                    })
211                    .sum()
212            })
213            .collect();
214
215        // Normalize
216        let norm: f32 = average.iter().map(|x| x * x).sum::<f32>().sqrt();
217        if norm > 0.0 {
218            average.iter().map(|x| x / norm).collect()
219        } else {
220            average
221        }
222    }
223}
224
225impl Default for OpenAIGateway {
226    fn default() -> Self {
227        Self::new()
228    }
229}
230
231#[async_trait]
232impl LlmGateway for OpenAIGateway {
233    async fn complete(
234        &self,
235        model: &str,
236        messages: &[LlmMessage],
237        tools: Option<&[Box<dyn LlmTool>]>,
238        config: &CompletionConfig,
239    ) -> Result<LlmGatewayResponse> {
240        info!("Delegating to OpenAI for completion");
241        debug!("Model: {}, Message count: {}", model, messages.len());
242
243        let openai_messages = adapt_messages_to_openai(messages)?;
244        let (adapted_params, supports_tools) = self.adapt_parameters_for_model(model, config);
245
246        let mut body = serde_json::json!({
247            "model": model,
248            "messages": openai_messages,
249        });
250
251        // Add adapted parameters
252        for (key, value) in adapted_params {
253            body[key] = value;
254        }
255
256        // Add tools if provided and supported
257        if let Some(tools) = tools {
258            if supports_tools {
259                let tool_defs: Vec<_> = tools.iter().map(|t| t.descriptor()).collect();
260                body["tools"] = serde_json::to_value(tool_defs)?;
261            } else {
262                warn!(model = model, "Model does not support tools, ignoring tool configuration");
263            }
264        }
265
266        // Make API request
267        let response = self
268            .client
269            .post(format!("{}/chat/completions", self.config.base_url))
270            .header("Authorization", format!("Bearer {}", self.config.api_key))
271            .header("Content-Type", "application/json")
272            .json(&body)
273            .send()
274            .await?;
275
276        if !response.status().is_success() {
277            let status = response.status();
278            let error_text = response.text().await.unwrap_or_default();
279            return Err(MojenticError::GatewayError(format!(
280                "OpenAI API error: {} - {}",
281                status, error_text
282            )));
283        }
284
285        let response_body: Value = response.json().await?;
286
287        // Parse content
288        let content = response_body["choices"][0]["message"]["content"].as_str().map(String::from);
289
290        // Parse tool calls if present
291        let tool_calls =
292            if let Some(calls) = response_body["choices"][0]["message"]["tool_calls"].as_array() {
293                convert_tool_calls(calls)
294            } else {
295                vec![]
296            };
297
298        Ok(LlmGatewayResponse {
299            content,
300            object: None,
301            tool_calls,
302            thinking: None,
303        })
304    }
305
306    async fn complete_json(
307        &self,
308        model: &str,
309        messages: &[LlmMessage],
310        schema: Value,
311        config: &CompletionConfig,
312    ) -> Result<Value> {
313        info!("Requesting structured output from OpenAI");
314
315        let openai_messages = adapt_messages_to_openai(messages)?;
316        let (adapted_params, _) = self.adapt_parameters_for_model(model, config);
317
318        let mut body = serde_json::json!({
319            "model": model,
320            "messages": openai_messages,
321            "response_format": {
322                "type": "json_schema",
323                "json_schema": {
324                    "name": "response",
325                    "schema": schema
326                }
327            }
328        });
329
330        // Add adapted parameters
331        for (key, value) in adapted_params {
332            body[key] = value;
333        }
334
335        let response = self
336            .client
337            .post(format!("{}/chat/completions", self.config.base_url))
338            .header("Authorization", format!("Bearer {}", self.config.api_key))
339            .header("Content-Type", "application/json")
340            .json(&body)
341            .send()
342            .await?;
343
344        if !response.status().is_success() {
345            let status = response.status();
346            let error_text = response.text().await.unwrap_or_default();
347            return Err(MojenticError::GatewayError(format!(
348                "OpenAI API error: {} - {}",
349                status, error_text
350            )));
351        }
352
353        let response_body: Value = response.json().await?;
354        let content = response_body["choices"][0]["message"]["content"]
355            .as_str()
356            .ok_or_else(|| MojenticError::GatewayError("No content in response".to_string()))?;
357
358        // Parse the JSON response
359        let json_value: Value = serde_json::from_str(content)?;
360
361        Ok(json_value)
362    }
363
364    async fn get_available_models(&self) -> Result<Vec<String>> {
365        debug!("Fetching available OpenAI models");
366
367        let response = self
368            .client
369            .get(format!("{}/models", self.config.base_url))
370            .header("Authorization", format!("Bearer {}", self.config.api_key))
371            .send()
372            .await?;
373
374        if !response.status().is_success() {
375            return Err(MojenticError::GatewayError(format!(
376                "Failed to get models: {}",
377                response.status()
378            )));
379        }
380
381        let body: Value = response.json().await?;
382
383        let mut models = body["data"]
384            .as_array()
385            .ok_or_else(|| MojenticError::GatewayError("Invalid response format".to_string()))?
386            .iter()
387            .filter_map(|m| m["id"].as_str().map(String::from))
388            .collect::<Vec<_>>();
389
390        models.sort();
391        Ok(models)
392    }
393
394    async fn calculate_embeddings(&self, text: &str, model: Option<&str>) -> Result<Vec<f32>> {
395        let model = model.unwrap_or("text-embedding-3-large");
396        debug!("Calculating embeddings with model: {}", model);
397
398        // Chunk the text to handle token limits
399        let chunks = self.chunk_text(text, 8191);
400
401        if chunks.is_empty() {
402            return Ok(vec![]);
403        }
404
405        let mut all_embeddings = Vec::new();
406        let mut weights = Vec::new();
407
408        for chunk in &chunks {
409            let body = serde_json::json!({
410                "model": model,
411                "input": chunk
412            });
413
414            let response = self
415                .client
416                .post(format!("{}/embeddings", self.config.base_url))
417                .header("Authorization", format!("Bearer {}", self.config.api_key))
418                .header("Content-Type", "application/json")
419                .json(&body)
420                .send()
421                .await?;
422
423            if !response.status().is_success() {
424                return Err(MojenticError::GatewayError(format!(
425                    "Embeddings API error: {}",
426                    response.status()
427                )));
428            }
429
430            let response_body: Value = response.json().await?;
431
432            let embedding: Vec<f32> = response_body["data"][0]["embedding"]
433                .as_array()
434                .ok_or_else(|| {
435                    MojenticError::GatewayError("Invalid embeddings response".to_string())
436                })?
437                .iter()
438                .filter_map(|v| v.as_f64().map(|f| f as f32))
439                .collect();
440
441            weights.push(embedding.len() as f32);
442            all_embeddings.push(embedding);
443        }
444
445        // If only one chunk, return it directly
446        if all_embeddings.len() == 1 {
447            return Ok(all_embeddings.remove(0));
448        }
449
450        // Calculate weighted average
451        Ok(self.weighted_average_embeddings(&all_embeddings, &weights))
452    }
453
454    fn complete_stream<'a>(
455        &'a self,
456        model: &'a str,
457        messages: &'a [LlmMessage],
458        tools: Option<&'a [Box<dyn LlmTool>]>,
459        config: &'a CompletionConfig,
460    ) -> Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send + 'a>> {
461        Box::pin(async_stream::stream! {
462            info!("Starting OpenAI streaming completion");
463            debug!("Model: {}, Message count: {}", model, messages.len());
464
465            // Check if model supports streaming
466            let registry = get_model_registry();
467            let capabilities = registry.get_model_capabilities(model);
468            if !capabilities.supports_streaming {
469                yield Err(MojenticError::GatewayError(format!(
470                    "Model {} does not support streaming",
471                    model
472                )));
473                return;
474            }
475
476            let openai_messages = match adapt_messages_to_openai(messages) {
477                Ok(msgs) => msgs,
478                Err(e) => {
479                    yield Err(e);
480                    return;
481                }
482            };
483
484            let (adapted_params, supports_tools) = self.adapt_parameters_for_model(model, config);
485
486            let mut body = serde_json::json!({
487                "model": model,
488                "messages": openai_messages,
489                "stream": true
490            });
491
492            // Add adapted parameters
493            for (key, value) in adapted_params {
494                body[key] = value;
495            }
496
497            // Add tools if provided and supported
498            if let Some(tools) = tools {
499                if supports_tools {
500                    let tool_defs: Vec<_> = tools.iter().map(|t| t.descriptor()).collect();
501                    if let Ok(tools_value) = serde_json::to_value(tool_defs) {
502                        body["tools"] = tools_value;
503                    }
504                }
505            }
506
507            // Make streaming API request
508            let response = match self
509                .client
510                .post(format!("{}/chat/completions", self.config.base_url))
511                .header("Authorization", format!("Bearer {}", self.config.api_key))
512                .header("Content-Type", "application/json")
513                .json(&body)
514                .send()
515                .await
516            {
517                Ok(r) => r,
518                Err(e) => {
519                    yield Err(e.into());
520                    return;
521                }
522            };
523
524            if !response.status().is_success() {
525                yield Err(MojenticError::GatewayError(format!(
526                    "OpenAI API error: {}",
527                    response.status()
528                )));
529                return;
530            }
531
532            // Process SSE stream
533            let mut stream = response.bytes_stream();
534            let mut buffer = String::new();
535
536            // Accumulate tool calls as they stream in
537            let mut tool_calls_accumulator: HashMap<usize, ToolCallAccumulator> = HashMap::new();
538
539            while let Some(chunk_result) = stream.next().await {
540                match chunk_result {
541                    Ok(bytes) => {
542                        if let Ok(text) = std::str::from_utf8(&bytes) {
543                            buffer.push_str(text);
544
545                            // Process complete SSE lines
546                            while let Some(line_end) = buffer.find('\n') {
547                                let line = buffer[..line_end].trim().to_string();
548                                buffer = buffer[line_end + 1..].to_string();
549
550                                if line.is_empty() || !line.starts_with("data: ") {
551                                    continue;
552                                }
553
554                                let data = line.strip_prefix("data: ").unwrap();
555
556                                if data == "[DONE]" {
557                                    // Final chunk - yield accumulated tool calls if any
558                                    if !tool_calls_accumulator.is_empty() {
559                                        let complete_tool_calls = build_complete_tool_calls(&tool_calls_accumulator);
560                                        if !complete_tool_calls.is_empty() {
561                                            yield Ok(StreamChunk::ToolCalls(complete_tool_calls));
562                                        }
563                                    }
564                                    continue;
565                                }
566
567                                // Parse JSON data
568                                match serde_json::from_str::<Value>(data) {
569                                    Ok(json) => {
570                                        if let Some(choices) = json["choices"].as_array() {
571                                            if choices.is_empty() {
572                                                continue;
573                                            }
574
575                                            let delta = &choices[0]["delta"];
576                                            let finish_reason = choices[0]["finish_reason"].as_str();
577
578                                            // Yield content chunks
579                                            if let Some(content) = delta["content"].as_str() {
580                                                if !content.is_empty() {
581                                                    yield Ok(StreamChunk::Content(content.to_string()));
582                                                }
583                                            }
584
585                                            // Accumulate tool call chunks
586                                            if let Some(tool_calls) = delta["tool_calls"].as_array() {
587                                                for tc in tool_calls {
588                                                    if let Some(index) = tc["index"].as_u64() {
589                                                        let index = index as usize;
590
591                                                        // Initialize accumulator if needed
592                                                        let acc = tool_calls_accumulator.entry(index).or_insert_with(|| ToolCallAccumulator {
593                                                            id: None,
594                                                            name: None,
595                                                            arguments: String::new(),
596                                                        });
597
598                                                        // First chunk has id
599                                                        if let Some(id) = tc["id"].as_str() {
600                                                            acc.id = Some(id.to_string());
601                                                        }
602
603                                                        // First chunk has function name
604                                                        if let Some(name) = tc["function"]["name"].as_str() {
605                                                            acc.name = Some(name.to_string());
606                                                        }
607
608                                                        // All chunks may have argument fragments
609                                                        if let Some(args) = tc["function"]["arguments"].as_str() {
610                                                            acc.arguments.push_str(args);
611                                                        }
612                                                    }
613                                                }
614                                            }
615
616                                            // When stream completes with tool_calls, yield accumulated tool calls
617                                            if finish_reason == Some("tool_calls") && !tool_calls_accumulator.is_empty() {
618                                                let complete_tool_calls = build_complete_tool_calls(&tool_calls_accumulator);
619                                                if !complete_tool_calls.is_empty() {
620                                                    yield Ok(StreamChunk::ToolCalls(complete_tool_calls));
621                                                }
622                                                tool_calls_accumulator.clear();
623                                            }
624                                        }
625                                    }
626                                    Err(e) => {
627                                        warn!("Failed to parse streaming chunk: {}", e);
628                                    }
629                                }
630                            }
631                        }
632                    }
633                    Err(e) => {
634                        yield Err(e.into());
635                        return;
636                    }
637                }
638            }
639        })
640    }
641}
642
643/// Accumulator for streaming tool calls.
644struct ToolCallAccumulator {
645    id: Option<String>,
646    name: Option<String>,
647    arguments: String,
648}
649
650/// Build complete tool calls from accumulators.
651fn build_complete_tool_calls(
652    accumulators: &HashMap<usize, ToolCallAccumulator>,
653) -> Vec<LlmToolCall> {
654    let mut indices: Vec<_> = accumulators.keys().collect();
655    indices.sort();
656
657    indices
658        .iter()
659        .filter_map(|&&index| {
660            let acc = accumulators.get(&index)?;
661            let name = acc.name.clone()?;
662
663            // Parse arguments
664            let arguments: HashMap<String, Value> =
665                serde_json::from_str(&acc.arguments).unwrap_or_default();
666
667            Some(LlmToolCall {
668                id: acc.id.clone(),
669                name,
670                arguments,
671            })
672        })
673        .collect()
674}
675
676#[cfg(test)]
677mod tests {
678    use super::*;
679
680    #[test]
681    fn test_openai_config_default() {
682        std::env::remove_var("OPENAI_API_KEY");
683        std::env::remove_var("OPENAI_API_ENDPOINT");
684        let config = OpenAIConfig::default();
685        assert_eq!(config.api_key, "");
686        assert_eq!(config.base_url, "https://api.openai.com/v1");
687        assert!(config.timeout.is_none());
688    }
689
690    #[test]
691    fn test_openai_config_from_env() {
692        std::env::set_var("OPENAI_API_KEY", "test-key");
693        std::env::set_var("OPENAI_API_ENDPOINT", "https://custom.openai.com");
694        let config = OpenAIConfig::default();
695        assert_eq!(config.api_key, "test-key");
696        assert_eq!(config.base_url, "https://custom.openai.com");
697        std::env::remove_var("OPENAI_API_KEY");
698        std::env::remove_var("OPENAI_API_ENDPOINT");
699    }
700
701    #[test]
702    fn test_gateway_new() {
703        let gateway = OpenAIGateway::new();
704        assert_eq!(gateway.config.base_url, "https://api.openai.com/v1");
705    }
706
707    #[test]
708    fn test_gateway_with_api_key() {
709        let gateway = OpenAIGateway::with_api_key("my-api-key");
710        assert_eq!(gateway.config.api_key, "my-api-key");
711    }
712
713    #[test]
714    fn test_gateway_with_api_key_and_base_url() {
715        let gateway = OpenAIGateway::with_api_key_and_base_url("key", "https://custom.com");
716        assert_eq!(gateway.config.api_key, "key");
717        assert_eq!(gateway.config.base_url, "https://custom.com");
718    }
719
720    #[test]
721    fn test_gateway_default() {
722        let gateway = OpenAIGateway::default();
723        assert_eq!(gateway.config.base_url, "https://api.openai.com/v1");
724    }
725
726    #[test]
727    fn test_chunk_text_short() {
728        let gateway = OpenAIGateway::new();
729        let chunks = gateway.chunk_text("Hello world", 100);
730        assert_eq!(chunks.len(), 1);
731        assert_eq!(chunks[0], "Hello world");
732    }
733
734    #[test]
735    fn test_chunk_text_long() {
736        let gateway = OpenAIGateway::new();
737        let long_text = "a".repeat(50000);
738        let chunks = gateway.chunk_text(&long_text, 100);
739        assert!(chunks.len() > 1);
740    }
741
742    #[test]
743    fn test_weighted_average_embeddings_single() {
744        let gateway = OpenAIGateway::new();
745        let embeddings = vec![vec![1.0, 2.0, 3.0]];
746        let weights = vec![1.0];
747        let result = gateway.weighted_average_embeddings(&embeddings, &weights);
748
749        // Normalized [1, 2, 3] / sqrt(14)
750        let norm = (1.0_f32 + 4.0 + 9.0).sqrt();
751        assert!((result[0] - 1.0 / norm).abs() < 0.001);
752        assert!((result[1] - 2.0 / norm).abs() < 0.001);
753        assert!((result[2] - 3.0 / norm).abs() < 0.001);
754    }
755
756    #[test]
757    fn test_weighted_average_embeddings_multiple() {
758        let gateway = OpenAIGateway::new();
759        let embeddings = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
760        let weights = vec![1.0, 1.0];
761        let result = gateway.weighted_average_embeddings(&embeddings, &weights);
762
763        // Equal weights, average is [0.5, 0.5], normalized to [1/sqrt(2), 1/sqrt(2)]
764        let expected = 1.0 / (2.0_f32).sqrt();
765        assert!((result[0] - expected).abs() < 0.001);
766        assert!((result[1] - expected).abs() < 0.001);
767    }
768
769    #[test]
770    fn test_weighted_average_embeddings_empty() {
771        let gateway = OpenAIGateway::new();
772        let embeddings: Vec<Vec<f32>> = vec![];
773        let weights: Vec<f32> = vec![];
774        let result = gateway.weighted_average_embeddings(&embeddings, &weights);
775        assert!(result.is_empty());
776    }
777
778    #[test]
779    fn test_build_complete_tool_calls() {
780        let mut accumulators = HashMap::new();
781        accumulators.insert(
782            0,
783            ToolCallAccumulator {
784                id: Some("call_123".to_string()),
785                name: Some("get_weather".to_string()),
786                arguments: r#"{"location": "NYC"}"#.to_string(),
787            },
788        );
789        accumulators.insert(
790            1,
791            ToolCallAccumulator {
792                id: Some("call_456".to_string()),
793                name: Some("search".to_string()),
794                arguments: r#"{"query": "test"}"#.to_string(),
795            },
796        );
797
798        let result = build_complete_tool_calls(&accumulators);
799
800        assert_eq!(result.len(), 2);
801        assert_eq!(result[0].id, Some("call_123".to_string()));
802        assert_eq!(result[0].name, "get_weather");
803        assert_eq!(result[1].id, Some("call_456".to_string()));
804        assert_eq!(result[1].name, "search");
805    }
806
807    #[test]
808    fn test_build_complete_tool_calls_missing_name() {
809        let mut accumulators = HashMap::new();
810        accumulators.insert(
811            0,
812            ToolCallAccumulator {
813                id: Some("call_123".to_string()),
814                name: None, // Missing name
815                arguments: r#"{}"#.to_string(),
816            },
817        );
818
819        let result = build_complete_tool_calls(&accumulators);
820        assert!(result.is_empty()); // Should be filtered out
821    }
822
823    #[test]
824    fn test_adapt_parameters_chat_model() {
825        let gateway = OpenAIGateway::new();
826        let config = CompletionConfig {
827            temperature: 0.7,
828            max_tokens: 1000,
829            ..Default::default()
830        };
831
832        let (params, supports_tools) = gateway.adapt_parameters_for_model("gpt-4", &config);
833
834        assert!(params.contains_key("max_tokens"));
835        assert!(!params.contains_key("max_completion_tokens"));
836        assert!(supports_tools);
837    }
838
839    #[test]
840    fn test_adapt_parameters_reasoning_model() {
841        let gateway = OpenAIGateway::new();
842        let config = CompletionConfig {
843            temperature: 0.7,
844            max_tokens: 1000,
845            ..Default::default()
846        };
847
848        let (params, supports_tools) = gateway.adapt_parameters_for_model("o1", &config);
849
850        assert!(!params.contains_key("max_tokens"));
851        assert!(params.contains_key("max_completion_tokens"));
852        assert!(supports_tools); // o1 now supports tools (audit 2026-02-04)
853    }
854
855    #[tokio::test]
856    async fn test_complete_success() {
857        let mut server = mockito::Server::new_async().await;
858        let mock = server
859            .mock("POST", "/chat/completions")
860            .with_status(200)
861            .with_body(r#"{"choices":[{"message":{"role":"assistant","content":"Hello!"}}]}"#)
862            .create();
863
864        let gateway = OpenAIGateway::with_api_key_and_base_url("test-key", server.url());
865        let messages = vec![LlmMessage::user("Hi")];
866        let config = CompletionConfig::default();
867
868        let result = gateway.complete("gpt-4", &messages, None, &config).await;
869
870        mock.assert();
871        assert!(result.is_ok());
872        let response = result.unwrap();
873        assert_eq!(response.content, Some("Hello!".to_string()));
874    }
875
876    #[tokio::test]
877    async fn test_complete_with_tool_calls() {
878        let mut server = mockito::Server::new_async().await;
879        let mock = server
880            .mock("POST", "/chat/completions")
881            .with_status(200)
882            .with_body(r#"{"choices":[{"message":{"role":"assistant","content":null,"tool_calls":[{"id":"call_1","type":"function","function":{"name":"get_weather","arguments":"{\"location\": \"NYC\"}"}}]}}]}"#)
883            .create();
884
885        let gateway = OpenAIGateway::with_api_key_and_base_url("test-key", server.url());
886        let messages = vec![LlmMessage::user("Weather?")];
887        let config = CompletionConfig::default();
888
889        let result = gateway.complete("gpt-4", &messages, None, &config).await;
890
891        mock.assert();
892        assert!(result.is_ok());
893        let response = result.unwrap();
894        assert_eq!(response.tool_calls.len(), 1);
895        assert_eq!(response.tool_calls[0].name, "get_weather");
896    }
897
898    #[tokio::test]
899    async fn test_complete_error() {
900        let mut server = mockito::Server::new_async().await;
901        let mock = server
902            .mock("POST", "/chat/completions")
903            .with_status(401)
904            .with_body("Unauthorized")
905            .create();
906
907        let gateway = OpenAIGateway::with_api_key_and_base_url("bad-key", server.url());
908        let messages = vec![LlmMessage::user("Hi")];
909        let config = CompletionConfig::default();
910
911        let result = gateway.complete("gpt-4", &messages, None, &config).await;
912
913        mock.assert();
914        assert!(result.is_err());
915    }
916
917    #[tokio::test]
918    async fn test_complete_json() {
919        let mut server = mockito::Server::new_async().await;
920        let mock = server
921            .mock("POST", "/chat/completions")
922            .with_status(200)
923            .with_body(
924                r#"{"choices":[{"message":{"content":"{\"name\":\"test\",\"value\":42}"}}]}"#,
925            )
926            .create();
927
928        let gateway = OpenAIGateway::with_api_key_and_base_url("test-key", server.url());
929        let messages = vec![LlmMessage::user("Generate JSON")];
930        let schema = serde_json::json!({"type": "object"});
931        let config = CompletionConfig::default();
932
933        let result = gateway.complete_json("gpt-4", &messages, schema, &config).await;
934
935        mock.assert();
936        assert!(result.is_ok());
937        let json = result.unwrap();
938        assert_eq!(json["name"], "test");
939        assert_eq!(json["value"], 42);
940    }
941
942    #[tokio::test]
943    async fn test_get_available_models() {
944        let mut server = mockito::Server::new_async().await;
945        let mock = server
946            .mock("GET", "/models")
947            .with_status(200)
948            .with_body(r#"{"data":[{"id":"gpt-4"},{"id":"gpt-3.5-turbo"}]}"#)
949            .create();
950
951        let gateway = OpenAIGateway::with_api_key_and_base_url("test-key", server.url());
952        let result = gateway.get_available_models().await;
953
954        mock.assert();
955        assert!(result.is_ok());
956        let models = result.unwrap();
957        assert_eq!(models.len(), 2);
958        // Should be sorted
959        assert_eq!(models[0], "gpt-3.5-turbo");
960        assert_eq!(models[1], "gpt-4");
961    }
962
963    #[tokio::test]
964    async fn test_calculate_embeddings() {
965        let mut server = mockito::Server::new_async().await;
966        let mock = server
967            .mock("POST", "/embeddings")
968            .with_status(200)
969            .with_body(r#"{"data":[{"embedding":[0.1,0.2,0.3,0.4]}]}"#)
970            .create();
971
972        let gateway = OpenAIGateway::with_api_key_and_base_url("test-key", server.url());
973        let result = gateway.calculate_embeddings("test text", None).await;
974
975        mock.assert();
976        assert!(result.is_ok());
977        let embeddings = result.unwrap();
978        assert_eq!(embeddings.len(), 4);
979    }
980
981    #[tokio::test]
982    async fn test_calculate_embeddings_custom_model() {
983        let mut server = mockito::Server::new_async().await;
984        let mock = server
985            .mock("POST", "/embeddings")
986            .match_body(mockito::Matcher::JsonString(
987                r#"{"model":"text-embedding-3-small","input":"test"}"#.to_string(),
988            ))
989            .with_status(200)
990            .with_body(r#"{"data":[{"embedding":[0.5,0.6]}]}"#)
991            .create();
992
993        let gateway = OpenAIGateway::with_api_key_and_base_url("test-key", server.url());
994        let result = gateway.calculate_embeddings("test", Some("text-embedding-3-small")).await;
995
996        mock.assert();
997        assert!(result.is_ok());
998    }
999}