mojentic/llm/tools/
tool.rs

1use crate::error::Result;
2use serde_json::Value;
3use std::collections::HashMap;
4
5/// Descriptor for tool function parameters
6#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
7pub struct ToolDescriptor {
8    pub r#type: String,
9    pub function: FunctionDescriptor,
10}
11
12#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
13pub struct FunctionDescriptor {
14    pub name: String,
15    pub description: String,
16    pub parameters: Value,
17}
18
19/// Trait for LLM tools
20pub trait LlmTool: Send + Sync {
21    /// Execute the tool with given arguments
22    fn run(&self, args: &HashMap<String, Value>) -> Result<Value>;
23
24    /// Get tool descriptor for LLM
25    fn descriptor(&self) -> ToolDescriptor;
26
27    /// Check if this tool matches the given name
28    fn matches(&self, name: &str) -> bool {
29        self.descriptor().function.name == name
30    }
31
32    /// Clone the tool into a Box
33    ///
34    /// This method is required to support cloning trait objects.
35    /// Implementations should return `Box::new(self.clone())`.
36    fn clone_box(&self) -> Box<dyn LlmTool>;
37}
38
39#[cfg(test)]
40mod tests {
41    use super::*;
42    use serde_json::json;
43
44    #[test]
45    fn test_tool_descriptor_serialization() {
46        let descriptor = ToolDescriptor {
47            r#type: "function".to_string(),
48            function: FunctionDescriptor {
49                name: "test_tool".to_string(),
50                description: "A test tool".to_string(),
51                parameters: json!({
52                    "type": "object",
53                    "properties": {
54                        "arg1": {"type": "string"}
55                    }
56                }),
57            },
58        };
59
60        let json = serde_json::to_string(&descriptor).unwrap();
61        assert!(json.contains("test_tool"));
62        assert!(json.contains("A test tool"));
63        assert!(json.contains("function"));
64    }
65
66    #[test]
67    fn test_tool_descriptor_deserialization() {
68        let json = r#"{
69            "type": "function",
70            "function": {
71                "name": "calculator",
72                "description": "Perform calculations",
73                "parameters": {
74                    "type": "object",
75                    "properties": {
76                        "expression": {"type": "string"}
77                    }
78                }
79            }
80        }"#;
81
82        let descriptor: ToolDescriptor = serde_json::from_str(json).unwrap();
83        assert_eq!(descriptor.r#type, "function");
84        assert_eq!(descriptor.function.name, "calculator");
85        assert_eq!(descriptor.function.description, "Perform calculations");
86    }
87
88    #[test]
89    fn test_function_descriptor_clone() {
90        let desc1 = FunctionDescriptor {
91            name: "test".to_string(),
92            description: "desc".to_string(),
93            parameters: json!({"type": "object"}),
94        };
95
96        let desc2 = desc1.clone();
97        assert_eq!(desc1.name, desc2.name);
98        assert_eq!(desc1.description, desc2.description);
99    }
100
101    struct MockTool;
102
103    impl LlmTool for MockTool {
104        fn run(&self, _args: &HashMap<String, Value>) -> Result<Value> {
105            Ok(json!("result"))
106        }
107
108        fn descriptor(&self) -> ToolDescriptor {
109            ToolDescriptor {
110                r#type: "function".to_string(),
111                function: FunctionDescriptor {
112                    name: "mock_tool".to_string(),
113                    description: "A mock tool".to_string(),
114                    parameters: json!({}),
115                },
116            }
117        }
118
119        fn clone_box(&self) -> Box<dyn LlmTool> {
120            Box::new(MockTool)
121        }
122    }
123
124    #[test]
125    fn test_tool_matches() {
126        let tool = MockTool;
127        assert!(tool.matches("mock_tool"));
128        assert!(!tool.matches("other_tool"));
129    }
130
131    #[test]
132    fn test_tool_run() {
133        let tool = MockTool;
134        let args = HashMap::new();
135        let result = tool.run(&args).unwrap();
136        assert_eq!(result, json!("result"));
137    }
138}