mojentic/llm/gateways/
openai_messages_adapter.rs

1//! Adapter for converting LLM messages to OpenAI format.
2
3use crate::error::Result;
4use crate::llm::models::{LlmMessage, LlmToolCall, MessageRole};
5use base64::Engine;
6use serde_json::Value;
7use std::path::Path;
8use tracing::warn;
9
10/// OpenAI message format.
11#[derive(Debug, Clone)]
12pub struct OpenAIMessage {
13    pub role: String,
14    pub content: OpenAIContent,
15    pub tool_calls: Option<Vec<OpenAIToolCall>>,
16    pub tool_call_id: Option<String>,
17}
18
19/// OpenAI content format (text or multimodal).
20#[derive(Debug, Clone)]
21pub enum OpenAIContent {
22    Text(String),
23    Parts(Vec<OpenAIContentPart>),
24}
25
26/// A part of multimodal content.
27#[derive(Debug, Clone)]
28pub enum OpenAIContentPart {
29    Text { text: String },
30    ImageUrl { url: String },
31}
32
33/// OpenAI tool call format.
34#[derive(Debug, Clone)]
35pub struct OpenAIToolCall {
36    pub id: String,
37    pub r#type: String,
38    pub function: OpenAIToolCallFunction,
39}
40
41/// OpenAI tool call function.
42#[derive(Debug, Clone)]
43pub struct OpenAIToolCallFunction {
44    pub name: String,
45    pub arguments: String,
46}
47
48/// Determine image type from file extension.
49fn get_image_type(file_path: &str) -> &'static str {
50    let ext = Path::new(file_path)
51        .extension()
52        .and_then(|e| e.to_str())
53        .unwrap_or("")
54        .to_lowercase();
55
56    match ext.as_str() {
57        "jpg" | "jpeg" => "jpeg",
58        "png" => "png",
59        "gif" => "gif",
60        "webp" => "webp",
61        _ => "jpeg", // Default to jpeg for unknown types
62    }
63}
64
65/// Read and encode an image file as base64.
66fn encode_image_as_base64(file_path: &str) -> Result<String> {
67    let bytes = std::fs::read(file_path)?;
68    let base64_data = base64::engine::general_purpose::STANDARD.encode(&bytes);
69    let image_type = get_image_type(file_path);
70    Ok(format!("data:image/{};base64,{}", image_type, base64_data))
71}
72
73/// Adapt LLM messages to OpenAI format.
74pub fn adapt_messages_to_openai(messages: &[LlmMessage]) -> Result<Vec<Value>> {
75    let mut result = Vec::new();
76
77    for msg in messages {
78        let openai_msg = match msg.role {
79            MessageRole::System => {
80                serde_json::json!({
81                    "role": "system",
82                    "content": msg.content.as_deref().unwrap_or("")
83                })
84            }
85            MessageRole::User => {
86                // Check for images
87                if let Some(ref image_paths) = msg.image_paths {
88                    if !image_paths.is_empty() {
89                        let mut content_parts = Vec::new();
90
91                        // Add text content
92                        if let Some(ref text) = msg.content {
93                            if !text.is_empty() {
94                                content_parts.push(serde_json::json!({
95                                    "type": "text",
96                                    "text": text
97                                }));
98                            }
99                        }
100
101                        // Add images
102                        for path in image_paths {
103                            match encode_image_as_base64(path) {
104                                Ok(data_url) => {
105                                    content_parts.push(serde_json::json!({
106                                        "type": "image_url",
107                                        "image_url": {
108                                            "url": data_url
109                                        }
110                                    }));
111                                }
112                                Err(e) => {
113                                    warn!(path = path, error = %e, "Failed to encode image");
114                                }
115                            }
116                        }
117
118                        serde_json::json!({
119                            "role": "user",
120                            "content": content_parts
121                        })
122                    } else {
123                        serde_json::json!({
124                            "role": "user",
125                            "content": msg.content.as_deref().unwrap_or("")
126                        })
127                    }
128                } else {
129                    serde_json::json!({
130                        "role": "user",
131                        "content": msg.content.as_deref().unwrap_or("")
132                    })
133                }
134            }
135            MessageRole::Assistant => {
136                let mut assistant_msg = serde_json::json!({
137                    "role": "assistant"
138                });
139
140                if let Some(ref content) = msg.content {
141                    assistant_msg["content"] = serde_json::json!(content);
142                }
143
144                // Add tool calls if present
145                if let Some(ref tool_calls) = msg.tool_calls {
146                    let formatted_calls: Vec<Value> = tool_calls
147                        .iter()
148                        .map(|tc| {
149                            serde_json::json!({
150                                "id": tc.id.as_deref().unwrap_or(""),
151                                "type": "function",
152                                "function": {
153                                    "name": tc.name,
154                                    "arguments": serde_json::to_string(&tc.arguments).unwrap_or_default()
155                                }
156                            })
157                        })
158                        .collect();
159                    assistant_msg["tool_calls"] = serde_json::json!(formatted_calls);
160                }
161
162                assistant_msg
163            }
164            MessageRole::Tool => {
165                // Tool messages need tool_call_id - use the first tool call id if available
166                let tool_call_id = msg
167                    .tool_calls
168                    .as_ref()
169                    .and_then(|tcs| tcs.first())
170                    .and_then(|tc| tc.id.clone())
171                    .unwrap_or_default();
172
173                serde_json::json!({
174                    "role": "tool",
175                    "content": msg.content.as_deref().unwrap_or(""),
176                    "tool_call_id": tool_call_id
177                })
178            }
179        };
180
181        result.push(openai_msg);
182    }
183
184    Ok(result)
185}
186
187/// Convert tool calls from OpenAI format to internal format.
188pub fn convert_tool_calls(tool_calls: &[Value]) -> Vec<LlmToolCall> {
189    tool_calls
190        .iter()
191        .filter_map(|tc| {
192            let id = tc["id"].as_str().map(String::from);
193            let name = tc["function"]["name"].as_str()?.to_string();
194            let args_str = tc["function"]["arguments"].as_str().unwrap_or("{}");
195
196            // Parse arguments as JSON object
197            let arguments: std::collections::HashMap<String, Value> =
198                serde_json::from_str(args_str).unwrap_or_default();
199
200            Some(LlmToolCall {
201                id,
202                name,
203                arguments,
204            })
205        })
206        .collect()
207}
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212    use std::collections::HashMap;
213    use std::io::Write;
214    use tempfile::NamedTempFile;
215
216    #[test]
217    fn test_get_image_type_jpg() {
218        assert_eq!(get_image_type("/path/to/image.jpg"), "jpeg");
219        assert_eq!(get_image_type("/path/to/image.jpeg"), "jpeg");
220    }
221
222    #[test]
223    fn test_get_image_type_png() {
224        assert_eq!(get_image_type("/path/to/image.png"), "png");
225    }
226
227    #[test]
228    fn test_get_image_type_gif() {
229        assert_eq!(get_image_type("/path/to/image.gif"), "gif");
230    }
231
232    #[test]
233    fn test_get_image_type_webp() {
234        assert_eq!(get_image_type("/path/to/image.webp"), "webp");
235    }
236
237    #[test]
238    fn test_get_image_type_unknown() {
239        assert_eq!(get_image_type("/path/to/image.unknown"), "jpeg");
240    }
241
242    #[test]
243    fn test_adapt_system_message() {
244        let messages = vec![LlmMessage::system("You are helpful")];
245
246        let result = adapt_messages_to_openai(&messages).unwrap();
247
248        assert_eq!(result.len(), 1);
249        assert_eq!(result[0]["role"], "system");
250        assert_eq!(result[0]["content"], "You are helpful");
251    }
252
253    #[test]
254    fn test_adapt_user_message() {
255        let messages = vec![LlmMessage::user("Hello")];
256
257        let result = adapt_messages_to_openai(&messages).unwrap();
258
259        assert_eq!(result.len(), 1);
260        assert_eq!(result[0]["role"], "user");
261        assert_eq!(result[0]["content"], "Hello");
262    }
263
264    #[test]
265    fn test_adapt_assistant_message() {
266        let messages = vec![LlmMessage::assistant("Hi there")];
267
268        let result = adapt_messages_to_openai(&messages).unwrap();
269
270        assert_eq!(result.len(), 1);
271        assert_eq!(result[0]["role"], "assistant");
272        assert_eq!(result[0]["content"], "Hi there");
273    }
274
275    #[test]
276    fn test_adapt_user_message_with_images() {
277        // Create a temporary image file
278        let mut temp_file = NamedTempFile::new().unwrap();
279        temp_file.write_all(b"fake image data").unwrap();
280        let path = temp_file.path().to_string_lossy().to_string();
281
282        let messages =
283            vec![LlmMessage::user("Describe this image").with_images(vec![path.clone()])];
284
285        let result = adapt_messages_to_openai(&messages).unwrap();
286
287        assert_eq!(result.len(), 1);
288        assert_eq!(result[0]["role"], "user");
289
290        let content = &result[0]["content"];
291        assert!(content.is_array());
292
293        let parts = content.as_array().unwrap();
294        assert_eq!(parts.len(), 2);
295        assert_eq!(parts[0]["type"], "text");
296        assert_eq!(parts[0]["text"], "Describe this image");
297        assert_eq!(parts[1]["type"], "image_url");
298        assert!(parts[1]["image_url"]["url"]
299            .as_str()
300            .unwrap()
301            .starts_with("data:image/jpeg;base64,"));
302    }
303
304    #[test]
305    fn test_adapt_assistant_with_tool_calls() {
306        let tool_call = LlmToolCall {
307            id: Some("call_123".to_string()),
308            name: "get_weather".to_string(),
309            arguments: {
310                let mut map = HashMap::new();
311                map.insert("location".to_string(), serde_json::json!("NYC"));
312                map
313            },
314        };
315
316        let messages = vec![LlmMessage {
317            role: MessageRole::Assistant,
318            content: None,
319            tool_calls: Some(vec![tool_call]),
320            image_paths: None,
321        }];
322
323        let result = adapt_messages_to_openai(&messages).unwrap();
324
325        assert_eq!(result.len(), 1);
326        assert_eq!(result[0]["role"], "assistant");
327
328        let tool_calls = &result[0]["tool_calls"];
329        assert!(tool_calls.is_array());
330
331        let calls = tool_calls.as_array().unwrap();
332        assert_eq!(calls.len(), 1);
333        assert_eq!(calls[0]["id"], "call_123");
334        assert_eq!(calls[0]["type"], "function");
335        assert_eq!(calls[0]["function"]["name"], "get_weather");
336    }
337
338    #[test]
339    fn test_adapt_tool_message() {
340        let messages = vec![LlmMessage {
341            role: MessageRole::Tool,
342            content: Some("Weather result: 72F".to_string()),
343            tool_calls: Some(vec![LlmToolCall {
344                id: Some("call_123".to_string()),
345                name: "get_weather".to_string(),
346                arguments: HashMap::new(),
347            }]),
348            image_paths: None,
349        }];
350
351        let result = adapt_messages_to_openai(&messages).unwrap();
352
353        assert_eq!(result.len(), 1);
354        assert_eq!(result[0]["role"], "tool");
355        assert_eq!(result[0]["content"], "Weather result: 72F");
356        assert_eq!(result[0]["tool_call_id"], "call_123");
357    }
358
359    #[test]
360    fn test_convert_tool_calls() {
361        let tool_calls = vec![serde_json::json!({
362            "id": "call_abc",
363            "type": "function",
364            "function": {
365                "name": "search",
366                "arguments": "{\"query\": \"test\"}"
367            }
368        })];
369
370        let result = convert_tool_calls(&tool_calls);
371
372        assert_eq!(result.len(), 1);
373        assert_eq!(result[0].id, Some("call_abc".to_string()));
374        assert_eq!(result[0].name, "search");
375        assert_eq!(result[0].arguments.get("query"), Some(&serde_json::json!("test")));
376    }
377
378    #[test]
379    fn test_convert_tool_calls_empty_args() {
380        let tool_calls = vec![serde_json::json!({
381            "id": "call_xyz",
382            "type": "function",
383            "function": {
384                "name": "no_args_tool",
385                "arguments": "{}"
386            }
387        })];
388
389        let result = convert_tool_calls(&tool_calls);
390
391        assert_eq!(result.len(), 1);
392        assert_eq!(result[0].name, "no_args_tool");
393        assert!(result[0].arguments.is_empty());
394    }
395}