1use crate::error::Result;
4use crate::llm::models::{LlmMessage, LlmToolCall, MessageRole};
5use base64::Engine;
6use serde_json::Value;
7use std::path::Path;
8use tracing::warn;
9
10#[derive(Debug, Clone)]
12pub struct OpenAIMessage {
13 pub role: String,
14 pub content: OpenAIContent,
15 pub tool_calls: Option<Vec<OpenAIToolCall>>,
16 pub tool_call_id: Option<String>,
17}
18
19#[derive(Debug, Clone)]
21pub enum OpenAIContent {
22 Text(String),
23 Parts(Vec<OpenAIContentPart>),
24}
25
26#[derive(Debug, Clone)]
28pub enum OpenAIContentPart {
29 Text { text: String },
30 ImageUrl { url: String },
31}
32
33#[derive(Debug, Clone)]
35pub struct OpenAIToolCall {
36 pub id: String,
37 pub r#type: String,
38 pub function: OpenAIToolCallFunction,
39}
40
41#[derive(Debug, Clone)]
43pub struct OpenAIToolCallFunction {
44 pub name: String,
45 pub arguments: String,
46}
47
48fn get_image_type(file_path: &str) -> &'static str {
50 let ext = Path::new(file_path)
51 .extension()
52 .and_then(|e| e.to_str())
53 .unwrap_or("")
54 .to_lowercase();
55
56 match ext.as_str() {
57 "jpg" | "jpeg" => "jpeg",
58 "png" => "png",
59 "gif" => "gif",
60 "webp" => "webp",
61 _ => "jpeg", }
63}
64
65fn encode_image_as_base64(file_path: &str) -> Result<String> {
67 let bytes = std::fs::read(file_path)?;
68 let base64_data = base64::engine::general_purpose::STANDARD.encode(&bytes);
69 let image_type = get_image_type(file_path);
70 Ok(format!("data:image/{};base64,{}", image_type, base64_data))
71}
72
73pub fn adapt_messages_to_openai(messages: &[LlmMessage]) -> Result<Vec<Value>> {
75 let mut result = Vec::new();
76
77 for msg in messages {
78 let openai_msg = match msg.role {
79 MessageRole::System => {
80 serde_json::json!({
81 "role": "system",
82 "content": msg.content.as_deref().unwrap_or("")
83 })
84 }
85 MessageRole::User => {
86 if let Some(ref image_paths) = msg.image_paths {
88 if !image_paths.is_empty() {
89 let mut content_parts = Vec::new();
90
91 if let Some(ref text) = msg.content {
93 if !text.is_empty() {
94 content_parts.push(serde_json::json!({
95 "type": "text",
96 "text": text
97 }));
98 }
99 }
100
101 for path in image_paths {
103 match encode_image_as_base64(path) {
104 Ok(data_url) => {
105 content_parts.push(serde_json::json!({
106 "type": "image_url",
107 "image_url": {
108 "url": data_url
109 }
110 }));
111 }
112 Err(e) => {
113 warn!(path = path, error = %e, "Failed to encode image");
114 }
115 }
116 }
117
118 serde_json::json!({
119 "role": "user",
120 "content": content_parts
121 })
122 } else {
123 serde_json::json!({
124 "role": "user",
125 "content": msg.content.as_deref().unwrap_or("")
126 })
127 }
128 } else {
129 serde_json::json!({
130 "role": "user",
131 "content": msg.content.as_deref().unwrap_or("")
132 })
133 }
134 }
135 MessageRole::Assistant => {
136 let mut assistant_msg = serde_json::json!({
137 "role": "assistant"
138 });
139
140 if let Some(ref content) = msg.content {
141 assistant_msg["content"] = serde_json::json!(content);
142 }
143
144 if let Some(ref tool_calls) = msg.tool_calls {
146 let formatted_calls: Vec<Value> = tool_calls
147 .iter()
148 .map(|tc| {
149 serde_json::json!({
150 "id": tc.id.as_deref().unwrap_or(""),
151 "type": "function",
152 "function": {
153 "name": tc.name,
154 "arguments": serde_json::to_string(&tc.arguments).unwrap_or_default()
155 }
156 })
157 })
158 .collect();
159 assistant_msg["tool_calls"] = serde_json::json!(formatted_calls);
160 }
161
162 assistant_msg
163 }
164 MessageRole::Tool => {
165 let tool_call_id = msg
167 .tool_calls
168 .as_ref()
169 .and_then(|tcs| tcs.first())
170 .and_then(|tc| tc.id.clone())
171 .unwrap_or_default();
172
173 serde_json::json!({
174 "role": "tool",
175 "content": msg.content.as_deref().unwrap_or(""),
176 "tool_call_id": tool_call_id
177 })
178 }
179 };
180
181 result.push(openai_msg);
182 }
183
184 Ok(result)
185}
186
187pub fn convert_tool_calls(tool_calls: &[Value]) -> Vec<LlmToolCall> {
189 tool_calls
190 .iter()
191 .filter_map(|tc| {
192 let id = tc["id"].as_str().map(String::from);
193 let name = tc["function"]["name"].as_str()?.to_string();
194 let args_str = tc["function"]["arguments"].as_str().unwrap_or("{}");
195
196 let arguments: std::collections::HashMap<String, Value> =
198 serde_json::from_str(args_str).unwrap_or_default();
199
200 Some(LlmToolCall {
201 id,
202 name,
203 arguments,
204 })
205 })
206 .collect()
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212 use std::collections::HashMap;
213 use std::io::Write;
214 use tempfile::NamedTempFile;
215
216 #[test]
217 fn test_get_image_type_jpg() {
218 assert_eq!(get_image_type("/path/to/image.jpg"), "jpeg");
219 assert_eq!(get_image_type("/path/to/image.jpeg"), "jpeg");
220 }
221
222 #[test]
223 fn test_get_image_type_png() {
224 assert_eq!(get_image_type("/path/to/image.png"), "png");
225 }
226
227 #[test]
228 fn test_get_image_type_gif() {
229 assert_eq!(get_image_type("/path/to/image.gif"), "gif");
230 }
231
232 #[test]
233 fn test_get_image_type_webp() {
234 assert_eq!(get_image_type("/path/to/image.webp"), "webp");
235 }
236
237 #[test]
238 fn test_get_image_type_unknown() {
239 assert_eq!(get_image_type("/path/to/image.unknown"), "jpeg");
240 }
241
242 #[test]
243 fn test_adapt_system_message() {
244 let messages = vec![LlmMessage::system("You are helpful")];
245
246 let result = adapt_messages_to_openai(&messages).unwrap();
247
248 assert_eq!(result.len(), 1);
249 assert_eq!(result[0]["role"], "system");
250 assert_eq!(result[0]["content"], "You are helpful");
251 }
252
253 #[test]
254 fn test_adapt_user_message() {
255 let messages = vec![LlmMessage::user("Hello")];
256
257 let result = adapt_messages_to_openai(&messages).unwrap();
258
259 assert_eq!(result.len(), 1);
260 assert_eq!(result[0]["role"], "user");
261 assert_eq!(result[0]["content"], "Hello");
262 }
263
264 #[test]
265 fn test_adapt_assistant_message() {
266 let messages = vec![LlmMessage::assistant("Hi there")];
267
268 let result = adapt_messages_to_openai(&messages).unwrap();
269
270 assert_eq!(result.len(), 1);
271 assert_eq!(result[0]["role"], "assistant");
272 assert_eq!(result[0]["content"], "Hi there");
273 }
274
275 #[test]
276 fn test_adapt_user_message_with_images() {
277 let mut temp_file = NamedTempFile::new().unwrap();
279 temp_file.write_all(b"fake image data").unwrap();
280 let path = temp_file.path().to_string_lossy().to_string();
281
282 let messages =
283 vec![LlmMessage::user("Describe this image").with_images(vec![path.clone()])];
284
285 let result = adapt_messages_to_openai(&messages).unwrap();
286
287 assert_eq!(result.len(), 1);
288 assert_eq!(result[0]["role"], "user");
289
290 let content = &result[0]["content"];
291 assert!(content.is_array());
292
293 let parts = content.as_array().unwrap();
294 assert_eq!(parts.len(), 2);
295 assert_eq!(parts[0]["type"], "text");
296 assert_eq!(parts[0]["text"], "Describe this image");
297 assert_eq!(parts[1]["type"], "image_url");
298 assert!(parts[1]["image_url"]["url"]
299 .as_str()
300 .unwrap()
301 .starts_with("data:image/jpeg;base64,"));
302 }
303
304 #[test]
305 fn test_adapt_assistant_with_tool_calls() {
306 let tool_call = LlmToolCall {
307 id: Some("call_123".to_string()),
308 name: "get_weather".to_string(),
309 arguments: {
310 let mut map = HashMap::new();
311 map.insert("location".to_string(), serde_json::json!("NYC"));
312 map
313 },
314 };
315
316 let messages = vec![LlmMessage {
317 role: MessageRole::Assistant,
318 content: None,
319 tool_calls: Some(vec![tool_call]),
320 image_paths: None,
321 }];
322
323 let result = adapt_messages_to_openai(&messages).unwrap();
324
325 assert_eq!(result.len(), 1);
326 assert_eq!(result[0]["role"], "assistant");
327
328 let tool_calls = &result[0]["tool_calls"];
329 assert!(tool_calls.is_array());
330
331 let calls = tool_calls.as_array().unwrap();
332 assert_eq!(calls.len(), 1);
333 assert_eq!(calls[0]["id"], "call_123");
334 assert_eq!(calls[0]["type"], "function");
335 assert_eq!(calls[0]["function"]["name"], "get_weather");
336 }
337
338 #[test]
339 fn test_adapt_tool_message() {
340 let messages = vec![LlmMessage {
341 role: MessageRole::Tool,
342 content: Some("Weather result: 72F".to_string()),
343 tool_calls: Some(vec![LlmToolCall {
344 id: Some("call_123".to_string()),
345 name: "get_weather".to_string(),
346 arguments: HashMap::new(),
347 }]),
348 image_paths: None,
349 }];
350
351 let result = adapt_messages_to_openai(&messages).unwrap();
352
353 assert_eq!(result.len(), 1);
354 assert_eq!(result[0]["role"], "tool");
355 assert_eq!(result[0]["content"], "Weather result: 72F");
356 assert_eq!(result[0]["tool_call_id"], "call_123");
357 }
358
359 #[test]
360 fn test_convert_tool_calls() {
361 let tool_calls = vec![serde_json::json!({
362 "id": "call_abc",
363 "type": "function",
364 "function": {
365 "name": "search",
366 "arguments": "{\"query\": \"test\"}"
367 }
368 })];
369
370 let result = convert_tool_calls(&tool_calls);
371
372 assert_eq!(result.len(), 1);
373 assert_eq!(result[0].id, Some("call_abc".to_string()));
374 assert_eq!(result[0].name, "search");
375 assert_eq!(result[0].arguments.get("query"), Some(&serde_json::json!("test")));
376 }
377
378 #[test]
379 fn test_convert_tool_calls_empty_args() {
380 let tool_calls = vec![serde_json::json!({
381 "id": "call_xyz",
382 "type": "function",
383 "function": {
384 "name": "no_args_tool",
385 "arguments": "{}"
386 }
387 })];
388
389 let result = convert_tool_calls(&tool_calls);
390
391 assert_eq!(result.len(), 1);
392 assert_eq!(result[0].name, "no_args_tool");
393 assert!(result[0].arguments.is_empty());
394 }
395}