mojentic/llm/tools/
ask_user_tool.rs

1use crate::error::Result;
2use crate::llm::tools::{FunctionDescriptor, LlmTool, ToolDescriptor};
3use serde_json::{json, Value};
4use std::collections::HashMap;
5use std::io::{self, Write};
6
7/// Tool for prompting the user for input or assistance
8///
9/// This tool allows the LLM to ask the user questions or request help when it
10/// doesn't have enough information to proceed. The user's response is returned
11/// as the tool's result.
12///
13/// # Examples
14///
15/// ```
16/// use mojentic::llm::tools::ask_user_tool::AskUserTool;
17/// use mojentic::llm::tools::LlmTool;
18/// use std::collections::HashMap;
19/// use serde_json::json;
20///
21/// let tool = AskUserTool::new();
22/// let mut args = HashMap::new();
23/// args.insert("user_request".to_string(), json!("What is your favorite color?"));
24///
25/// // This would prompt the user for input
26/// // let result = tool.run(&args).unwrap();
27/// ```
28#[derive(Clone)]
29pub struct AskUserTool;
30
31impl AskUserTool {
32    /// Creates a new AskUserTool instance
33    pub fn new() -> Self {
34        Self
35    }
36
37    /// Prompt the user with a message and read their response
38    fn prompt_user(&self, request: &str) -> Result<String> {
39        println!("\n\n\nI NEED YOUR HELP!\n{}", request);
40        print!("Your response: ");
41        io::stdout().flush().map_err(crate::error::MojenticError::IoError)?;
42
43        let mut response = String::new();
44        io::stdin()
45            .read_line(&mut response)
46            .map_err(crate::error::MojenticError::IoError)?;
47
48        Ok(response.trim().to_string())
49    }
50}
51
52impl Default for AskUserTool {
53    fn default() -> Self {
54        Self::new()
55    }
56}
57
58impl LlmTool for AskUserTool {
59    fn run(&self, args: &HashMap<String, Value>) -> Result<Value> {
60        let user_request = args.get("user_request").and_then(|v| v.as_str()).ok_or_else(|| {
61            crate::error::MojenticError::ToolError(
62                "Missing required argument: user_request".to_string(),
63            )
64        })?;
65
66        let response = self.prompt_user(user_request)?;
67        Ok(json!(response))
68    }
69
70    fn descriptor(&self) -> ToolDescriptor {
71        ToolDescriptor {
72            r#type: "function".to_string(),
73            function: FunctionDescriptor {
74                name: "ask_user".to_string(),
75                description: "If you do not know how to proceed, ask the user a question, or ask them for help or to do something for you.".to_string(),
76                parameters: json!({
77                    "type": "object",
78                    "properties": {
79                        "user_request": {
80                            "type": "string",
81                            "description": "The question you need the user to answer, or the task you need the user to do for you."
82                        }
83                    },
84                    "required": ["user_request"]
85                }),
86            },
87        }
88    }
89
90    fn clone_box(&self) -> Box<dyn LlmTool> {
91        Box::new(self.clone())
92    }
93}
94
95#[cfg(test)]
96mod tests {
97    use super::*;
98
99    #[test]
100    fn test_descriptor() {
101        let tool = AskUserTool::new();
102        let descriptor = tool.descriptor();
103
104        assert_eq!(descriptor.r#type, "function");
105        assert_eq!(descriptor.function.name, "ask_user");
106        assert!(descriptor.function.description.contains("ask the user a question"));
107
108        // Verify parameters structure
109        let params = &descriptor.function.parameters;
110        assert_eq!(params["type"], "object");
111        assert!(params["properties"]["user_request"].is_object());
112        assert_eq!(params["properties"]["user_request"]["type"], "string");
113        assert!(params["required"].as_array().unwrap().contains(&json!("user_request")));
114    }
115
116    #[test]
117    fn test_tool_matches() {
118        let tool = AskUserTool::new();
119        assert!(tool.matches("ask_user"));
120        assert!(!tool.matches("other_tool"));
121    }
122
123    #[test]
124    fn test_missing_argument() {
125        let tool = AskUserTool::new();
126        let args = HashMap::new();
127
128        let result = tool.run(&args);
129        assert!(result.is_err());
130        assert!(matches!(result.unwrap_err(), crate::error::MojenticError::ToolError(_)));
131    }
132
133    #[test]
134    fn test_default_implementation() {
135        let tool = AskUserTool;
136        let descriptor = tool.descriptor();
137
138        assert_eq!(descriptor.function.name, "ask_user");
139    }
140}