1use crate::error::{MojenticError, Result};
2use crate::llm::gateway::{CompletionConfig, LlmGateway, StreamChunk};
3use crate::llm::models::{LlmGatewayResponse, LlmMessage, LlmToolCall, MessageRole};
4use crate::llm::tools::LlmTool;
5use async_trait::async_trait;
6use futures::stream::{Stream, StreamExt};
7use reqwest::Client;
8use serde_json::Value;
9use std::collections::HashMap;
10use std::pin::Pin;
11use tracing::{debug, info, warn};
12
13#[derive(Debug, Clone)]
15pub struct OllamaConfig {
16 pub host: String,
17 pub timeout: Option<std::time::Duration>,
18 pub headers: HashMap<String, String>,
19}
20
21impl Default for OllamaConfig {
22 fn default() -> Self {
23 Self {
24 host: std::env::var("OLLAMA_HOST")
25 .unwrap_or_else(|_| "http://localhost:11434".to_string()),
26 timeout: None,
27 headers: HashMap::new(),
28 }
29 }
30}
31
32pub struct OllamaGateway {
37 client: Client,
38 config: OllamaConfig,
39}
40
41impl OllamaGateway {
42 pub fn new() -> Self {
44 Self::with_config(OllamaConfig::default())
45 }
46
47 pub fn with_config(config: OllamaConfig) -> Self {
49 let mut client_builder = Client::builder();
50
51 if let Some(timeout) = config.timeout {
52 client_builder = client_builder.timeout(timeout);
53 }
54
55 let client = client_builder.build().unwrap();
56
57 Self { client, config }
58 }
59
60 pub fn with_host(host: impl Into<String>) -> Self {
62 Self::with_config(OllamaConfig {
63 host: host.into(),
64 ..Default::default()
65 })
66 }
67
68 pub async fn pull_model(&self, model: &str) -> Result<()> {
70 info!("Pulling Ollama model: {}", model);
71
72 let response = self
73 .client
74 .post(format!("{}/api/pull", self.config.host))
75 .json(&serde_json::json!({
76 "name": model
77 }))
78 .send()
79 .await?;
80
81 if !response.status().is_success() {
82 return Err(MojenticError::GatewayError(format!(
83 "Failed to pull model {}: {}",
84 model,
85 response.status()
86 )));
87 }
88
89 Ok(())
90 }
91}
92
93impl Default for OllamaGateway {
94 fn default() -> Self {
95 Self::new()
96 }
97}
98
99#[async_trait]
100impl LlmGateway for OllamaGateway {
101 async fn complete(
102 &self,
103 model: &str,
104 messages: &[LlmMessage],
105 tools: Option<&[Box<dyn LlmTool>]>,
106 config: &CompletionConfig,
107 ) -> Result<LlmGatewayResponse> {
108 info!("Delegating to Ollama for completion");
109 debug!("Model: {}, Message count: {}", model, messages.len());
110
111 let ollama_messages = adapt_messages_to_ollama(messages)?;
112 let options = extract_ollama_options(config);
113
114 let mut body = serde_json::json!({
115 "model": model,
116 "messages": ollama_messages,
117 "options": options,
118 "stream": false
119 });
120
121 if let Some(tools) = tools {
123 let tool_defs: Vec<_> = tools.iter().map(|t| t.descriptor()).collect();
124 body["tools"] = serde_json::to_value(tool_defs)?;
125 }
126
127 if config.reasoning_effort.is_some() {
129 body["think"] = serde_json::json!(true);
130 }
131
132 add_response_format(&mut body, config);
134
135 let response = self
137 .client
138 .post(format!("{}/api/chat", self.config.host))
139 .json(&body)
140 .send()
141 .await?;
142
143 if !response.status().is_success() {
144 return Err(MojenticError::GatewayError(format!(
145 "Ollama API error: {}",
146 response.status()
147 )));
148 }
149
150 let response_body: Value = response.json().await?;
151
152 let content = response_body["message"]["content"].as_str().map(String::from);
154
155 let thinking = response_body["message"]["thinking"].as_str().map(String::from);
157
158 let tool_calls = if let Some(calls) = response_body["message"]["tool_calls"].as_array() {
160 calls
161 .iter()
162 .filter_map(|call| {
163 let name = call["function"]["name"].as_str()?.to_string();
164 let args = call["function"]["arguments"].as_object()?;
165
166 let arguments: HashMap<String, Value> =
167 args.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
168
169 Some(LlmToolCall {
170 id: call["id"].as_str().map(String::from),
171 name,
172 arguments,
173 })
174 })
175 .collect()
176 } else {
177 vec![]
178 };
179
180 Ok(LlmGatewayResponse {
181 content,
182 object: None,
183 tool_calls,
184 thinking,
185 })
186 }
187
188 async fn complete_json(
189 &self,
190 model: &str,
191 messages: &[LlmMessage],
192 schema: Value,
193 config: &CompletionConfig,
194 ) -> Result<Value> {
195 info!("Requesting structured output from Ollama");
196
197 let ollama_messages = adapt_messages_to_ollama(messages)?;
198 let options = extract_ollama_options(config);
199
200 let body = serde_json::json!({
201 "model": model,
202 "messages": ollama_messages,
203 "options": options,
204 "format": schema,
205 "stream": false
206 });
207
208 let response = self
209 .client
210 .post(format!("{}/api/chat", self.config.host))
211 .json(&body)
212 .send()
213 .await?;
214
215 if !response.status().is_success() {
216 return Err(MojenticError::GatewayError(format!(
217 "Ollama API error: {}",
218 response.status()
219 )));
220 }
221
222 let response_body: Value = response.json().await?;
223 let content = response_body["message"]["content"]
224 .as_str()
225 .ok_or_else(|| MojenticError::GatewayError("No content in response".to_string()))?;
226
227 let json_value: Value = serde_json::from_str(content)?;
229
230 Ok(json_value)
231 }
232
233 async fn get_available_models(&self) -> Result<Vec<String>> {
234 debug!("Fetching available Ollama models");
235
236 let response = self.client.get(format!("{}/api/tags", self.config.host)).send().await?;
237
238 if !response.status().is_success() {
239 return Err(MojenticError::GatewayError(format!(
240 "Failed to get models: {}",
241 response.status()
242 )));
243 }
244
245 let body: Value = response.json().await?;
246
247 let models = body["models"]
248 .as_array()
249 .ok_or_else(|| MojenticError::GatewayError("Invalid response format".to_string()))?
250 .iter()
251 .filter_map(|m| m["name"].as_str().map(String::from))
252 .collect::<Vec<_>>();
253
254 Ok(models)
255 }
256
257 async fn calculate_embeddings(&self, text: &str, model: Option<&str>) -> Result<Vec<f32>> {
258 let model = model.unwrap_or("mxbai-embed-large");
259 debug!("Calculating embeddings with model: {}", model);
260
261 let body = serde_json::json!({
262 "model": model,
263 "prompt": text
264 });
265
266 let response = self
267 .client
268 .post(format!("{}/api/embeddings", self.config.host))
269 .json(&body)
270 .send()
271 .await?;
272
273 if !response.status().is_success() {
274 return Err(MojenticError::GatewayError(format!(
275 "Embeddings API error: {}",
276 response.status()
277 )));
278 }
279
280 let response_body: Value = response.json().await?;
281
282 let embeddings = response_body["embedding"]
283 .as_array()
284 .ok_or_else(|| MojenticError::GatewayError("Invalid embeddings response".to_string()))?
285 .iter()
286 .filter_map(|v| v.as_f64().map(|f| f as f32))
287 .collect();
288
289 Ok(embeddings)
290 }
291
292 fn complete_stream<'a>(
293 &'a self,
294 model: &'a str,
295 messages: &'a [LlmMessage],
296 tools: Option<&'a [Box<dyn LlmTool>]>,
297 config: &'a CompletionConfig,
298 ) -> Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send + 'a>> {
299 Box::pin(async_stream::stream! {
300 info!("Starting Ollama streaming completion");
301 debug!("Model: {}, Message count: {}", model, messages.len());
302
303 let ollama_messages = match adapt_messages_to_ollama(messages) {
304 Ok(msgs) => msgs,
305 Err(e) => {
306 yield Err(e);
307 return;
308 }
309 };
310
311 let options = extract_ollama_options(config);
312
313 let mut body = serde_json::json!({
314 "model": model,
315 "messages": ollama_messages,
316 "options": options,
317 "stream": true
318 });
319
320 if let Some(tools) = tools {
322 let tool_defs: Vec<_> = tools.iter().map(|t| t.descriptor()).collect();
323 if let Ok(tools_value) = serde_json::to_value(tool_defs) {
324 body["tools"] = tools_value;
325 }
326 }
327
328 if config.reasoning_effort.is_some() {
330 body["think"] = serde_json::json!(true);
331 }
332
333 add_response_format(&mut body, config);
335
336 let response = match self
338 .client
339 .post(format!("{}/api/chat", self.config.host))
340 .json(&body)
341 .send()
342 .await
343 {
344 Ok(r) => r,
345 Err(e) => {
346 yield Err(e.into());
347 return;
348 }
349 };
350
351 if !response.status().is_success() {
352 yield Err(MojenticError::GatewayError(format!(
353 "Ollama API error: {}",
354 response.status()
355 )));
356 return;
357 }
358
359 let mut stream = response.bytes_stream();
361 let mut buffer = String::new();
362 let mut accumulated_tool_calls: Vec<LlmToolCall> = Vec::new();
363
364 while let Some(chunk_result) = stream.next().await {
365 match chunk_result {
366 Ok(bytes) => {
367 if let Ok(text) = std::str::from_utf8(&bytes) {
369 buffer.push_str(text);
370
371 while let Some(newline_pos) = buffer.find('\n') {
373 let line = buffer[..newline_pos].trim().to_string();
374 buffer = buffer[newline_pos + 1..].to_string();
375
376 if line.is_empty() {
377 continue;
378 }
379
380 match serde_json::from_str::<Value>(&line) {
382 Ok(json) => {
383 if json["done"].as_bool().unwrap_or(false) {
385 if !accumulated_tool_calls.is_empty() {
387 yield Ok(StreamChunk::ToolCalls(accumulated_tool_calls.clone()));
388 }
389 continue;
390 }
391
392 if let Some(message) = json["message"].as_object() {
394 if let Some(content) = message["content"].as_str() {
395 if !content.is_empty() {
396 yield Ok(StreamChunk::Content(content.to_string()));
397 }
398 }
399
400 if let Some(calls) = message.get("tool_calls").and_then(|v| v.as_array()) {
402 for call in calls {
403 if let Some(function) = call.get("function").and_then(|v| v.as_object()) {
404 if let (Some(name), Some(args)) = (
405 function.get("name").and_then(|v| v.as_str()),
406 function.get("arguments").and_then(|v| v.as_object()),
407 ) {
408 let arguments: HashMap<String, Value> = args
409 .iter()
410 .map(|(k, v)| (k.clone(), v.clone()))
411 .collect();
412
413 let tool_call = LlmToolCall {
414 id: call.get("id").and_then(|v| v.as_str()).map(String::from),
415 name: name.to_string(),
416 arguments,
417 };
418
419 accumulated_tool_calls.push(tool_call);
420 }
421 }
422 }
423 }
424 }
425 }
426 Err(e) => {
427 warn!("Failed to parse streaming chunk: {}", e);
428 }
429 }
430 }
431 }
432 }
433 Err(e) => {
434 yield Err(e.into());
435 return;
436 }
437 }
438 }
439 })
440 }
441}
442
443fn adapt_messages_to_ollama(messages: &[LlmMessage]) -> Result<Vec<Value>> {
445 messages
446 .iter()
447 .map(|msg| {
448 let mut ollama_msg = serde_json::json!({
449 "role": match msg.role {
450 MessageRole::System => "system",
451 MessageRole::User => "user",
452 MessageRole::Assistant => "assistant",
453 MessageRole::Tool => "tool",
454 },
455 "content": msg.content.as_deref().unwrap_or("")
456 });
457
458 if let Some(image_paths) = &msg.image_paths {
460 let encoded_images: Result<Vec<String>> = image_paths
461 .iter()
462 .map(|path| {
463 std::fs::read(path)
464 .map_err(|e| {
465 MojenticError::GatewayError(format!(
466 "Failed to read image file {}: {}",
467 path, e
468 ))
469 })
470 .map(|bytes| {
471 base64::Engine::encode(
472 &base64::engine::general_purpose::STANDARD,
473 bytes,
474 )
475 })
476 })
477 .collect();
478
479 ollama_msg["images"] = serde_json::to_value(encoded_images?)?;
480 }
481
482 if let Some(tool_calls) = &msg.tool_calls {
484 let calls: Vec<_> = tool_calls
485 .iter()
486 .map(|tc| {
487 serde_json::json!({
488 "type": "function",
489 "function": {
490 "name": tc.name,
491 "arguments": tc.arguments
492 }
493 })
494 })
495 .collect();
496 ollama_msg["tool_calls"] = serde_json::to_value(calls)?;
497 }
498
499 Ok(ollama_msg)
500 })
501 .collect()
502}
503
504fn extract_ollama_options(config: &CompletionConfig) -> Value {
506 let mut options = serde_json::json!({
507 "temperature": config.temperature,
508 "num_ctx": config.num_ctx,
509 });
510
511 if let Some(num_predict) = config.num_predict {
512 if num_predict > 0 {
513 options["num_predict"] = serde_json::json!(num_predict);
514 }
515 } else if config.max_tokens > 0 {
516 options["num_predict"] = serde_json::json!(config.max_tokens);
517 }
518
519 if let Some(top_p) = config.top_p {
520 options["top_p"] = serde_json::json!(top_p);
521 }
522
523 if let Some(top_k) = config.top_k {
524 options["top_k"] = serde_json::json!(top_k);
525 }
526
527 options
528}
529
530fn add_response_format(body: &mut Value, config: &CompletionConfig) {
532 use crate::llm::gateway::ResponseFormat;
533
534 if let Some(response_format) = &config.response_format {
535 match response_format {
536 ResponseFormat::JsonObject { schema: Some(s) } => {
537 body["format"] = s.clone();
538 }
539 ResponseFormat::JsonObject { schema: None } => {
540 body["format"] = serde_json::json!("json");
541 }
542 ResponseFormat::Text => {
543 }
545 }
546 }
547}
548
549#[cfg(test)]
550mod tests {
551 use super::*;
552
553 #[test]
554 fn test_ollama_config_default() {
555 std::env::remove_var("OLLAMA_HOST");
556 let config = OllamaConfig::default();
557 assert_eq!(config.host, "http://localhost:11434");
558 assert!(config.timeout.is_none());
559 assert!(config.headers.is_empty());
560 }
561
562 #[test]
563 fn test_ollama_config_from_env() {
564 std::env::set_var("OLLAMA_HOST", "http://custom:8080");
565 let config = OllamaConfig::default();
566 assert_eq!(config.host, "http://custom:8080");
567 std::env::remove_var("OLLAMA_HOST");
568 }
569
570 #[test]
571 fn test_ollama_config_custom() {
572 let mut headers = HashMap::new();
573 headers.insert("X-Custom".to_string(), "value".to_string());
574
575 let config = OllamaConfig {
576 host: "http://test:9999".to_string(),
577 timeout: Some(std::time::Duration::from_secs(30)),
578 headers,
579 };
580
581 assert_eq!(config.host, "http://test:9999");
582 assert_eq!(config.timeout, Some(std::time::Duration::from_secs(30)));
583 assert_eq!(config.headers.get("X-Custom"), Some(&"value".to_string()));
584 }
585
586 #[test]
587 fn test_gateway_new() {
588 let gateway = OllamaGateway::new();
589 assert_eq!(gateway.config.host, "http://localhost:11434");
590 }
591
592 #[test]
593 fn test_gateway_with_host() {
594 let gateway = OllamaGateway::with_host("http://example.com:8080");
595 assert_eq!(gateway.config.host, "http://example.com:8080");
596 }
597
598 #[test]
599 fn test_gateway_with_config() {
600 let config = OllamaConfig {
601 host: "http://custom:5000".to_string(),
602 timeout: Some(std::time::Duration::from_secs(60)),
603 headers: HashMap::new(),
604 };
605
606 let gateway = OllamaGateway::with_config(config);
607 assert_eq!(gateway.config.host, "http://custom:5000");
608 }
609
610 #[test]
611 fn test_gateway_default() {
612 let gateway = OllamaGateway::default();
613 assert_eq!(gateway.config.host, "http://localhost:11434");
614 }
615
616 #[test]
617 fn test_adapt_messages_to_ollama_simple() {
618 let messages = vec![
619 LlmMessage::system("You are helpful"),
620 LlmMessage::user("Hello"),
621 LlmMessage::assistant("Hi there"),
622 ];
623
624 let result = adapt_messages_to_ollama(&messages).unwrap();
625
626 assert_eq!(result.len(), 3);
627 assert_eq!(result[0]["role"], "system");
628 assert_eq!(result[0]["content"], "You are helpful");
629 assert_eq!(result[1]["role"], "user");
630 assert_eq!(result[1]["content"], "Hello");
631 assert_eq!(result[2]["role"], "assistant");
632 assert_eq!(result[2]["content"], "Hi there");
633 }
634
635 #[test]
636 fn test_adapt_messages_with_images() {
637 use std::io::Write;
638 use tempfile::NamedTempFile;
639
640 let mut temp_file1 = NamedTempFile::new().unwrap();
642 let mut temp_file2 = NamedTempFile::new().unwrap();
643 temp_file1.write_all(b"fake_image_data_1").unwrap();
644 temp_file2.write_all(b"fake_image_data_2").unwrap();
645
646 let path1 = temp_file1.path().to_string_lossy().to_string();
648 let path2 = temp_file2.path().to_string_lossy().to_string();
649
650 let expected_base64_1 = base64::Engine::encode(
652 &base64::engine::general_purpose::STANDARD,
653 b"fake_image_data_1",
654 );
655 let expected_base64_2 = base64::Engine::encode(
656 &base64::engine::general_purpose::STANDARD,
657 b"fake_image_data_2",
658 );
659
660 let messages = vec![LlmMessage::user("Describe this").with_images(vec![path1, path2])];
661
662 let result = adapt_messages_to_ollama(&messages).unwrap();
663
664 assert_eq!(result.len(), 1);
665 assert_eq!(result[0]["role"], "user");
666 assert_eq!(result[0]["images"][0], expected_base64_1);
668 assert_eq!(result[0]["images"][1], expected_base64_2);
669 }
670
671 #[test]
672 fn test_adapt_messages_with_tool_calls() {
673 let tool_call = LlmToolCall {
674 id: Some("call_123".to_string()),
675 name: "test_function".to_string(),
676 arguments: {
677 let mut map = HashMap::new();
678 map.insert("arg1".to_string(), serde_json::json!("value1"));
679 map
680 },
681 };
682
683 let messages = vec![LlmMessage {
684 role: MessageRole::Assistant,
685 content: None,
686 tool_calls: Some(vec![tool_call]),
687 image_paths: None,
688 }];
689
690 let result = adapt_messages_to_ollama(&messages).unwrap();
691
692 assert_eq!(result.len(), 1);
693 assert_eq!(result[0]["role"], "assistant");
694 assert_eq!(result[0]["tool_calls"][0]["type"], "function");
695 assert_eq!(result[0]["tool_calls"][0]["function"]["name"], "test_function");
696 }
697
698 #[test]
699 fn test_adapt_messages_empty_content() {
700 let messages = vec![LlmMessage {
701 role: MessageRole::User,
702 content: None,
703 tool_calls: None,
704 image_paths: None,
705 }];
706
707 let result = adapt_messages_to_ollama(&messages).unwrap();
708
709 assert_eq!(result.len(), 1);
710 assert_eq!(result[0]["content"], "");
711 }
712
713 #[test]
714 fn test_adapt_messages_tool_role() {
715 let messages = vec![LlmMessage {
716 role: MessageRole::Tool,
717 content: Some("Tool result".to_string()),
718 tool_calls: None,
719 image_paths: None,
720 }];
721
722 let result = adapt_messages_to_ollama(&messages).unwrap();
723
724 assert_eq!(result.len(), 1);
725 assert_eq!(result[0]["role"], "tool");
726 assert_eq!(result[0]["content"], "Tool result");
727 }
728
729 #[test]
730 fn test_extract_ollama_options_basic() {
731 let config = CompletionConfig {
732 temperature: 0.7,
733 num_ctx: 4096,
734 max_tokens: 2048,
735 num_predict: None,
736 top_p: None,
737 top_k: None,
738 response_format: None,
739 reasoning_effort: None,
740 };
741
742 let options = extract_ollama_options(&config);
743
744 assert!((options["temperature"].as_f64().unwrap() - 0.7).abs() < 0.01);
746 assert_eq!(options["num_ctx"], 4096);
747 assert_eq!(options["num_predict"], 2048);
749 }
750
751 #[test]
752 fn test_extract_ollama_options_with_num_predict() {
753 let config = CompletionConfig {
754 temperature: 0.5,
755 num_ctx: 2048,
756 max_tokens: 1000,
757 num_predict: Some(500),
758 top_p: None,
759 top_k: None,
760 response_format: None,
761 reasoning_effort: None,
762 };
763
764 let options = extract_ollama_options(&config);
765
766 assert!((options["temperature"].as_f64().unwrap() - 0.5).abs() < 0.01);
767 assert_eq!(options["num_ctx"], 2048);
768 assert_eq!(options["num_predict"], 500);
770 }
771
772 #[test]
773 fn test_extract_ollama_options_zero_num_predict() {
774 let config = CompletionConfig {
775 temperature: 1.0,
776 num_ctx: 8192,
777 max_tokens: 4096,
778 num_predict: Some(0),
779 top_p: None,
780 top_k: None,
781 response_format: None,
782 reasoning_effort: None,
783 };
784
785 let options = extract_ollama_options(&config);
786
787 assert!((options["temperature"].as_f64().unwrap() - 1.0).abs() < 0.01);
788 assert_eq!(options["num_ctx"], 8192);
789 assert!(options.get("num_predict").is_none() || options["num_predict"].is_null());
792 }
793
794 #[test]
795 fn test_extract_ollama_options_zero_max_tokens() {
796 let config = CompletionConfig {
797 temperature: 0.8,
798 num_ctx: 1024,
799 max_tokens: 0,
800 num_predict: None,
801 top_p: None,
802 top_k: None,
803 response_format: None,
804 reasoning_effort: None,
805 };
806
807 let options = extract_ollama_options(&config);
808
809 assert!((options["temperature"].as_f64().unwrap() - 0.8).abs() < 0.01);
810 assert_eq!(options["num_ctx"], 1024);
811 let num_predict = options.get("num_predict");
814 assert!(num_predict.is_none() || num_predict.unwrap().is_null());
815 }
816
817 #[test]
818 fn test_extract_ollama_options_with_top_p() {
819 let config = CompletionConfig {
820 temperature: 0.7,
821 num_ctx: 4096,
822 max_tokens: 2048,
823 num_predict: None,
824 top_p: Some(0.9),
825 top_k: None,
826 response_format: None,
827 reasoning_effort: None,
828 };
829
830 let options = extract_ollama_options(&config);
831
832 assert!((options["temperature"].as_f64().unwrap() - 0.7).abs() < 0.01);
833 assert_eq!(options["num_ctx"], 4096);
834 assert!((options["top_p"].as_f64().unwrap() - 0.9).abs() < 0.01);
835 assert!(options.get("top_k").is_none());
836 }
837
838 #[test]
839 fn test_extract_ollama_options_with_top_k() {
840 let config = CompletionConfig {
841 temperature: 0.8,
842 num_ctx: 2048,
843 max_tokens: 1024,
844 num_predict: None,
845 top_p: None,
846 top_k: Some(40),
847 response_format: None,
848 reasoning_effort: None,
849 };
850
851 let options = extract_ollama_options(&config);
852
853 assert!((options["temperature"].as_f64().unwrap() - 0.8).abs() < 0.01);
854 assert_eq!(options["top_k"], 40);
855 assert!(options.get("top_p").is_none());
856 }
857
858 #[test]
859 fn test_extract_ollama_options_with_all_sampling_params() {
860 let config = CompletionConfig {
861 temperature: 0.6,
862 num_ctx: 8192,
863 max_tokens: 4096,
864 num_predict: Some(2000),
865 top_p: Some(0.95),
866 top_k: Some(50),
867 response_format: None,
868 reasoning_effort: None,
869 };
870
871 let options = extract_ollama_options(&config);
872
873 assert!((options["temperature"].as_f64().unwrap() - 0.6).abs() < 0.01);
874 assert_eq!(options["num_ctx"], 8192);
875 assert_eq!(options["num_predict"], 2000);
876 assert!((options["top_p"].as_f64().unwrap() - 0.95).abs() < 0.01);
877 assert_eq!(options["top_k"], 50);
878 }
879
880 #[test]
881 fn test_add_response_format_text() {
882 use crate::llm::gateway::ResponseFormat;
883
884 let config = CompletionConfig {
885 temperature: 0.7,
886 num_ctx: 4096,
887 max_tokens: 2048,
888 num_predict: None,
889 top_p: None,
890 top_k: None,
891 response_format: Some(ResponseFormat::Text),
892 reasoning_effort: None,
893 };
894
895 let mut body = serde_json::json!({
896 "model": "test",
897 "messages": []
898 });
899
900 add_response_format(&mut body, &config);
901
902 assert!(body.get("format").is_none());
904 }
905
906 #[test]
907 fn test_add_response_format_json_no_schema() {
908 use crate::llm::gateway::ResponseFormat;
909
910 let config = CompletionConfig {
911 temperature: 0.7,
912 num_ctx: 4096,
913 max_tokens: 2048,
914 num_predict: None,
915 top_p: None,
916 top_k: None,
917 response_format: Some(ResponseFormat::JsonObject { schema: None }),
918 reasoning_effort: None,
919 };
920
921 let mut body = serde_json::json!({
922 "model": "test",
923 "messages": []
924 });
925
926 add_response_format(&mut body, &config);
927
928 assert_eq!(body["format"], "json");
929 }
930
931 #[test]
932 fn test_add_response_format_json_with_schema() {
933 use crate::llm::gateway::ResponseFormat;
934
935 let schema = serde_json::json!({
936 "type": "object",
937 "properties": {
938 "name": {"type": "string"},
939 "age": {"type": "number"}
940 }
941 });
942
943 let config = CompletionConfig {
944 temperature: 0.7,
945 num_ctx: 4096,
946 max_tokens: 2048,
947 num_predict: None,
948 top_p: None,
949 top_k: None,
950 response_format: Some(ResponseFormat::JsonObject {
951 schema: Some(schema.clone()),
952 }),
953 reasoning_effort: None,
954 };
955
956 let mut body = serde_json::json!({
957 "model": "test",
958 "messages": []
959 });
960
961 add_response_format(&mut body, &config);
962
963 assert_eq!(body["format"], schema);
964 }
965
966 #[test]
967 fn test_add_response_format_none() {
968 let config = CompletionConfig {
969 temperature: 0.7,
970 num_ctx: 4096,
971 max_tokens: 2048,
972 num_predict: None,
973 top_p: None,
974 top_k: None,
975 response_format: None,
976 reasoning_effort: None,
977 };
978
979 let mut body = serde_json::json!({
980 "model": "test",
981 "messages": []
982 });
983
984 add_response_format(&mut body, &config);
985
986 assert!(body.get("format").is_none());
988 }
989
990 #[tokio::test]
991 async fn test_pull_model_success() {
992 let mut server = mockito::Server::new_async().await;
993 let mock = server
994 .mock("POST", "/api/pull")
995 .with_status(200)
996 .with_body(r#"{"status":"success"}"#)
997 .create();
998
999 let gateway = OllamaGateway::with_host(server.url());
1000 let result = gateway.pull_model("llama2").await;
1001
1002 mock.assert();
1003 assert!(result.is_ok());
1004 }
1005
1006 #[tokio::test]
1007 async fn test_pull_model_failure() {
1008 let mut server = mockito::Server::new_async().await;
1009 let mock = server.mock("POST", "/api/pull").with_status(404).create();
1010
1011 let gateway = OllamaGateway::with_host(server.url());
1012 let result = gateway.pull_model("nonexistent").await;
1013
1014 mock.assert();
1015 assert!(result.is_err());
1016 }
1017
1018 #[tokio::test]
1019 async fn test_complete_simple() {
1020 let mut server = mockito::Server::new_async().await;
1021 let mock = server
1022 .mock("POST", "/api/chat")
1023 .with_status(200)
1024 .with_body(r#"{"message":{"role":"assistant","content":"Hello!"}}"#)
1025 .create();
1026
1027 let gateway = OllamaGateway::with_host(server.url());
1028 let messages = vec![LlmMessage::user("Hi")];
1029 let config = CompletionConfig::default();
1030
1031 let result = gateway.complete("llama2", &messages, None, &config).await;
1032
1033 mock.assert();
1034 assert!(result.is_ok());
1035 let response = result.unwrap();
1036 assert_eq!(response.content, Some("Hello!".to_string()));
1037 assert_eq!(response.thinking, None);
1038 }
1039
1040 #[tokio::test]
1041 async fn test_complete_with_tools() {
1042 let mut server = mockito::Server::new_async().await;
1043 let mock = server
1044 .mock("POST", "/api/chat")
1045 .match_body(mockito::Matcher::JsonString(
1046 r#"{"model":"llama2","messages":[{"role":"user","content":"Hi"}],"options":{"temperature":1.0,"num_ctx":32768,"num_predict":16384},"stream":false,"tools":[{"type":"function","function":{"name":"test_tool","description":"A test","parameters":{}}}]}"#.to_string()
1047 ))
1048 .with_status(200)
1049 .with_body(r#"{"message":{"role":"assistant","content":"Result"}}"#)
1050 .create();
1051
1052 let gateway = OllamaGateway::with_host(server.url());
1053 let messages = vec![LlmMessage::user("Hi")];
1054 let config = CompletionConfig::default();
1055
1056 use crate::llm::tools::{FunctionDescriptor, LlmTool, ToolDescriptor};
1057
1058 #[derive(Clone)]
1059 struct MockTool;
1060 impl LlmTool for MockTool {
1061 fn run(&self, _args: &HashMap<String, Value>) -> Result<Value> {
1062 Ok(serde_json::json!({}))
1063 }
1064 fn descriptor(&self) -> ToolDescriptor {
1065 ToolDescriptor {
1066 r#type: "function".to_string(),
1067 function: FunctionDescriptor {
1068 name: "test_tool".to_string(),
1069 description: "A test".to_string(),
1070 parameters: serde_json::json!({}),
1071 },
1072 }
1073 }
1074 fn clone_box(&self) -> Box<dyn LlmTool> {
1075 Box::new(self.clone())
1076 }
1077 }
1078
1079 let tools: Vec<Box<dyn LlmTool>> = vec![Box::new(MockTool)];
1080 let result = gateway.complete("llama2", &messages, Some(&tools), &config).await;
1081
1082 mock.assert();
1083 assert!(result.is_ok());
1084 }
1085
1086 #[tokio::test]
1087 async fn test_complete_error() {
1088 let mut server = mockito::Server::new_async().await;
1089 let mock = server.mock("POST", "/api/chat").with_status(500).create();
1090
1091 let gateway = OllamaGateway::with_host(server.url());
1092 let messages = vec![LlmMessage::user("Hi")];
1093 let config = CompletionConfig::default();
1094
1095 let result = gateway.complete("llama2", &messages, None, &config).await;
1096
1097 mock.assert();
1098 assert!(result.is_err());
1099 }
1100
1101 #[tokio::test]
1102 async fn test_complete_json() {
1103 let mut server = mockito::Server::new_async().await;
1104 let mock = server
1105 .mock("POST", "/api/chat")
1106 .with_status(200)
1107 .with_body(r#"{"message":{"content":"{\"name\":\"test\",\"value\":42}"}}"#)
1108 .create();
1109
1110 let gateway = OllamaGateway::with_host(server.url());
1111 let messages = vec![LlmMessage::user("Generate JSON")];
1112 let schema = serde_json::json!({"type": "object"});
1113 let config = CompletionConfig::default();
1114
1115 let result = gateway.complete_json("llama2", &messages, schema, &config).await;
1116
1117 mock.assert();
1118 assert!(result.is_ok());
1119 let json = result.unwrap();
1120 assert_eq!(json["name"], "test");
1121 assert_eq!(json["value"], 42);
1122 }
1123
1124 #[tokio::test]
1125 async fn test_get_available_models() {
1126 let mut server = mockito::Server::new_async().await;
1127 let mock = server
1128 .mock("GET", "/api/tags")
1129 .with_status(200)
1130 .with_body(r#"{"models":[{"name":"llama2"},{"name":"mistral"}]}"#)
1131 .create();
1132
1133 let gateway = OllamaGateway::with_host(server.url());
1134 let result = gateway.get_available_models().await;
1135
1136 mock.assert();
1137 assert!(result.is_ok());
1138 let models = result.unwrap();
1139 assert_eq!(models.len(), 2);
1140 assert!(models.contains(&"llama2".to_string()));
1141 assert!(models.contains(&"mistral".to_string()));
1142 }
1143
1144 #[tokio::test]
1145 async fn test_calculate_embeddings() {
1146 let mut server = mockito::Server::new_async().await;
1147 let mock = server
1148 .mock("POST", "/api/embeddings")
1149 .with_status(200)
1150 .with_body(r#"{"embedding":[0.1,0.2,0.3,0.4]}"#)
1151 .create();
1152
1153 let gateway = OllamaGateway::with_host(server.url());
1154 let result = gateway.calculate_embeddings("test text", None).await;
1155
1156 mock.assert();
1157 assert!(result.is_ok());
1158 let embeddings = result.unwrap();
1159 assert_eq!(embeddings.len(), 4);
1160 assert_eq!(embeddings[0], 0.1);
1161 assert_eq!(embeddings[3], 0.4);
1162 }
1163
1164 #[tokio::test]
1165 async fn test_calculate_embeddings_custom_model() {
1166 let mut server = mockito::Server::new_async().await;
1167 let mock = server
1168 .mock("POST", "/api/embeddings")
1169 .match_body(mockito::Matcher::JsonString(
1170 r#"{"model":"custom-embed","prompt":"test"}"#.to_string(),
1171 ))
1172 .with_status(200)
1173 .with_body(r#"{"embedding":[0.5,0.6]}"#)
1174 .create();
1175
1176 let gateway = OllamaGateway::with_host(server.url());
1177 let result = gateway.calculate_embeddings("test", Some("custom-embed")).await;
1178
1179 mock.assert();
1180 assert!(result.is_ok());
1181 }
1182
1183 #[tokio::test]
1184 async fn test_complete_with_reasoning_effort() {
1185 use crate::llm::gateway::ReasoningEffort;
1186
1187 let mut server = mockito::Server::new_async().await;
1188 let mock = server
1189 .mock("POST", "/api/chat")
1190 .match_body(mockito::Matcher::PartialJson(
1191 serde_json::json!({"think": true}),
1192 ))
1193 .with_status(200)
1194 .with_body(r#"{"message":{"role":"assistant","content":"Response","thinking":"Internal reasoning..."}}"#)
1195 .create();
1196
1197 let gateway = OllamaGateway::with_host(server.url());
1198 let messages = vec![LlmMessage::user("Test")];
1199 let config = CompletionConfig {
1200 reasoning_effort: Some(ReasoningEffort::High),
1201 ..Default::default()
1202 };
1203
1204 let result = gateway.complete("qwen3:32b", &messages, None, &config).await;
1205
1206 mock.assert();
1207 assert!(result.is_ok());
1208 let response = result.unwrap();
1209 assert_eq!(response.content, Some("Response".to_string()));
1210 assert_eq!(response.thinking, Some("Internal reasoning...".to_string()));
1211 }
1212}