mojentic/llm/tools/
web_search_tool.rs1use crate::error::Result;
2use crate::llm::tools::{FunctionDescriptor, LlmTool, ToolDescriptor};
3use scraper::{Html, Selector};
4use serde::{Deserialize, Serialize};
5use serde_json::{json, Value};
6use std::collections::HashMap;
7
8const BASE_URL: &str = "https://lite.duckduckgo.com/lite/";
9const MAX_RESULTS: usize = 10;
10const TIMEOUT_SECONDS: u64 = 10;
11
12#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
14pub struct SearchResult {
15 pub title: String,
17 pub url: String,
19 pub snippet: String,
21}
22
23#[derive(Clone)]
42pub struct WebSearchTool {
43 client: reqwest::Client,
44}
45
46impl WebSearchTool {
47 pub fn new() -> Self {
49 let client = reqwest::Client::builder()
50 .timeout(std::time::Duration::from_secs(TIMEOUT_SECONDS))
51 .build()
52 .expect("Failed to create HTTP client");
53
54 Self { client }
55 }
56
57 #[cfg(test)]
59 pub fn with_client(client: reqwest::Client) -> Self {
60 Self { client }
61 }
62
63 async fn perform_search(&self, query: &str) -> Result<Vec<SearchResult>> {
65 let url = format!("{}?q={}", BASE_URL, urlencoding::encode(query));
66
67 let response = self
68 .client
69 .get(&url)
70 .send()
71 .await
72 .map_err(crate::error::MojenticError::HttpError)?;
73
74 if !response.status().is_success() {
75 return Err(crate::error::MojenticError::ApiError(format!(
76 "HTTP request failed with status {}",
77 response.status()
78 )));
79 }
80
81 let html = response.text().await.map_err(crate::error::MojenticError::HttpError)?;
82
83 self.parse_results(&html)
84 }
85
86 fn parse_results(&self, html: &str) -> Result<Vec<SearchResult>> {
88 let document = Html::parse_document(html);
89
90 let link_selector = Selector::parse("a.result-link").map_err(|e| {
92 crate::error::MojenticError::ParseError(format!("Invalid selector: {:?}", e))
93 })?;
94 let snippet_selector = Selector::parse("td.result-snippet").map_err(|e| {
95 crate::error::MojenticError::ParseError(format!("Invalid selector: {:?}", e))
96 })?;
97
98 let mut results = Vec::new();
99
100 let links: Vec<_> = document.select(&link_selector).collect();
102 let snippets: Vec<_> = document.select(&snippet_selector).collect();
103
104 for (i, link) in links.iter().take(MAX_RESULTS).enumerate() {
105 if let Some(href) = link.value().attr("href") {
106 let title = link.text().collect::<Vec<_>>().join(" ");
107 let url = Self::decode_url(href);
108
109 let snippet = snippets
110 .get(i)
111 .map(|s| Self::clean_text(&s.text().collect::<Vec<_>>().join(" ")))
112 .unwrap_or_default();
113
114 results.push(SearchResult {
115 title: Self::clean_text(&title),
116 url,
117 snippet,
118 });
119 }
120 }
121
122 Ok(results)
123 }
124
125 fn decode_url(url: &str) -> String {
127 if url.contains("uddg=") {
129 url.split("uddg=")
130 .nth(1)
131 .and_then(|s| s.split('&').next())
132 .map(|s| urlencoding::decode(s).unwrap_or_default().to_string())
133 .unwrap_or_else(|| url.to_string())
134 } else {
135 url.to_string()
136 }
137 }
138
139 fn clean_text(text: &str) -> String {
141 let text = text.trim().replace(|c: char| c.is_whitespace(), " ");
142 let text = text.split_whitespace().collect::<Vec<_>>().join(" ");
143
144 text.replace("&", "&")
146 .replace("<", "<")
147 .replace(">", ">")
148 .replace(""", "\"")
149 .replace("'", "'")
150 }
151}
152
153impl Default for WebSearchTool {
154 fn default() -> Self {
155 Self::new()
156 }
157}
158
159impl LlmTool for WebSearchTool {
160 fn run(&self, args: &HashMap<String, Value>) -> Result<Value> {
161 let query = args.get("query").and_then(|v| v.as_str()).ok_or_else(|| {
162 crate::error::MojenticError::InvalidArgument("query parameter is required".to_string())
163 })?;
164
165 if query.is_empty() {
166 return Err(crate::error::MojenticError::InvalidArgument(
167 "query parameter cannot be empty".to_string(),
168 ));
169 }
170
171 let rt = tokio::runtime::Runtime::new().map_err(|e| {
173 crate::error::MojenticError::RuntimeError(format!("Failed to create runtime: {}", e))
174 })?;
175
176 let results = rt.block_on(self.perform_search(query)).map_err(|e| {
177 crate::error::MojenticError::ToolExecutionError(format!("Search failed: {}", e))
178 })?;
179
180 Ok(json!(results))
181 }
182
183 fn descriptor(&self) -> ToolDescriptor {
184 ToolDescriptor {
185 r#type: "function".to_string(),
186 function: FunctionDescriptor {
187 name: "web_search".to_string(),
188 description: "Search the web for information using DuckDuckGo. Returns organic search results including title, URL, and snippet for each result.".to_string(),
189 parameters: json!({
190 "type": "object",
191 "properties": {
192 "query": {
193 "type": "string",
194 "description": "The search query"
195 }
196 },
197 "required": ["query"]
198 }),
199 },
200 }
201 }
202
203 fn clone_box(&self) -> Box<dyn LlmTool> {
204 Box::new(self.clone())
205 }
206}
207
208#[cfg(test)]
209mod tests {
210 use super::*;
211 use mockito::Server;
212
213 fn sample_html() -> String {
214 r#"
215 <!DOCTYPE html>
216 <html>
217 <body>
218 <table>
219 <tr>
220 <td>
221 <a class="result-link" href="//duckduckgo.com/l/?uddg=https%3A%2F%2Fwww.rust-lang.org%2F">The Rust Programming Language</a>
222 </td>
223 </tr>
224 <tr>
225 <td class="result-snippet">A language empowering everyone to build reliable and efficient software.</td>
226 </tr>
227 <tr>
228 <td>
229 <a class="result-link" href="//duckduckgo.com/l/?uddg=https%3A%2F%2Fdoc.rust-lang.org%2F">Rust Documentation</a>
230 </td>
231 </tr>
232 <tr>
233 <td class="result-snippet">The official Rust documentation and learning resources.</td>
234 </tr>
235 <tr>
236 <td>
237 <a class="result-link" href="//duckduckgo.com/l/?uddg=https%3A%2F%2Fcrates.io%2F">crates.io: Rust Package Registry</a>
238 </td>
239 </tr>
240 <tr>
241 <td class="result-snippet">The Rust community's crate registry.</td>
242 </tr>
243 </table>
244 </body>
245 </html>
246 "#.to_string()
247 }
248
249 #[test]
250 fn test_descriptor() {
251 let tool = WebSearchTool::new();
252 let descriptor = tool.descriptor();
253
254 assert_eq!(descriptor.r#type, "function");
255 assert_eq!(descriptor.function.name, "web_search");
256 assert!(descriptor
257 .function
258 .description
259 .contains("Search the web for information using DuckDuckGo"));
260
261 let params = descriptor.function.parameters;
262 assert_eq!(params["type"], "object");
263 assert!(params["properties"]["query"].is_object());
264 assert_eq!(params["required"][0], "query");
265 }
266
267 #[test]
268 fn test_parse_results() {
269 let tool = WebSearchTool::new();
270 let html = sample_html();
271
272 let results = tool.parse_results(&html).unwrap();
273
274 assert_eq!(results.len(), 3);
275
276 assert_eq!(results[0].title, "The Rust Programming Language");
277 assert_eq!(results[0].url, "https://www.rust-lang.org/");
278 assert!(results[0].snippet.contains("A language empowering everyone"));
279
280 assert_eq!(results[1].title, "Rust Documentation");
281 assert_eq!(results[1].url, "https://doc.rust-lang.org/");
282 assert!(results[1].snippet.contains("official Rust documentation"));
283
284 assert_eq!(results[2].title, "crates.io: Rust Package Registry");
285 assert_eq!(results[2].url, "https://crates.io/");
286 assert!(results[2].snippet.contains("crate registry"));
287 }
288
289 #[test]
290 fn test_decode_url() {
291 let url = "//duckduckgo.com/l/?uddg=https%3A%2F%2Fexample.com%2Fpath";
292 let decoded = WebSearchTool::decode_url(url);
293 assert_eq!(decoded, "https://example.com/path");
294
295 let url = "https://example.com/direct";
297 let decoded = WebSearchTool::decode_url(url);
298 assert_eq!(decoded, "https://example.com/direct");
299 }
300
301 #[test]
302 fn test_clean_text() {
303 assert_eq!(WebSearchTool::clean_text(" Multiple spaces "), "Multiple spaces");
304 assert_eq!(WebSearchTool::clean_text("Text&more"), "Text&more");
305 assert_eq!(WebSearchTool::clean_text("<tag>"), "<tag>");
306 assert_eq!(WebSearchTool::clean_text(""quoted""), "\"quoted\"");
307 assert_eq!(WebSearchTool::clean_text("it's"), "it's");
308 }
309
310 #[test]
311 fn test_tool_matches() {
312 let tool = WebSearchTool::new();
313 assert!(tool.matches("web_search"));
314 assert!(!tool.matches("other_tool"));
315 }
316
317 #[test]
318 fn test_run_missing_query() {
319 let tool = WebSearchTool::new();
320 let args = HashMap::new();
321
322 let result = tool.run(&args);
323 assert!(result.is_err());
324 assert!(result.unwrap_err().to_string().contains("query parameter is required"));
325 }
326
327 #[test]
328 fn test_run_empty_query() {
329 let tool = WebSearchTool::new();
330 let mut args = HashMap::new();
331 args.insert("query".to_string(), json!(""));
332
333 let result = tool.run(&args);
334 assert!(result.is_err());
335 assert!(result.unwrap_err().to_string().contains("query parameter cannot be empty"));
336 }
337
338 #[tokio::test]
339 async fn test_perform_search_success() {
340 let mut server = Server::new_async().await;
341 let mock = server
342 .mock("GET", mockito::Matcher::Any)
343 .with_status(200)
344 .with_body(sample_html())
345 .create_async()
346 .await;
347
348 let client = reqwest::Client::builder()
349 .timeout(std::time::Duration::from_secs(10))
350 .build()
351 .unwrap();
352
353 let tool = WebSearchTool::with_client(client);
354
355 let url = format!("{}?q=rust", server.url());
357 let response = tool.client.get(&url).send().await.unwrap();
358 let html = response.text().await.unwrap();
359 let results = tool.parse_results(&html).unwrap();
360
361 assert_eq!(results.len(), 3);
362 assert_eq!(results[0].title, "The Rust Programming Language");
363
364 mock.assert_async().await;
365 }
366
367 #[tokio::test]
368 async fn test_perform_search_http_error() {
369 let mut server = Server::new_async().await;
370 let mock = server.mock("GET", mockito::Matcher::Any).with_status(500).create_async().await;
371
372 let client = reqwest::Client::builder()
373 .timeout(std::time::Duration::from_secs(10))
374 .build()
375 .unwrap();
376
377 let tool = WebSearchTool::with_client(client);
378
379 let url = format!("{}?q=test", server.url());
380 let response = tool.client.get(&url).send().await.unwrap();
381
382 assert_eq!(response.status(), 500);
383
384 mock.assert_async().await;
385 }
386
387 #[test]
388 fn test_max_results_limit() {
389 let tool = WebSearchTool::new();
390
391 let mut html = String::from("<html><body><table>");
393 for i in 0..15 {
394 html.push_str(&format!(
395 r#"<tr><td><a class="result-link" href="https://example.com/{}">Result {}</a></td></tr>"#,
396 i, i
397 ));
398 html.push_str(&format!(r#"<tr><td class="result-snippet">Snippet {}</td></tr>"#, i));
399 }
400 html.push_str("</table></body></html>");
401
402 let results = tool.parse_results(&html).unwrap();
403
404 assert_eq!(results.len(), MAX_RESULTS);
405 }
406
407 #[test]
408 fn test_clone_box() {
409 let tool = WebSearchTool::new();
410 let cloned = tool.clone_box();
411
412 assert_eq!(cloned.descriptor().function.name, tool.descriptor().function.name);
413 }
414
415 #[test]
416 fn test_search_result_serialization() {
417 let result = SearchResult {
418 title: "Test Title".to_string(),
419 url: "https://example.com".to_string(),
420 snippet: "Test snippet".to_string(),
421 };
422
423 let json = serde_json::to_string(&result).unwrap();
424 assert!(json.contains("Test Title"));
425 assert!(json.contains("https://example.com"));
426 assert!(json.contains("Test snippet"));
427
428 let deserialized: SearchResult = serde_json::from_str(&json).unwrap();
429 assert_eq!(deserialized, result);
430 }
431
432 #[test]
433 fn test_parse_empty_html() {
434 let tool = WebSearchTool::new();
435 let html = "<html><body></body></html>";
436
437 let results = tool.parse_results(html).unwrap();
438 assert_eq!(results.len(), 0);
439 }
440
441 #[test]
442 fn test_parse_malformed_html() {
443 let tool = WebSearchTool::new();
444 let html = "<html><body><a class=\"result-link\">No href</a></body></html>";
445
446 let results = tool.parse_results(html).unwrap();
447 assert_eq!(results.len(), 0);
448 }
449}