1use crate::error::Result;
2use crate::llm::models::{LlmGatewayResponse, LlmMessage};
3use crate::llm::tools::LlmTool;
4use async_trait::async_trait;
5use futures::stream::Stream;
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use std::pin::Pin;
9
10#[derive(Debug, Clone)]
12pub enum ResponseFormat {
13 Text,
15 JsonObject { schema: Option<Value> },
17}
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
21#[serde(rename_all = "lowercase")]
22pub enum ReasoningEffort {
23 Low,
24 Medium,
25 High,
26}
27
28#[derive(Debug, Clone)]
30pub struct CompletionConfig {
31 pub temperature: f32,
32 pub num_ctx: usize,
33 pub max_tokens: usize,
34 pub num_predict: Option<i32>,
35 pub top_p: Option<f32>,
36 pub top_k: Option<u32>,
37 pub response_format: Option<ResponseFormat>,
38 pub reasoning_effort: Option<ReasoningEffort>,
39}
40
41impl Default for CompletionConfig {
42 fn default() -> Self {
43 Self {
44 temperature: 1.0,
45 num_ctx: 32768,
46 max_tokens: 16384,
47 num_predict: None,
48 top_p: None,
49 top_k: None,
50 response_format: None,
51 reasoning_effort: None,
52 }
53 }
54}
55
56#[async_trait]
58pub trait LlmGateway: Send + Sync {
59 async fn complete(
61 &self,
62 model: &str,
63 messages: &[LlmMessage],
64 tools: Option<&[Box<dyn LlmTool>]>,
65 config: &CompletionConfig,
66 ) -> Result<LlmGatewayResponse>;
67
68 async fn complete_json(
70 &self,
71 model: &str,
72 messages: &[LlmMessage],
73 schema: Value,
74 config: &CompletionConfig,
75 ) -> Result<Value>;
76
77 async fn get_available_models(&self) -> Result<Vec<String>>;
79
80 async fn calculate_embeddings(&self, text: &str, model: Option<&str>) -> Result<Vec<f32>>;
82
83 fn complete_stream<'a>(
88 &'a self,
89 model: &'a str,
90 messages: &'a [LlmMessage],
91 tools: Option<&'a [Box<dyn LlmTool>]>,
92 config: &'a CompletionConfig,
93 ) -> Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send + 'a>>;
94}
95
96#[derive(Debug, Clone)]
98pub enum StreamChunk {
99 Content(String),
101 ToolCalls(Vec<crate::llm::models::LlmToolCall>),
103}
104
105#[cfg(test)]
106mod tests {
107 use super::*;
108
109 #[test]
110 fn test_completion_config_default() {
111 let config = CompletionConfig::default();
112
113 assert_eq!(config.temperature, 1.0);
114 assert_eq!(config.num_ctx, 32768);
115 assert_eq!(config.max_tokens, 16384);
116 assert_eq!(config.num_predict, None);
117 assert_eq!(config.top_p, None);
118 assert_eq!(config.top_k, None);
119 assert!(config.response_format.is_none());
120 assert!(config.reasoning_effort.is_none());
121 }
122
123 #[test]
124 fn test_completion_config_custom() {
125 let config = CompletionConfig {
126 temperature: 0.5,
127 num_ctx: 2048,
128 max_tokens: 1024,
129 num_predict: Some(100),
130 top_p: Some(0.9),
131 top_k: Some(40),
132 response_format: Some(ResponseFormat::Text),
133 reasoning_effort: None,
134 };
135
136 assert_eq!(config.temperature, 0.5);
137 assert_eq!(config.num_ctx, 2048);
138 assert_eq!(config.max_tokens, 1024);
139 assert_eq!(config.num_predict, Some(100));
140 assert_eq!(config.top_p, Some(0.9));
141 assert_eq!(config.top_k, Some(40));
142 assert!(matches!(config.response_format, Some(ResponseFormat::Text)));
143 }
144
145 #[test]
146 fn test_completion_config_clone() {
147 let config1 = CompletionConfig {
148 temperature: 0.7,
149 num_ctx: 4096,
150 max_tokens: 2048,
151 num_predict: Some(50),
152 top_p: Some(0.95),
153 top_k: Some(50),
154 response_format: Some(ResponseFormat::JsonObject { schema: None }),
155 reasoning_effort: None,
156 };
157
158 let config2 = config1.clone();
159
160 assert_eq!(config1.temperature, config2.temperature);
161 assert_eq!(config1.num_ctx, config2.num_ctx);
162 assert_eq!(config1.max_tokens, config2.max_tokens);
163 assert_eq!(config1.num_predict, config2.num_predict);
164 assert_eq!(config1.top_p, config2.top_p);
165 assert_eq!(config1.top_k, config2.top_k);
166 }
167
168 #[test]
169 fn test_response_format_text() {
170 let format = ResponseFormat::Text;
171 assert!(matches!(format, ResponseFormat::Text));
172 }
173
174 #[test]
175 fn test_response_format_json_no_schema() {
176 let format = ResponseFormat::JsonObject { schema: None };
177 assert!(matches!(format, ResponseFormat::JsonObject { schema: None }));
178 }
179
180 #[test]
181 fn test_response_format_json_with_schema() {
182 let schema = serde_json::json!({
183 "type": "object",
184 "properties": {
185 "name": {"type": "string"}
186 }
187 });
188 let format = ResponseFormat::JsonObject {
189 schema: Some(schema.clone()),
190 };
191
192 match format {
193 ResponseFormat::JsonObject { schema: Some(s) } => {
194 assert_eq!(s, schema);
195 }
196 _ => panic!("Expected JsonObject with schema"),
197 }
198 }
199
200 #[test]
201 fn test_completion_config_with_all_sampling_params() {
202 let config = CompletionConfig {
203 temperature: 0.8,
204 num_ctx: 8192,
205 max_tokens: 4096,
206 num_predict: Some(2000),
207 top_p: Some(0.92),
208 top_k: Some(60),
209 response_format: Some(ResponseFormat::JsonObject {
210 schema: Some(serde_json::json!({"type": "object"})),
211 }),
212 reasoning_effort: None,
213 };
214
215 assert_eq!(config.temperature, 0.8);
216 assert_eq!(config.top_p, Some(0.92));
217 assert_eq!(config.top_k, Some(60));
218 assert!(config.response_format.is_some());
219 }
220
221 #[test]
222 fn test_reasoning_effort_serialization() {
223 assert_eq!(serde_json::to_string(&ReasoningEffort::Low).unwrap(), "\"low\"");
224 assert_eq!(serde_json::to_string(&ReasoningEffort::Medium).unwrap(), "\"medium\"");
225 assert_eq!(serde_json::to_string(&ReasoningEffort::High).unwrap(), "\"high\"");
226 }
227
228 #[test]
229 fn test_reasoning_effort_deserialization() {
230 assert_eq!(
231 serde_json::from_str::<ReasoningEffort>("\"low\"").unwrap(),
232 ReasoningEffort::Low
233 );
234 assert_eq!(
235 serde_json::from_str::<ReasoningEffort>("\"medium\"").unwrap(),
236 ReasoningEffort::Medium
237 );
238 assert_eq!(
239 serde_json::from_str::<ReasoningEffort>("\"high\"").unwrap(),
240 ReasoningEffort::High
241 );
242 }
243
244 #[test]
245 fn test_completion_config_with_reasoning_effort() {
246 let config = CompletionConfig {
247 temperature: 1.0,
248 num_ctx: 32768,
249 max_tokens: 16384,
250 num_predict: None,
251 top_p: None,
252 top_k: None,
253 response_format: None,
254 reasoning_effort: Some(ReasoningEffort::High),
255 };
256
257 assert_eq!(config.reasoning_effort, Some(ReasoningEffort::High));
258 }
259}