mojentic/llm/tools/
tool_wrapper.rs

1use crate::error::Result;
2use crate::llm::broker::LlmBroker;
3use crate::llm::models::{LlmMessage, MessageRole};
4use crate::llm::tools::{FunctionDescriptor, LlmTool, ToolDescriptor};
5use serde_json::{json, Value};
6use std::collections::HashMap;
7use std::sync::Arc;
8
9/// Wraps an agent (broker + tools + behaviour) as an LlmTool
10///
11/// This allows agents to be used as tools by other agents (delegation pattern).
12/// The tool's descriptor has a single "input" parameter (string).
13/// When run, it creates initial messages from the agent's behaviour, appends the input,
14/// and calls the agent's broker.
15pub struct ToolWrapper {
16    broker: Arc<LlmBroker>,
17    tools: Vec<Box<dyn LlmTool>>,
18    behaviour: String,
19    name: String,
20    description: String,
21}
22
23impl ToolWrapper {
24    /// Create a new ToolWrapper
25    ///
26    /// # Arguments
27    /// * `broker` - The LLM broker for this agent
28    /// * `tools` - The tools available to this agent
29    /// * `behaviour` - The system message defining the agent's behaviour
30    /// * `name` - The name of this tool (how other agents will call it)
31    /// * `description` - Description of what this agent/tool does
32    pub fn new(
33        broker: Arc<LlmBroker>,
34        tools: Vec<Box<dyn LlmTool>>,
35        behaviour: impl Into<String>,
36        name: impl Into<String>,
37        description: impl Into<String>,
38    ) -> Self {
39        Self {
40            broker,
41            tools,
42            behaviour: behaviour.into(),
43            name: name.into(),
44            description: description.into(),
45        }
46    }
47
48    /// Create initial messages with the agent's behaviour
49    fn create_initial_messages(&self) -> Vec<LlmMessage> {
50        vec![LlmMessage {
51            role: MessageRole::System,
52            content: Some(self.behaviour.clone()),
53            tool_calls: None,
54            image_paths: None,
55        }]
56    }
57}
58
59impl LlmTool for ToolWrapper {
60    fn run(&self, args: &HashMap<String, Value>) -> Result<Value> {
61        // Extract input from arguments
62        let input = args.get("input").and_then(|v| v.as_str()).ok_or_else(|| {
63            crate::error::MojenticError::ToolError("Missing 'input' parameter".to_string())
64        })?;
65
66        // Create initial messages with behaviour
67        let mut messages = self.create_initial_messages();
68
69        // Append the user input
70        messages.push(LlmMessage {
71            role: MessageRole::User,
72            content: Some(input.to_string()),
73            tool_calls: None,
74            image_paths: None,
75        });
76
77        // Call the broker with the messages and tools
78        // We need to handle the async call in a way that works with the sync trait
79        let response = tokio::task::block_in_place(|| {
80            tokio::runtime::Handle::current().block_on(async {
81                self.broker.generate(&messages, Some(&self.tools), None, None).await
82            })
83        })?;
84
85        Ok(json!(response))
86    }
87
88    fn descriptor(&self) -> ToolDescriptor {
89        ToolDescriptor {
90            r#type: "function".to_string(),
91            function: FunctionDescriptor {
92                name: self.name.clone(),
93                description: self.description.clone(),
94                parameters: json!({
95                    "type": "object",
96                    "properties": {
97                        "input": {
98                            "type": "string",
99                            "description": "Instructions for this agent."
100                        }
101                    },
102                    "required": ["input"],
103                    "additionalProperties": false
104                }),
105            },
106        }
107    }
108
109    fn clone_box(&self) -> Box<dyn LlmTool> {
110        Box::new(ToolWrapper {
111            broker: self.broker.clone(),
112            tools: self.tools.iter().map(|t| t.clone_box()).collect(),
113            behaviour: self.behaviour.clone(),
114            name: self.name.clone(),
115            description: self.description.clone(),
116        })
117    }
118}
119
120#[cfg(test)]
121mod tests {
122    use super::*;
123    use crate::llm::gateway::{CompletionConfig, LlmGateway, StreamChunk};
124    use crate::llm::models::LlmGatewayResponse;
125    use futures::stream::{self, Stream};
126    use std::pin::Pin;
127
128    // Mock gateway for testing
129    struct MockGateway {
130        expected_behaviour: String,
131        response: String,
132    }
133
134    impl MockGateway {
135        fn new(expected_behaviour: String, response: String) -> Self {
136            Self {
137                expected_behaviour,
138                response,
139            }
140        }
141    }
142
143    #[async_trait::async_trait]
144    impl LlmGateway for MockGateway {
145        async fn complete(
146            &self,
147            _model: &str,
148            messages: &[LlmMessage],
149            _tools: Option<&[Box<dyn LlmTool>]>,
150            _config: &CompletionConfig,
151        ) -> Result<LlmGatewayResponse> {
152            // Verify that the first message is the system message with behaviour
153            assert!(messages.len() >= 2, "Expected at least 2 messages (system + user)");
154            assert_eq!(messages[0].role, MessageRole::System, "First message should be system");
155            assert_eq!(
156                messages[0].content.as_ref().unwrap(),
157                &self.expected_behaviour,
158                "System message should match behaviour"
159            );
160
161            Ok(LlmGatewayResponse {
162                content: Some(self.response.clone()),
163                object: None,
164                tool_calls: vec![],
165                thinking: None,
166            })
167        }
168
169        async fn complete_json(
170            &self,
171            _model: &str,
172            _messages: &[LlmMessage],
173            _schema: Value,
174            _config: &CompletionConfig,
175        ) -> Result<Value> {
176            Ok(json!({}))
177        }
178
179        async fn get_available_models(&self) -> Result<Vec<String>> {
180            Ok(vec![])
181        }
182
183        async fn calculate_embeddings(
184            &self,
185            _text: &str,
186            _model: Option<&str>,
187        ) -> Result<Vec<f32>> {
188            Ok(vec![])
189        }
190
191        fn complete_stream<'a>(
192            &'a self,
193            _model: &'a str,
194            _messages: &'a [LlmMessage],
195            _tools: Option<&'a [Box<dyn LlmTool>]>,
196            _config: &'a CompletionConfig,
197        ) -> Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send + 'a>> {
198            Box::pin(stream::iter(vec![]))
199        }
200    }
201
202    #[tokio::test]
203    async fn test_tool_wrapper_descriptor() {
204        let gateway = Arc::new(MockGateway::new(
205            "You are a test agent".to_string(),
206            "test response".to_string(),
207        ));
208        let broker = Arc::new(LlmBroker::new("test-model", gateway, None));
209        let tools: Vec<Box<dyn LlmTool>> = vec![];
210
211        let wrapper = ToolWrapper::new(
212            broker,
213            tools,
214            "You are a test agent",
215            "test_agent",
216            "A test agent for unit testing",
217        );
218
219        let descriptor = wrapper.descriptor();
220
221        assert_eq!(descriptor.r#type, "function");
222        assert_eq!(descriptor.function.name, "test_agent");
223        assert_eq!(descriptor.function.description, "A test agent for unit testing");
224
225        let params = descriptor.function.parameters;
226        assert_eq!(params["type"], "object");
227        assert!(params["properties"]["input"].is_object());
228        assert_eq!(params["properties"]["input"]["type"], "string");
229        assert_eq!(params["properties"]["input"]["description"], "Instructions for this agent.");
230        assert_eq!(params["required"], json!(["input"]));
231        assert_eq!(params["additionalProperties"], false);
232    }
233
234    #[tokio::test(flavor = "multi_thread")]
235    async fn test_tool_wrapper_execution() {
236        let gateway = Arc::new(MockGateway::new(
237            "You are a helpful assistant".to_string(),
238            "I can help with that!".to_string(),
239        ));
240        let broker = Arc::new(LlmBroker::new("test-model", gateway, None));
241        let tools: Vec<Box<dyn LlmTool>> = vec![];
242
243        let wrapper = ToolWrapper::new(
244            broker,
245            tools,
246            "You are a helpful assistant",
247            "assistant",
248            "A helpful assistant",
249        );
250
251        let mut args = HashMap::new();
252        args.insert("input".to_string(), json!("Help me with something"));
253
254        let result = wrapper.run(&args).unwrap();
255
256        // Result should be a JSON string value
257        assert_eq!(result, json!("I can help with that!"));
258    }
259
260    #[tokio::test]
261    async fn test_tool_wrapper_missing_input() {
262        let gateway =
263            Arc::new(MockGateway::new("You are a test agent".to_string(), "test".to_string()));
264        let broker = Arc::new(LlmBroker::new("test-model", gateway, None));
265        let tools: Vec<Box<dyn LlmTool>> = vec![];
266
267        let wrapper =
268            ToolWrapper::new(broker, tools, "You are a test agent", "test_agent", "A test agent");
269
270        let args = HashMap::new();
271        let result = wrapper.run(&args);
272
273        assert!(result.is_err());
274        match result {
275            Err(crate::error::MojenticError::ToolError(message)) => {
276                assert_eq!(message, "Missing 'input' parameter");
277            }
278            _ => panic!("Expected ToolError"),
279        }
280    }
281
282    #[tokio::test(flavor = "multi_thread")]
283    async fn test_tool_wrapper_with_tools() {
284        // Mock tool for the wrapped agent
285        struct MockTool;
286
287        impl LlmTool for MockTool {
288            fn run(&self, _args: &HashMap<String, Value>) -> Result<Value> {
289                Ok(json!({"result": "tool executed"}))
290            }
291
292            fn descriptor(&self) -> ToolDescriptor {
293                ToolDescriptor {
294                    r#type: "function".to_string(),
295                    function: FunctionDescriptor {
296                        name: "mock_tool".to_string(),
297                        description: "A mock tool".to_string(),
298                        parameters: json!({}),
299                    },
300                }
301            }
302
303            fn clone_box(&self) -> Box<dyn LlmTool> {
304                Box::new(MockTool)
305            }
306        }
307
308        let gateway = Arc::new(MockGateway::new(
309            "You are an agent with tools".to_string(),
310            "Task completed using tools".to_string(),
311        ));
312        let broker = Arc::new(LlmBroker::new("test-model", gateway, None));
313        let tools: Vec<Box<dyn LlmTool>> = vec![Box::new(MockTool)];
314
315        let wrapper = ToolWrapper::new(
316            broker,
317            tools,
318            "You are an agent with tools",
319            "tool_agent",
320            "An agent that has access to tools",
321        );
322
323        let mut args = HashMap::new();
324        args.insert("input".to_string(), json!("Use your tools"));
325
326        let result = wrapper.run(&args).unwrap();
327
328        assert_eq!(result, json!("Task completed using tools"));
329    }
330
331    #[tokio::test]
332    async fn test_tool_wrapper_matches() {
333        let gateway = Arc::new(MockGateway::new("test".to_string(), "test".to_string()));
334        let broker = Arc::new(LlmBroker::new("test-model", gateway, None));
335        let tools: Vec<Box<dyn LlmTool>> = vec![];
336
337        let wrapper =
338            ToolWrapper::new(broker, tools, "You are a test agent", "my_agent", "A test agent");
339
340        assert!(wrapper.matches("my_agent"));
341        assert!(!wrapper.matches("other_agent"));
342    }
343}