1use crate::error::Result;
2use crate::llm::broker::LlmBroker;
3use crate::llm::models::{LlmMessage, MessageRole};
4use crate::llm::tools::{FunctionDescriptor, LlmTool, ToolDescriptor};
5use serde_json::{json, Value};
6use std::collections::HashMap;
7use std::sync::Arc;
8
9pub struct ToolWrapper {
16 broker: Arc<LlmBroker>,
17 tools: Vec<Box<dyn LlmTool>>,
18 behaviour: String,
19 name: String,
20 description: String,
21}
22
23impl ToolWrapper {
24 pub fn new(
33 broker: Arc<LlmBroker>,
34 tools: Vec<Box<dyn LlmTool>>,
35 behaviour: impl Into<String>,
36 name: impl Into<String>,
37 description: impl Into<String>,
38 ) -> Self {
39 Self {
40 broker,
41 tools,
42 behaviour: behaviour.into(),
43 name: name.into(),
44 description: description.into(),
45 }
46 }
47
48 fn create_initial_messages(&self) -> Vec<LlmMessage> {
50 vec![LlmMessage {
51 role: MessageRole::System,
52 content: Some(self.behaviour.clone()),
53 tool_calls: None,
54 image_paths: None,
55 }]
56 }
57}
58
59impl LlmTool for ToolWrapper {
60 fn run(&self, args: &HashMap<String, Value>) -> Result<Value> {
61 let input = args.get("input").and_then(|v| v.as_str()).ok_or_else(|| {
63 crate::error::MojenticError::ToolError("Missing 'input' parameter".to_string())
64 })?;
65
66 let mut messages = self.create_initial_messages();
68
69 messages.push(LlmMessage {
71 role: MessageRole::User,
72 content: Some(input.to_string()),
73 tool_calls: None,
74 image_paths: None,
75 });
76
77 let response = tokio::task::block_in_place(|| {
80 tokio::runtime::Handle::current().block_on(async {
81 self.broker.generate(&messages, Some(&self.tools), None, None).await
82 })
83 })?;
84
85 Ok(json!(response))
86 }
87
88 fn descriptor(&self) -> ToolDescriptor {
89 ToolDescriptor {
90 r#type: "function".to_string(),
91 function: FunctionDescriptor {
92 name: self.name.clone(),
93 description: self.description.clone(),
94 parameters: json!({
95 "type": "object",
96 "properties": {
97 "input": {
98 "type": "string",
99 "description": "Instructions for this agent."
100 }
101 },
102 "required": ["input"],
103 "additionalProperties": false
104 }),
105 },
106 }
107 }
108
109 fn clone_box(&self) -> Box<dyn LlmTool> {
110 Box::new(ToolWrapper {
111 broker: self.broker.clone(),
112 tools: self.tools.iter().map(|t| t.clone_box()).collect(),
113 behaviour: self.behaviour.clone(),
114 name: self.name.clone(),
115 description: self.description.clone(),
116 })
117 }
118}
119
120#[cfg(test)]
121mod tests {
122 use super::*;
123 use crate::llm::gateway::{CompletionConfig, LlmGateway, StreamChunk};
124 use crate::llm::models::LlmGatewayResponse;
125 use futures::stream::{self, Stream};
126 use std::pin::Pin;
127
128 struct MockGateway {
130 expected_behaviour: String,
131 response: String,
132 }
133
134 impl MockGateway {
135 fn new(expected_behaviour: String, response: String) -> Self {
136 Self {
137 expected_behaviour,
138 response,
139 }
140 }
141 }
142
143 #[async_trait::async_trait]
144 impl LlmGateway for MockGateway {
145 async fn complete(
146 &self,
147 _model: &str,
148 messages: &[LlmMessage],
149 _tools: Option<&[Box<dyn LlmTool>]>,
150 _config: &CompletionConfig,
151 ) -> Result<LlmGatewayResponse> {
152 assert!(messages.len() >= 2, "Expected at least 2 messages (system + user)");
154 assert_eq!(messages[0].role, MessageRole::System, "First message should be system");
155 assert_eq!(
156 messages[0].content.as_ref().unwrap(),
157 &self.expected_behaviour,
158 "System message should match behaviour"
159 );
160
161 Ok(LlmGatewayResponse {
162 content: Some(self.response.clone()),
163 object: None,
164 tool_calls: vec![],
165 thinking: None,
166 })
167 }
168
169 async fn complete_json(
170 &self,
171 _model: &str,
172 _messages: &[LlmMessage],
173 _schema: Value,
174 _config: &CompletionConfig,
175 ) -> Result<Value> {
176 Ok(json!({}))
177 }
178
179 async fn get_available_models(&self) -> Result<Vec<String>> {
180 Ok(vec![])
181 }
182
183 async fn calculate_embeddings(
184 &self,
185 _text: &str,
186 _model: Option<&str>,
187 ) -> Result<Vec<f32>> {
188 Ok(vec![])
189 }
190
191 fn complete_stream<'a>(
192 &'a self,
193 _model: &'a str,
194 _messages: &'a [LlmMessage],
195 _tools: Option<&'a [Box<dyn LlmTool>]>,
196 _config: &'a CompletionConfig,
197 ) -> Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send + 'a>> {
198 Box::pin(stream::iter(vec![]))
199 }
200 }
201
202 #[tokio::test]
203 async fn test_tool_wrapper_descriptor() {
204 let gateway = Arc::new(MockGateway::new(
205 "You are a test agent".to_string(),
206 "test response".to_string(),
207 ));
208 let broker = Arc::new(LlmBroker::new("test-model", gateway, None));
209 let tools: Vec<Box<dyn LlmTool>> = vec![];
210
211 let wrapper = ToolWrapper::new(
212 broker,
213 tools,
214 "You are a test agent",
215 "test_agent",
216 "A test agent for unit testing",
217 );
218
219 let descriptor = wrapper.descriptor();
220
221 assert_eq!(descriptor.r#type, "function");
222 assert_eq!(descriptor.function.name, "test_agent");
223 assert_eq!(descriptor.function.description, "A test agent for unit testing");
224
225 let params = descriptor.function.parameters;
226 assert_eq!(params["type"], "object");
227 assert!(params["properties"]["input"].is_object());
228 assert_eq!(params["properties"]["input"]["type"], "string");
229 assert_eq!(params["properties"]["input"]["description"], "Instructions for this agent.");
230 assert_eq!(params["required"], json!(["input"]));
231 assert_eq!(params["additionalProperties"], false);
232 }
233
234 #[tokio::test(flavor = "multi_thread")]
235 async fn test_tool_wrapper_execution() {
236 let gateway = Arc::new(MockGateway::new(
237 "You are a helpful assistant".to_string(),
238 "I can help with that!".to_string(),
239 ));
240 let broker = Arc::new(LlmBroker::new("test-model", gateway, None));
241 let tools: Vec<Box<dyn LlmTool>> = vec![];
242
243 let wrapper = ToolWrapper::new(
244 broker,
245 tools,
246 "You are a helpful assistant",
247 "assistant",
248 "A helpful assistant",
249 );
250
251 let mut args = HashMap::new();
252 args.insert("input".to_string(), json!("Help me with something"));
253
254 let result = wrapper.run(&args).unwrap();
255
256 assert_eq!(result, json!("I can help with that!"));
258 }
259
260 #[tokio::test]
261 async fn test_tool_wrapper_missing_input() {
262 let gateway =
263 Arc::new(MockGateway::new("You are a test agent".to_string(), "test".to_string()));
264 let broker = Arc::new(LlmBroker::new("test-model", gateway, None));
265 let tools: Vec<Box<dyn LlmTool>> = vec![];
266
267 let wrapper =
268 ToolWrapper::new(broker, tools, "You are a test agent", "test_agent", "A test agent");
269
270 let args = HashMap::new();
271 let result = wrapper.run(&args);
272
273 assert!(result.is_err());
274 match result {
275 Err(crate::error::MojenticError::ToolError(message)) => {
276 assert_eq!(message, "Missing 'input' parameter");
277 }
278 _ => panic!("Expected ToolError"),
279 }
280 }
281
282 #[tokio::test(flavor = "multi_thread")]
283 async fn test_tool_wrapper_with_tools() {
284 struct MockTool;
286
287 impl LlmTool for MockTool {
288 fn run(&self, _args: &HashMap<String, Value>) -> Result<Value> {
289 Ok(json!({"result": "tool executed"}))
290 }
291
292 fn descriptor(&self) -> ToolDescriptor {
293 ToolDescriptor {
294 r#type: "function".to_string(),
295 function: FunctionDescriptor {
296 name: "mock_tool".to_string(),
297 description: "A mock tool".to_string(),
298 parameters: json!({}),
299 },
300 }
301 }
302
303 fn clone_box(&self) -> Box<dyn LlmTool> {
304 Box::new(MockTool)
305 }
306 }
307
308 let gateway = Arc::new(MockGateway::new(
309 "You are an agent with tools".to_string(),
310 "Task completed using tools".to_string(),
311 ));
312 let broker = Arc::new(LlmBroker::new("test-model", gateway, None));
313 let tools: Vec<Box<dyn LlmTool>> = vec![Box::new(MockTool)];
314
315 let wrapper = ToolWrapper::new(
316 broker,
317 tools,
318 "You are an agent with tools",
319 "tool_agent",
320 "An agent that has access to tools",
321 );
322
323 let mut args = HashMap::new();
324 args.insert("input".to_string(), json!("Use your tools"));
325
326 let result = wrapper.run(&args).unwrap();
327
328 assert_eq!(result, json!("Task completed using tools"));
329 }
330
331 #[tokio::test]
332 async fn test_tool_wrapper_matches() {
333 let gateway = Arc::new(MockGateway::new("test".to_string(), "test".to_string()));
334 let broker = Arc::new(LlmBroker::new("test-model", gateway, None));
335 let tools: Vec<Box<dyn LlmTool>> = vec![];
336
337 let wrapper =
338 ToolWrapper::new(broker, tools, "You are a test agent", "my_agent", "A test agent");
339
340 assert!(wrapper.matches("my_agent"));
341 assert!(!wrapper.matches("other_agent"));
342 }
343}