mojentic/llm/gateways/
ollama.rs

1use crate::error::{MojenticError, Result};
2use crate::llm::gateway::{CompletionConfig, LlmGateway, StreamChunk};
3use crate::llm::models::{LlmGatewayResponse, LlmMessage, LlmToolCall, MessageRole};
4use crate::llm::tools::LlmTool;
5use async_trait::async_trait;
6use futures::stream::{Stream, StreamExt};
7use reqwest::Client;
8use serde_json::Value;
9use std::collections::HashMap;
10use std::pin::Pin;
11use tracing::{debug, info, warn};
12
13/// Configuration for connecting to Ollama server
14#[derive(Debug, Clone)]
15pub struct OllamaConfig {
16    pub host: String,
17    pub timeout: Option<std::time::Duration>,
18    pub headers: HashMap<String, String>,
19}
20
21impl Default for OllamaConfig {
22    fn default() -> Self {
23        Self {
24            host: std::env::var("OLLAMA_HOST")
25                .unwrap_or_else(|_| "http://localhost:11434".to_string()),
26            timeout: None,
27            headers: HashMap::new(),
28        }
29    }
30}
31
32/// Gateway for Ollama local LLM service
33///
34/// This gateway provides access to local LLM models through Ollama,
35/// supporting text generation, structured output, tool calling, and embeddings.
36pub struct OllamaGateway {
37    client: Client,
38    config: OllamaConfig,
39}
40
41impl OllamaGateway {
42    /// Create a new Ollama gateway with default configuration
43    pub fn new() -> Self {
44        Self::with_config(OllamaConfig::default())
45    }
46
47    /// Create a new Ollama gateway with custom configuration
48    pub fn with_config(config: OllamaConfig) -> Self {
49        let mut client_builder = Client::builder();
50
51        if let Some(timeout) = config.timeout {
52            client_builder = client_builder.timeout(timeout);
53        }
54
55        let client = client_builder.build().unwrap();
56
57        Self { client, config }
58    }
59
60    /// Create gateway with custom host
61    pub fn with_host(host: impl Into<String>) -> Self {
62        Self::with_config(OllamaConfig {
63            host: host.into(),
64            ..Default::default()
65        })
66    }
67
68    /// Pull a model from Ollama library
69    pub async fn pull_model(&self, model: &str) -> Result<()> {
70        info!("Pulling Ollama model: {}", model);
71
72        let response = self
73            .client
74            .post(format!("{}/api/pull", self.config.host))
75            .json(&serde_json::json!({
76                "name": model
77            }))
78            .send()
79            .await?;
80
81        if !response.status().is_success() {
82            return Err(MojenticError::GatewayError(format!(
83                "Failed to pull model {}: {}",
84                model,
85                response.status()
86            )));
87        }
88
89        Ok(())
90    }
91}
92
93impl Default for OllamaGateway {
94    fn default() -> Self {
95        Self::new()
96    }
97}
98
99#[async_trait]
100impl LlmGateway for OllamaGateway {
101    async fn complete(
102        &self,
103        model: &str,
104        messages: &[LlmMessage],
105        tools: Option<&[Box<dyn LlmTool>]>,
106        config: &CompletionConfig,
107    ) -> Result<LlmGatewayResponse> {
108        info!("Delegating to Ollama for completion");
109        debug!("Model: {}, Message count: {}", model, messages.len());
110
111        let ollama_messages = adapt_messages_to_ollama(messages)?;
112        let options = extract_ollama_options(config);
113
114        let mut body = serde_json::json!({
115            "model": model,
116            "messages": ollama_messages,
117            "options": options,
118            "stream": false
119        });
120
121        // Add tools if provided
122        if let Some(tools) = tools {
123            let tool_defs: Vec<_> = tools.iter().map(|t| t.descriptor()).collect();
124            body["tools"] = serde_json::to_value(tool_defs)?;
125        }
126
127        // Add reasoning effort if specified (Ollama uses "think" parameter)
128        if config.reasoning_effort.is_some() {
129            body["think"] = serde_json::json!(true);
130        }
131
132        // Add response format if specified
133        add_response_format(&mut body, config);
134
135        // Make API request
136        let response = self
137            .client
138            .post(format!("{}/api/chat", self.config.host))
139            .json(&body)
140            .send()
141            .await?;
142
143        if !response.status().is_success() {
144            return Err(MojenticError::GatewayError(format!(
145                "Ollama API error: {}",
146                response.status()
147            )));
148        }
149
150        let response_body: Value = response.json().await?;
151
152        // Parse content
153        let content = response_body["message"]["content"].as_str().map(String::from);
154
155        // Parse thinking if present (from reasoning models)
156        let thinking = response_body["message"]["thinking"].as_str().map(String::from);
157
158        // Parse tool calls if present
159        let tool_calls = if let Some(calls) = response_body["message"]["tool_calls"].as_array() {
160            calls
161                .iter()
162                .filter_map(|call| {
163                    let name = call["function"]["name"].as_str()?.to_string();
164                    let args = call["function"]["arguments"].as_object()?;
165
166                    let arguments: HashMap<String, Value> =
167                        args.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
168
169                    Some(LlmToolCall {
170                        id: call["id"].as_str().map(String::from),
171                        name,
172                        arguments,
173                    })
174                })
175                .collect()
176        } else {
177            vec![]
178        };
179
180        Ok(LlmGatewayResponse {
181            content,
182            object: None,
183            tool_calls,
184            thinking,
185        })
186    }
187
188    async fn complete_json(
189        &self,
190        model: &str,
191        messages: &[LlmMessage],
192        schema: Value,
193        config: &CompletionConfig,
194    ) -> Result<Value> {
195        info!("Requesting structured output from Ollama");
196
197        let ollama_messages = adapt_messages_to_ollama(messages)?;
198        let options = extract_ollama_options(config);
199
200        let body = serde_json::json!({
201            "model": model,
202            "messages": ollama_messages,
203            "options": options,
204            "format": schema,
205            "stream": false
206        });
207
208        let response = self
209            .client
210            .post(format!("{}/api/chat", self.config.host))
211            .json(&body)
212            .send()
213            .await?;
214
215        if !response.status().is_success() {
216            return Err(MojenticError::GatewayError(format!(
217                "Ollama API error: {}",
218                response.status()
219            )));
220        }
221
222        let response_body: Value = response.json().await?;
223        let content = response_body["message"]["content"]
224            .as_str()
225            .ok_or_else(|| MojenticError::GatewayError("No content in response".to_string()))?;
226
227        // Parse the JSON response
228        let json_value: Value = serde_json::from_str(content)?;
229
230        Ok(json_value)
231    }
232
233    async fn get_available_models(&self) -> Result<Vec<String>> {
234        debug!("Fetching available Ollama models");
235
236        let response = self.client.get(format!("{}/api/tags", self.config.host)).send().await?;
237
238        if !response.status().is_success() {
239            return Err(MojenticError::GatewayError(format!(
240                "Failed to get models: {}",
241                response.status()
242            )));
243        }
244
245        let body: Value = response.json().await?;
246
247        let models = body["models"]
248            .as_array()
249            .ok_or_else(|| MojenticError::GatewayError("Invalid response format".to_string()))?
250            .iter()
251            .filter_map(|m| m["name"].as_str().map(String::from))
252            .collect::<Vec<_>>();
253
254        Ok(models)
255    }
256
257    async fn calculate_embeddings(&self, text: &str, model: Option<&str>) -> Result<Vec<f32>> {
258        let model = model.unwrap_or("mxbai-embed-large");
259        debug!("Calculating embeddings with model: {}", model);
260
261        let body = serde_json::json!({
262            "model": model,
263            "prompt": text
264        });
265
266        let response = self
267            .client
268            .post(format!("{}/api/embeddings", self.config.host))
269            .json(&body)
270            .send()
271            .await?;
272
273        if !response.status().is_success() {
274            return Err(MojenticError::GatewayError(format!(
275                "Embeddings API error: {}",
276                response.status()
277            )));
278        }
279
280        let response_body: Value = response.json().await?;
281
282        let embeddings = response_body["embedding"]
283            .as_array()
284            .ok_or_else(|| MojenticError::GatewayError("Invalid embeddings response".to_string()))?
285            .iter()
286            .filter_map(|v| v.as_f64().map(|f| f as f32))
287            .collect();
288
289        Ok(embeddings)
290    }
291
292    fn complete_stream<'a>(
293        &'a self,
294        model: &'a str,
295        messages: &'a [LlmMessage],
296        tools: Option<&'a [Box<dyn LlmTool>]>,
297        config: &'a CompletionConfig,
298    ) -> Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send + 'a>> {
299        Box::pin(async_stream::stream! {
300            info!("Starting Ollama streaming completion");
301            debug!("Model: {}, Message count: {}", model, messages.len());
302
303            let ollama_messages = match adapt_messages_to_ollama(messages) {
304                Ok(msgs) => msgs,
305                Err(e) => {
306                    yield Err(e);
307                    return;
308                }
309            };
310
311            let options = extract_ollama_options(config);
312
313            let mut body = serde_json::json!({
314                "model": model,
315                "messages": ollama_messages,
316                "options": options,
317                "stream": true
318            });
319
320            // Add tools if provided
321            if let Some(tools) = tools {
322                let tool_defs: Vec<_> = tools.iter().map(|t| t.descriptor()).collect();
323                if let Ok(tools_value) = serde_json::to_value(tool_defs) {
324                    body["tools"] = tools_value;
325                }
326            }
327
328            // Add reasoning effort if specified (Ollama uses "think" parameter)
329            if config.reasoning_effort.is_some() {
330                body["think"] = serde_json::json!(true);
331            }
332
333            // Add response format if specified
334            add_response_format(&mut body, config);
335
336            // Make streaming API request
337            let response = match self
338                .client
339                .post(format!("{}/api/chat", self.config.host))
340                .json(&body)
341                .send()
342                .await
343            {
344                Ok(r) => r,
345                Err(e) => {
346                    yield Err(e.into());
347                    return;
348                }
349            };
350
351            if !response.status().is_success() {
352                yield Err(MojenticError::GatewayError(format!(
353                    "Ollama API error: {}",
354                    response.status()
355                )));
356                return;
357            }
358
359            // Process byte stream
360            let mut stream = response.bytes_stream();
361            let mut buffer = String::new();
362            let mut accumulated_tool_calls: Vec<LlmToolCall> = Vec::new();
363
364            while let Some(chunk_result) = stream.next().await {
365                match chunk_result {
366                    Ok(bytes) => {
367                        // Append to buffer
368                        if let Ok(text) = std::str::from_utf8(&bytes) {
369                            buffer.push_str(text);
370
371                            // Process complete JSON lines (newline-delimited)
372                            while let Some(newline_pos) = buffer.find('\n') {
373                                let line = buffer[..newline_pos].trim().to_string();
374                                buffer = buffer[newline_pos + 1..].to_string();
375
376                                if line.is_empty() {
377                                    continue;
378                                }
379
380                                // Parse JSON line
381                                match serde_json::from_str::<Value>(&line) {
382                                    Ok(json) => {
383                                        // Check if streaming is done
384                                        if json["done"].as_bool().unwrap_or(false) {
385                                            // Final chunk - yield accumulated tool calls if any
386                                            if !accumulated_tool_calls.is_empty() {
387                                                yield Ok(StreamChunk::ToolCalls(accumulated_tool_calls.clone()));
388                                            }
389                                            continue;
390                                        }
391
392                                        // Extract content
393                                        if let Some(message) = json["message"].as_object() {
394                                            if let Some(content) = message["content"].as_str() {
395                                                if !content.is_empty() {
396                                                    yield Ok(StreamChunk::Content(content.to_string()));
397                                                }
398                                            }
399
400                                            // Extract tool calls
401                                            if let Some(calls) = message.get("tool_calls").and_then(|v| v.as_array()) {
402                                                for call in calls {
403                                                    if let Some(function) = call.get("function").and_then(|v| v.as_object()) {
404                                                        if let (Some(name), Some(args)) = (
405                                                            function.get("name").and_then(|v| v.as_str()),
406                                                            function.get("arguments").and_then(|v| v.as_object()),
407                                                        ) {
408                                                            let arguments: HashMap<String, Value> = args
409                                                                .iter()
410                                                                .map(|(k, v)| (k.clone(), v.clone()))
411                                                                .collect();
412
413                                                            let tool_call = LlmToolCall {
414                                                                id: call.get("id").and_then(|v| v.as_str()).map(String::from),
415                                                                name: name.to_string(),
416                                                                arguments,
417                                                            };
418
419                                                            accumulated_tool_calls.push(tool_call);
420                                                        }
421                                                    }
422                                                }
423                                            }
424                                        }
425                                    }
426                                    Err(e) => {
427                                        warn!("Failed to parse streaming chunk: {}", e);
428                                    }
429                                }
430                            }
431                        }
432                    }
433                    Err(e) => {
434                        yield Err(e.into());
435                        return;
436                    }
437                }
438            }
439        })
440    }
441}
442
443// Message adapter for Ollama format
444fn adapt_messages_to_ollama(messages: &[LlmMessage]) -> Result<Vec<Value>> {
445    messages
446        .iter()
447        .map(|msg| {
448            let mut ollama_msg = serde_json::json!({
449                "role": match msg.role {
450                    MessageRole::System => "system",
451                    MessageRole::User => "user",
452                    MessageRole::Assistant => "assistant",
453                    MessageRole::Tool => "tool",
454                },
455                "content": msg.content.as_deref().unwrap_or("")
456            });
457
458            // Add images for user messages - Ollama requires base64-encoded images
459            if let Some(image_paths) = &msg.image_paths {
460                let encoded_images: Result<Vec<String>> = image_paths
461                    .iter()
462                    .map(|path| {
463                        std::fs::read(path)
464                            .map_err(|e| {
465                                MojenticError::GatewayError(format!(
466                                    "Failed to read image file {}: {}",
467                                    path, e
468                                ))
469                            })
470                            .map(|bytes| {
471                                base64::Engine::encode(
472                                    &base64::engine::general_purpose::STANDARD,
473                                    bytes,
474                                )
475                            })
476                    })
477                    .collect();
478
479                ollama_msg["images"] = serde_json::to_value(encoded_images?)?;
480            }
481
482            // Add tool calls for assistant messages
483            if let Some(tool_calls) = &msg.tool_calls {
484                let calls: Vec<_> = tool_calls
485                    .iter()
486                    .map(|tc| {
487                        serde_json::json!({
488                            "type": "function",
489                            "function": {
490                                "name": tc.name,
491                                "arguments": tc.arguments
492                            }
493                        })
494                    })
495                    .collect();
496                ollama_msg["tool_calls"] = serde_json::to_value(calls)?;
497            }
498
499            Ok(ollama_msg)
500        })
501        .collect()
502}
503
504// Extract Ollama-specific options from config
505fn extract_ollama_options(config: &CompletionConfig) -> Value {
506    let mut options = serde_json::json!({
507        "temperature": config.temperature,
508        "num_ctx": config.num_ctx,
509    });
510
511    if let Some(num_predict) = config.num_predict {
512        if num_predict > 0 {
513            options["num_predict"] = serde_json::json!(num_predict);
514        }
515    } else if config.max_tokens > 0 {
516        options["num_predict"] = serde_json::json!(config.max_tokens);
517    }
518
519    if let Some(top_p) = config.top_p {
520        options["top_p"] = serde_json::json!(top_p);
521    }
522
523    if let Some(top_k) = config.top_k {
524        options["top_k"] = serde_json::json!(top_k);
525    }
526
527    options
528}
529
530// Add response format to request body if specified
531fn add_response_format(body: &mut Value, config: &CompletionConfig) {
532    use crate::llm::gateway::ResponseFormat;
533
534    if let Some(response_format) = &config.response_format {
535        match response_format {
536            ResponseFormat::JsonObject { schema: Some(s) } => {
537                body["format"] = s.clone();
538            }
539            ResponseFormat::JsonObject { schema: None } => {
540                body["format"] = serde_json::json!("json");
541            }
542            ResponseFormat::Text => {
543                // Text is the default, no need to add format field
544            }
545        }
546    }
547}
548
549#[cfg(test)]
550mod tests {
551    use super::*;
552
553    #[test]
554    fn test_ollama_config_default() {
555        std::env::remove_var("OLLAMA_HOST");
556        let config = OllamaConfig::default();
557        assert_eq!(config.host, "http://localhost:11434");
558        assert!(config.timeout.is_none());
559        assert!(config.headers.is_empty());
560    }
561
562    #[test]
563    fn test_ollama_config_from_env() {
564        std::env::set_var("OLLAMA_HOST", "http://custom:8080");
565        let config = OllamaConfig::default();
566        assert_eq!(config.host, "http://custom:8080");
567        std::env::remove_var("OLLAMA_HOST");
568    }
569
570    #[test]
571    fn test_ollama_config_custom() {
572        let mut headers = HashMap::new();
573        headers.insert("X-Custom".to_string(), "value".to_string());
574
575        let config = OllamaConfig {
576            host: "http://test:9999".to_string(),
577            timeout: Some(std::time::Duration::from_secs(30)),
578            headers,
579        };
580
581        assert_eq!(config.host, "http://test:9999");
582        assert_eq!(config.timeout, Some(std::time::Duration::from_secs(30)));
583        assert_eq!(config.headers.get("X-Custom"), Some(&"value".to_string()));
584    }
585
586    #[test]
587    fn test_gateway_new() {
588        let gateway = OllamaGateway::new();
589        assert_eq!(gateway.config.host, "http://localhost:11434");
590    }
591
592    #[test]
593    fn test_gateway_with_host() {
594        let gateway = OllamaGateway::with_host("http://example.com:8080");
595        assert_eq!(gateway.config.host, "http://example.com:8080");
596    }
597
598    #[test]
599    fn test_gateway_with_config() {
600        let config = OllamaConfig {
601            host: "http://custom:5000".to_string(),
602            timeout: Some(std::time::Duration::from_secs(60)),
603            headers: HashMap::new(),
604        };
605
606        let gateway = OllamaGateway::with_config(config);
607        assert_eq!(gateway.config.host, "http://custom:5000");
608    }
609
610    #[test]
611    fn test_gateway_default() {
612        let gateway = OllamaGateway::default();
613        assert_eq!(gateway.config.host, "http://localhost:11434");
614    }
615
616    #[test]
617    fn test_adapt_messages_to_ollama_simple() {
618        let messages = vec![
619            LlmMessage::system("You are helpful"),
620            LlmMessage::user("Hello"),
621            LlmMessage::assistant("Hi there"),
622        ];
623
624        let result = adapt_messages_to_ollama(&messages).unwrap();
625
626        assert_eq!(result.len(), 3);
627        assert_eq!(result[0]["role"], "system");
628        assert_eq!(result[0]["content"], "You are helpful");
629        assert_eq!(result[1]["role"], "user");
630        assert_eq!(result[1]["content"], "Hello");
631        assert_eq!(result[2]["role"], "assistant");
632        assert_eq!(result[2]["content"], "Hi there");
633    }
634
635    #[test]
636    fn test_adapt_messages_with_images() {
637        use std::io::Write;
638        use tempfile::NamedTempFile;
639
640        // Create temporary test image files
641        let mut temp_file1 = NamedTempFile::new().unwrap();
642        let mut temp_file2 = NamedTempFile::new().unwrap();
643        temp_file1.write_all(b"fake_image_data_1").unwrap();
644        temp_file2.write_all(b"fake_image_data_2").unwrap();
645
646        // Get paths as strings
647        let path1 = temp_file1.path().to_string_lossy().to_string();
648        let path2 = temp_file2.path().to_string_lossy().to_string();
649
650        // Expected base64 encodings
651        let expected_base64_1 = base64::Engine::encode(
652            &base64::engine::general_purpose::STANDARD,
653            b"fake_image_data_1",
654        );
655        let expected_base64_2 = base64::Engine::encode(
656            &base64::engine::general_purpose::STANDARD,
657            b"fake_image_data_2",
658        );
659
660        let messages = vec![LlmMessage::user("Describe this").with_images(vec![path1, path2])];
661
662        let result = adapt_messages_to_ollama(&messages).unwrap();
663
664        assert_eq!(result.len(), 1);
665        assert_eq!(result[0]["role"], "user");
666        // Images should be base64-encoded
667        assert_eq!(result[0]["images"][0], expected_base64_1);
668        assert_eq!(result[0]["images"][1], expected_base64_2);
669    }
670
671    #[test]
672    fn test_adapt_messages_with_tool_calls() {
673        let tool_call = LlmToolCall {
674            id: Some("call_123".to_string()),
675            name: "test_function".to_string(),
676            arguments: {
677                let mut map = HashMap::new();
678                map.insert("arg1".to_string(), serde_json::json!("value1"));
679                map
680            },
681        };
682
683        let messages = vec![LlmMessage {
684            role: MessageRole::Assistant,
685            content: None,
686            tool_calls: Some(vec![tool_call]),
687            image_paths: None,
688        }];
689
690        let result = adapt_messages_to_ollama(&messages).unwrap();
691
692        assert_eq!(result.len(), 1);
693        assert_eq!(result[0]["role"], "assistant");
694        assert_eq!(result[0]["tool_calls"][0]["type"], "function");
695        assert_eq!(result[0]["tool_calls"][0]["function"]["name"], "test_function");
696    }
697
698    #[test]
699    fn test_adapt_messages_empty_content() {
700        let messages = vec![LlmMessage {
701            role: MessageRole::User,
702            content: None,
703            tool_calls: None,
704            image_paths: None,
705        }];
706
707        let result = adapt_messages_to_ollama(&messages).unwrap();
708
709        assert_eq!(result.len(), 1);
710        assert_eq!(result[0]["content"], "");
711    }
712
713    #[test]
714    fn test_adapt_messages_tool_role() {
715        let messages = vec![LlmMessage {
716            role: MessageRole::Tool,
717            content: Some("Tool result".to_string()),
718            tool_calls: None,
719            image_paths: None,
720        }];
721
722        let result = adapt_messages_to_ollama(&messages).unwrap();
723
724        assert_eq!(result.len(), 1);
725        assert_eq!(result[0]["role"], "tool");
726        assert_eq!(result[0]["content"], "Tool result");
727    }
728
729    #[test]
730    fn test_extract_ollama_options_basic() {
731        let config = CompletionConfig {
732            temperature: 0.7,
733            num_ctx: 4096,
734            max_tokens: 2048,
735            num_predict: None,
736            top_p: None,
737            top_k: None,
738            response_format: None,
739            reasoning_effort: None,
740        };
741
742        let options = extract_ollama_options(&config);
743
744        // Use as_f64 for floating point comparison
745        assert!((options["temperature"].as_f64().unwrap() - 0.7).abs() < 0.01);
746        assert_eq!(options["num_ctx"], 4096);
747        // max_tokens should be used as num_predict when num_predict is None
748        assert_eq!(options["num_predict"], 2048);
749    }
750
751    #[test]
752    fn test_extract_ollama_options_with_num_predict() {
753        let config = CompletionConfig {
754            temperature: 0.5,
755            num_ctx: 2048,
756            max_tokens: 1000,
757            num_predict: Some(500),
758            top_p: None,
759            top_k: None,
760            response_format: None,
761            reasoning_effort: None,
762        };
763
764        let options = extract_ollama_options(&config);
765
766        assert!((options["temperature"].as_f64().unwrap() - 0.5).abs() < 0.01);
767        assert_eq!(options["num_ctx"], 2048);
768        // num_predict takes precedence over max_tokens
769        assert_eq!(options["num_predict"], 500);
770    }
771
772    #[test]
773    fn test_extract_ollama_options_zero_num_predict() {
774        let config = CompletionConfig {
775            temperature: 1.0,
776            num_ctx: 8192,
777            max_tokens: 4096,
778            num_predict: Some(0),
779            top_p: None,
780            top_k: None,
781            response_format: None,
782            reasoning_effort: None,
783        };
784
785        let options = extract_ollama_options(&config);
786
787        assert!((options["temperature"].as_f64().unwrap() - 1.0).abs() < 0.01);
788        assert_eq!(options["num_ctx"], 8192);
789        // When num_predict is Some(0) (not > 0), num_predict field is not added
790        // (the else-if branch only runs when num_predict is None)
791        assert!(options.get("num_predict").is_none() || options["num_predict"].is_null());
792    }
793
794    #[test]
795    fn test_extract_ollama_options_zero_max_tokens() {
796        let config = CompletionConfig {
797            temperature: 0.8,
798            num_ctx: 1024,
799            max_tokens: 0,
800            num_predict: None,
801            top_p: None,
802            top_k: None,
803            response_format: None,
804            reasoning_effort: None,
805        };
806
807        let options = extract_ollama_options(&config);
808
809        assert!((options["temperature"].as_f64().unwrap() - 0.8).abs() < 0.01);
810        assert_eq!(options["num_ctx"], 1024);
811        // When max_tokens is 0 and num_predict is None, num_predict field is not added
812        // Check that it's either missing or null (not set in the options object)
813        let num_predict = options.get("num_predict");
814        assert!(num_predict.is_none() || num_predict.unwrap().is_null());
815    }
816
817    #[test]
818    fn test_extract_ollama_options_with_top_p() {
819        let config = CompletionConfig {
820            temperature: 0.7,
821            num_ctx: 4096,
822            max_tokens: 2048,
823            num_predict: None,
824            top_p: Some(0.9),
825            top_k: None,
826            response_format: None,
827            reasoning_effort: None,
828        };
829
830        let options = extract_ollama_options(&config);
831
832        assert!((options["temperature"].as_f64().unwrap() - 0.7).abs() < 0.01);
833        assert_eq!(options["num_ctx"], 4096);
834        assert!((options["top_p"].as_f64().unwrap() - 0.9).abs() < 0.01);
835        assert!(options.get("top_k").is_none());
836    }
837
838    #[test]
839    fn test_extract_ollama_options_with_top_k() {
840        let config = CompletionConfig {
841            temperature: 0.8,
842            num_ctx: 2048,
843            max_tokens: 1024,
844            num_predict: None,
845            top_p: None,
846            top_k: Some(40),
847            response_format: None,
848            reasoning_effort: None,
849        };
850
851        let options = extract_ollama_options(&config);
852
853        assert!((options["temperature"].as_f64().unwrap() - 0.8).abs() < 0.01);
854        assert_eq!(options["top_k"], 40);
855        assert!(options.get("top_p").is_none());
856    }
857
858    #[test]
859    fn test_extract_ollama_options_with_all_sampling_params() {
860        let config = CompletionConfig {
861            temperature: 0.6,
862            num_ctx: 8192,
863            max_tokens: 4096,
864            num_predict: Some(2000),
865            top_p: Some(0.95),
866            top_k: Some(50),
867            response_format: None,
868            reasoning_effort: None,
869        };
870
871        let options = extract_ollama_options(&config);
872
873        assert!((options["temperature"].as_f64().unwrap() - 0.6).abs() < 0.01);
874        assert_eq!(options["num_ctx"], 8192);
875        assert_eq!(options["num_predict"], 2000);
876        assert!((options["top_p"].as_f64().unwrap() - 0.95).abs() < 0.01);
877        assert_eq!(options["top_k"], 50);
878    }
879
880    #[test]
881    fn test_add_response_format_text() {
882        use crate::llm::gateway::ResponseFormat;
883
884        let config = CompletionConfig {
885            temperature: 0.7,
886            num_ctx: 4096,
887            max_tokens: 2048,
888            num_predict: None,
889            top_p: None,
890            top_k: None,
891            response_format: Some(ResponseFormat::Text),
892            reasoning_effort: None,
893        };
894
895        let mut body = serde_json::json!({
896            "model": "test",
897            "messages": []
898        });
899
900        add_response_format(&mut body, &config);
901
902        // Text format shouldn't add a format field
903        assert!(body.get("format").is_none());
904    }
905
906    #[test]
907    fn test_add_response_format_json_no_schema() {
908        use crate::llm::gateway::ResponseFormat;
909
910        let config = CompletionConfig {
911            temperature: 0.7,
912            num_ctx: 4096,
913            max_tokens: 2048,
914            num_predict: None,
915            top_p: None,
916            top_k: None,
917            response_format: Some(ResponseFormat::JsonObject { schema: None }),
918            reasoning_effort: None,
919        };
920
921        let mut body = serde_json::json!({
922            "model": "test",
923            "messages": []
924        });
925
926        add_response_format(&mut body, &config);
927
928        assert_eq!(body["format"], "json");
929    }
930
931    #[test]
932    fn test_add_response_format_json_with_schema() {
933        use crate::llm::gateway::ResponseFormat;
934
935        let schema = serde_json::json!({
936            "type": "object",
937            "properties": {
938                "name": {"type": "string"},
939                "age": {"type": "number"}
940            }
941        });
942
943        let config = CompletionConfig {
944            temperature: 0.7,
945            num_ctx: 4096,
946            max_tokens: 2048,
947            num_predict: None,
948            top_p: None,
949            top_k: None,
950            response_format: Some(ResponseFormat::JsonObject {
951                schema: Some(schema.clone()),
952            }),
953            reasoning_effort: None,
954        };
955
956        let mut body = serde_json::json!({
957            "model": "test",
958            "messages": []
959        });
960
961        add_response_format(&mut body, &config);
962
963        assert_eq!(body["format"], schema);
964    }
965
966    #[test]
967    fn test_add_response_format_none() {
968        let config = CompletionConfig {
969            temperature: 0.7,
970            num_ctx: 4096,
971            max_tokens: 2048,
972            num_predict: None,
973            top_p: None,
974            top_k: None,
975            response_format: None,
976            reasoning_effort: None,
977        };
978
979        let mut body = serde_json::json!({
980            "model": "test",
981            "messages": []
982        });
983
984        add_response_format(&mut body, &config);
985
986        // No format should be added when response_format is None
987        assert!(body.get("format").is_none());
988    }
989
990    #[tokio::test]
991    async fn test_pull_model_success() {
992        let mut server = mockito::Server::new_async().await;
993        let mock = server
994            .mock("POST", "/api/pull")
995            .with_status(200)
996            .with_body(r#"{"status":"success"}"#)
997            .create();
998
999        let gateway = OllamaGateway::with_host(server.url());
1000        let result = gateway.pull_model("llama2").await;
1001
1002        mock.assert();
1003        assert!(result.is_ok());
1004    }
1005
1006    #[tokio::test]
1007    async fn test_pull_model_failure() {
1008        let mut server = mockito::Server::new_async().await;
1009        let mock = server.mock("POST", "/api/pull").with_status(404).create();
1010
1011        let gateway = OllamaGateway::with_host(server.url());
1012        let result = gateway.pull_model("nonexistent").await;
1013
1014        mock.assert();
1015        assert!(result.is_err());
1016    }
1017
1018    #[tokio::test]
1019    async fn test_complete_simple() {
1020        let mut server = mockito::Server::new_async().await;
1021        let mock = server
1022            .mock("POST", "/api/chat")
1023            .with_status(200)
1024            .with_body(r#"{"message":{"role":"assistant","content":"Hello!"}}"#)
1025            .create();
1026
1027        let gateway = OllamaGateway::with_host(server.url());
1028        let messages = vec![LlmMessage::user("Hi")];
1029        let config = CompletionConfig::default();
1030
1031        let result = gateway.complete("llama2", &messages, None, &config).await;
1032
1033        mock.assert();
1034        assert!(result.is_ok());
1035        let response = result.unwrap();
1036        assert_eq!(response.content, Some("Hello!".to_string()));
1037        assert_eq!(response.thinking, None);
1038    }
1039
1040    #[tokio::test]
1041    async fn test_complete_with_tools() {
1042        let mut server = mockito::Server::new_async().await;
1043        let mock = server
1044            .mock("POST", "/api/chat")
1045            .match_body(mockito::Matcher::JsonString(
1046                r#"{"model":"llama2","messages":[{"role":"user","content":"Hi"}],"options":{"temperature":1.0,"num_ctx":32768,"num_predict":16384},"stream":false,"tools":[{"type":"function","function":{"name":"test_tool","description":"A test","parameters":{}}}]}"#.to_string()
1047            ))
1048            .with_status(200)
1049            .with_body(r#"{"message":{"role":"assistant","content":"Result"}}"#)
1050            .create();
1051
1052        let gateway = OllamaGateway::with_host(server.url());
1053        let messages = vec![LlmMessage::user("Hi")];
1054        let config = CompletionConfig::default();
1055
1056        use crate::llm::tools::{FunctionDescriptor, LlmTool, ToolDescriptor};
1057
1058        #[derive(Clone)]
1059        struct MockTool;
1060        impl LlmTool for MockTool {
1061            fn run(&self, _args: &HashMap<String, Value>) -> Result<Value> {
1062                Ok(serde_json::json!({}))
1063            }
1064            fn descriptor(&self) -> ToolDescriptor {
1065                ToolDescriptor {
1066                    r#type: "function".to_string(),
1067                    function: FunctionDescriptor {
1068                        name: "test_tool".to_string(),
1069                        description: "A test".to_string(),
1070                        parameters: serde_json::json!({}),
1071                    },
1072                }
1073            }
1074            fn clone_box(&self) -> Box<dyn LlmTool> {
1075                Box::new(self.clone())
1076            }
1077        }
1078
1079        let tools: Vec<Box<dyn LlmTool>> = vec![Box::new(MockTool)];
1080        let result = gateway.complete("llama2", &messages, Some(&tools), &config).await;
1081
1082        mock.assert();
1083        assert!(result.is_ok());
1084    }
1085
1086    #[tokio::test]
1087    async fn test_complete_error() {
1088        let mut server = mockito::Server::new_async().await;
1089        let mock = server.mock("POST", "/api/chat").with_status(500).create();
1090
1091        let gateway = OllamaGateway::with_host(server.url());
1092        let messages = vec![LlmMessage::user("Hi")];
1093        let config = CompletionConfig::default();
1094
1095        let result = gateway.complete("llama2", &messages, None, &config).await;
1096
1097        mock.assert();
1098        assert!(result.is_err());
1099    }
1100
1101    #[tokio::test]
1102    async fn test_complete_json() {
1103        let mut server = mockito::Server::new_async().await;
1104        let mock = server
1105            .mock("POST", "/api/chat")
1106            .with_status(200)
1107            .with_body(r#"{"message":{"content":"{\"name\":\"test\",\"value\":42}"}}"#)
1108            .create();
1109
1110        let gateway = OllamaGateway::with_host(server.url());
1111        let messages = vec![LlmMessage::user("Generate JSON")];
1112        let schema = serde_json::json!({"type": "object"});
1113        let config = CompletionConfig::default();
1114
1115        let result = gateway.complete_json("llama2", &messages, schema, &config).await;
1116
1117        mock.assert();
1118        assert!(result.is_ok());
1119        let json = result.unwrap();
1120        assert_eq!(json["name"], "test");
1121        assert_eq!(json["value"], 42);
1122    }
1123
1124    #[tokio::test]
1125    async fn test_get_available_models() {
1126        let mut server = mockito::Server::new_async().await;
1127        let mock = server
1128            .mock("GET", "/api/tags")
1129            .with_status(200)
1130            .with_body(r#"{"models":[{"name":"llama2"},{"name":"mistral"}]}"#)
1131            .create();
1132
1133        let gateway = OllamaGateway::with_host(server.url());
1134        let result = gateway.get_available_models().await;
1135
1136        mock.assert();
1137        assert!(result.is_ok());
1138        let models = result.unwrap();
1139        assert_eq!(models.len(), 2);
1140        assert!(models.contains(&"llama2".to_string()));
1141        assert!(models.contains(&"mistral".to_string()));
1142    }
1143
1144    #[tokio::test]
1145    async fn test_calculate_embeddings() {
1146        let mut server = mockito::Server::new_async().await;
1147        let mock = server
1148            .mock("POST", "/api/embeddings")
1149            .with_status(200)
1150            .with_body(r#"{"embedding":[0.1,0.2,0.3,0.4]}"#)
1151            .create();
1152
1153        let gateway = OllamaGateway::with_host(server.url());
1154        let result = gateway.calculate_embeddings("test text", None).await;
1155
1156        mock.assert();
1157        assert!(result.is_ok());
1158        let embeddings = result.unwrap();
1159        assert_eq!(embeddings.len(), 4);
1160        assert_eq!(embeddings[0], 0.1);
1161        assert_eq!(embeddings[3], 0.4);
1162    }
1163
1164    #[tokio::test]
1165    async fn test_calculate_embeddings_custom_model() {
1166        let mut server = mockito::Server::new_async().await;
1167        let mock = server
1168            .mock("POST", "/api/embeddings")
1169            .match_body(mockito::Matcher::JsonString(
1170                r#"{"model":"custom-embed","prompt":"test"}"#.to_string(),
1171            ))
1172            .with_status(200)
1173            .with_body(r#"{"embedding":[0.5,0.6]}"#)
1174            .create();
1175
1176        let gateway = OllamaGateway::with_host(server.url());
1177        let result = gateway.calculate_embeddings("test", Some("custom-embed")).await;
1178
1179        mock.assert();
1180        assert!(result.is_ok());
1181    }
1182
1183    #[tokio::test]
1184    async fn test_complete_with_reasoning_effort() {
1185        use crate::llm::gateway::ReasoningEffort;
1186
1187        let mut server = mockito::Server::new_async().await;
1188        let mock = server
1189            .mock("POST", "/api/chat")
1190            .match_body(mockito::Matcher::PartialJson(
1191                serde_json::json!({"think": true}),
1192            ))
1193            .with_status(200)
1194            .with_body(r#"{"message":{"role":"assistant","content":"Response","thinking":"Internal reasoning..."}}"#)
1195            .create();
1196
1197        let gateway = OllamaGateway::with_host(server.url());
1198        let messages = vec![LlmMessage::user("Test")];
1199        let config = CompletionConfig {
1200            reasoning_effort: Some(ReasoningEffort::High),
1201            ..Default::default()
1202        };
1203
1204        let result = gateway.complete("qwen3:32b", &messages, None, &config).await;
1205
1206        mock.assert();
1207        assert!(result.is_ok());
1208        let response = result.unwrap();
1209        assert_eq!(response.content, Some("Response".to_string()));
1210        assert_eq!(response.thinking, Some("Internal reasoning...".to_string()));
1211    }
1212}