mojentic/llm/tools/
ask_user_tool.rs1use crate::error::Result;
2use crate::llm::tools::{FunctionDescriptor, LlmTool, ToolDescriptor};
3use serde_json::{json, Value};
4use std::collections::HashMap;
5use std::io::{self, Write};
6
7#[derive(Clone)]
29pub struct AskUserTool;
30
31impl AskUserTool {
32 pub fn new() -> Self {
34 Self
35 }
36
37 fn prompt_user(&self, request: &str) -> Result<String> {
39 println!("\n\n\nI NEED YOUR HELP!\n{}", request);
40 print!("Your response: ");
41 io::stdout().flush().map_err(crate::error::MojenticError::IoError)?;
42
43 let mut response = String::new();
44 io::stdin()
45 .read_line(&mut response)
46 .map_err(crate::error::MojenticError::IoError)?;
47
48 Ok(response.trim().to_string())
49 }
50}
51
52impl Default for AskUserTool {
53 fn default() -> Self {
54 Self::new()
55 }
56}
57
58impl LlmTool for AskUserTool {
59 fn run(&self, args: &HashMap<String, Value>) -> Result<Value> {
60 let user_request = args.get("user_request").and_then(|v| v.as_str()).ok_or_else(|| {
61 crate::error::MojenticError::ToolError(
62 "Missing required argument: user_request".to_string(),
63 )
64 })?;
65
66 let response = self.prompt_user(user_request)?;
67 Ok(json!(response))
68 }
69
70 fn descriptor(&self) -> ToolDescriptor {
71 ToolDescriptor {
72 r#type: "function".to_string(),
73 function: FunctionDescriptor {
74 name: "ask_user".to_string(),
75 description: "If you do not know how to proceed, ask the user a question, or ask them for help or to do something for you.".to_string(),
76 parameters: json!({
77 "type": "object",
78 "properties": {
79 "user_request": {
80 "type": "string",
81 "description": "The question you need the user to answer, or the task you need the user to do for you."
82 }
83 },
84 "required": ["user_request"]
85 }),
86 },
87 }
88 }
89
90 fn clone_box(&self) -> Box<dyn LlmTool> {
91 Box::new(self.clone())
92 }
93}
94
95#[cfg(test)]
96mod tests {
97 use super::*;
98
99 #[test]
100 fn test_descriptor() {
101 let tool = AskUserTool::new();
102 let descriptor = tool.descriptor();
103
104 assert_eq!(descriptor.r#type, "function");
105 assert_eq!(descriptor.function.name, "ask_user");
106 assert!(descriptor.function.description.contains("ask the user a question"));
107
108 let params = &descriptor.function.parameters;
110 assert_eq!(params["type"], "object");
111 assert!(params["properties"]["user_request"].is_object());
112 assert_eq!(params["properties"]["user_request"]["type"], "string");
113 assert!(params["required"].as_array().unwrap().contains(&json!("user_request")));
114 }
115
116 #[test]
117 fn test_tool_matches() {
118 let tool = AskUserTool::new();
119 assert!(tool.matches("ask_user"));
120 assert!(!tool.matches("other_tool"));
121 }
122
123 #[test]
124 fn test_missing_argument() {
125 let tool = AskUserTool::new();
126 let args = HashMap::new();
127
128 let result = tool.run(&args);
129 assert!(result.is_err());
130 assert!(matches!(result.unwrap_err(), crate::error::MojenticError::ToolError(_)));
131 }
132
133 #[test]
134 fn test_default_implementation() {
135 let tool = AskUserTool;
136 let descriptor = tool.descriptor();
137
138 assert_eq!(descriptor.function.name, "ask_user");
139 }
140}