mojentic/llm/
gateway.rs

1use crate::error::Result;
2use crate::llm::models::{LlmGatewayResponse, LlmMessage};
3use crate::llm::tools::LlmTool;
4use async_trait::async_trait;
5use futures::stream::Stream;
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use std::pin::Pin;
9
10/// Format specification for LLM responses
11#[derive(Debug, Clone)]
12pub enum ResponseFormat {
13    /// Plain text response
14    Text,
15    /// JSON object response with optional schema
16    JsonObject { schema: Option<Value> },
17}
18
19/// Reasoning effort level for models that support extended thinking
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
21#[serde(rename_all = "lowercase")]
22pub enum ReasoningEffort {
23    Low,
24    Medium,
25    High,
26}
27
28/// Configuration for LLM completion
29#[derive(Debug, Clone)]
30pub struct CompletionConfig {
31    pub temperature: f32,
32    pub num_ctx: usize,
33    pub max_tokens: usize,
34    pub num_predict: Option<i32>,
35    pub top_p: Option<f32>,
36    pub top_k: Option<u32>,
37    pub response_format: Option<ResponseFormat>,
38    pub reasoning_effort: Option<ReasoningEffort>,
39}
40
41impl Default for CompletionConfig {
42    fn default() -> Self {
43        Self {
44            temperature: 1.0,
45            num_ctx: 32768,
46            max_tokens: 16384,
47            num_predict: None,
48            top_p: None,
49            top_k: None,
50            response_format: None,
51            reasoning_effort: None,
52        }
53    }
54}
55
56/// Abstract interface for LLM providers
57#[async_trait]
58pub trait LlmGateway: Send + Sync {
59    /// Complete an LLM request with text response
60    async fn complete(
61        &self,
62        model: &str,
63        messages: &[LlmMessage],
64        tools: Option<&[Box<dyn LlmTool>]>,
65        config: &CompletionConfig,
66    ) -> Result<LlmGatewayResponse>;
67
68    /// Complete an LLM request with structured JSON response
69    async fn complete_json(
70        &self,
71        model: &str,
72        messages: &[LlmMessage],
73        schema: Value,
74        config: &CompletionConfig,
75    ) -> Result<Value>;
76
77    /// Get list of available models
78    async fn get_available_models(&self) -> Result<Vec<String>>;
79
80    /// Calculate embeddings for text
81    async fn calculate_embeddings(&self, text: &str, model: Option<&str>) -> Result<Vec<f32>>;
82
83    /// Stream LLM responses chunk by chunk
84    ///
85    /// Returns a stream that yields either content chunks or tool calls.
86    /// Tool calls will be accumulated and yielded when complete.
87    fn complete_stream<'a>(
88        &'a self,
89        model: &'a str,
90        messages: &'a [LlmMessage],
91        tools: Option<&'a [Box<dyn LlmTool>]>,
92        config: &'a CompletionConfig,
93    ) -> Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send + 'a>>;
94}
95
96/// Streaming response chunk
97#[derive(Debug, Clone)]
98pub enum StreamChunk {
99    /// Content text chunk
100    Content(String),
101    /// Complete tool calls (accumulated from stream)
102    ToolCalls(Vec<crate::llm::models::LlmToolCall>),
103}
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108
109    #[test]
110    fn test_completion_config_default() {
111        let config = CompletionConfig::default();
112
113        assert_eq!(config.temperature, 1.0);
114        assert_eq!(config.num_ctx, 32768);
115        assert_eq!(config.max_tokens, 16384);
116        assert_eq!(config.num_predict, None);
117        assert_eq!(config.top_p, None);
118        assert_eq!(config.top_k, None);
119        assert!(config.response_format.is_none());
120        assert!(config.reasoning_effort.is_none());
121    }
122
123    #[test]
124    fn test_completion_config_custom() {
125        let config = CompletionConfig {
126            temperature: 0.5,
127            num_ctx: 2048,
128            max_tokens: 1024,
129            num_predict: Some(100),
130            top_p: Some(0.9),
131            top_k: Some(40),
132            response_format: Some(ResponseFormat::Text),
133            reasoning_effort: None,
134        };
135
136        assert_eq!(config.temperature, 0.5);
137        assert_eq!(config.num_ctx, 2048);
138        assert_eq!(config.max_tokens, 1024);
139        assert_eq!(config.num_predict, Some(100));
140        assert_eq!(config.top_p, Some(0.9));
141        assert_eq!(config.top_k, Some(40));
142        assert!(matches!(config.response_format, Some(ResponseFormat::Text)));
143    }
144
145    #[test]
146    fn test_completion_config_clone() {
147        let config1 = CompletionConfig {
148            temperature: 0.7,
149            num_ctx: 4096,
150            max_tokens: 2048,
151            num_predict: Some(50),
152            top_p: Some(0.95),
153            top_k: Some(50),
154            response_format: Some(ResponseFormat::JsonObject { schema: None }),
155            reasoning_effort: None,
156        };
157
158        let config2 = config1.clone();
159
160        assert_eq!(config1.temperature, config2.temperature);
161        assert_eq!(config1.num_ctx, config2.num_ctx);
162        assert_eq!(config1.max_tokens, config2.max_tokens);
163        assert_eq!(config1.num_predict, config2.num_predict);
164        assert_eq!(config1.top_p, config2.top_p);
165        assert_eq!(config1.top_k, config2.top_k);
166    }
167
168    #[test]
169    fn test_response_format_text() {
170        let format = ResponseFormat::Text;
171        assert!(matches!(format, ResponseFormat::Text));
172    }
173
174    #[test]
175    fn test_response_format_json_no_schema() {
176        let format = ResponseFormat::JsonObject { schema: None };
177        assert!(matches!(format, ResponseFormat::JsonObject { schema: None }));
178    }
179
180    #[test]
181    fn test_response_format_json_with_schema() {
182        let schema = serde_json::json!({
183            "type": "object",
184            "properties": {
185                "name": {"type": "string"}
186            }
187        });
188        let format = ResponseFormat::JsonObject {
189            schema: Some(schema.clone()),
190        };
191
192        match format {
193            ResponseFormat::JsonObject { schema: Some(s) } => {
194                assert_eq!(s, schema);
195            }
196            _ => panic!("Expected JsonObject with schema"),
197        }
198    }
199
200    #[test]
201    fn test_completion_config_with_all_sampling_params() {
202        let config = CompletionConfig {
203            temperature: 0.8,
204            num_ctx: 8192,
205            max_tokens: 4096,
206            num_predict: Some(2000),
207            top_p: Some(0.92),
208            top_k: Some(60),
209            response_format: Some(ResponseFormat::JsonObject {
210                schema: Some(serde_json::json!({"type": "object"})),
211            }),
212            reasoning_effort: None,
213        };
214
215        assert_eq!(config.temperature, 0.8);
216        assert_eq!(config.top_p, Some(0.92));
217        assert_eq!(config.top_k, Some(60));
218        assert!(config.response_format.is_some());
219    }
220
221    #[test]
222    fn test_reasoning_effort_serialization() {
223        assert_eq!(serde_json::to_string(&ReasoningEffort::Low).unwrap(), "\"low\"");
224        assert_eq!(serde_json::to_string(&ReasoningEffort::Medium).unwrap(), "\"medium\"");
225        assert_eq!(serde_json::to_string(&ReasoningEffort::High).unwrap(), "\"high\"");
226    }
227
228    #[test]
229    fn test_reasoning_effort_deserialization() {
230        assert_eq!(
231            serde_json::from_str::<ReasoningEffort>("\"low\"").unwrap(),
232            ReasoningEffort::Low
233        );
234        assert_eq!(
235            serde_json::from_str::<ReasoningEffort>("\"medium\"").unwrap(),
236            ReasoningEffort::Medium
237        );
238        assert_eq!(
239            serde_json::from_str::<ReasoningEffort>("\"high\"").unwrap(),
240            ReasoningEffort::High
241        );
242    }
243
244    #[test]
245    fn test_completion_config_with_reasoning_effort() {
246        let config = CompletionConfig {
247            temperature: 1.0,
248            num_ctx: 32768,
249            max_tokens: 16384,
250            num_predict: None,
251            top_p: None,
252            top_k: None,
253            response_format: None,
254            reasoning_effort: Some(ReasoningEffort::High),
255        };
256
257        assert_eq!(config.reasoning_effort, Some(ReasoningEffort::High));
258    }
259}