mojentic/llm/tools/
web_search_tool.rs

1use crate::error::Result;
2use crate::llm::tools::{FunctionDescriptor, LlmTool, ToolDescriptor};
3use scraper::{Html, Selector};
4use serde::{Deserialize, Serialize};
5use serde_json::{json, Value};
6use std::collections::HashMap;
7
8const BASE_URL: &str = "https://lite.duckduckgo.com/lite/";
9const MAX_RESULTS: usize = 10;
10const TIMEOUT_SECONDS: u64 = 10;
11
12/// A web search result from DuckDuckGo
13#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
14pub struct SearchResult {
15    /// The title of the search result
16    pub title: String,
17    /// The URL of the search result
18    pub url: String,
19    /// A snippet/description of the search result
20    pub snippet: String,
21}
22
23/// Tool for searching the web using DuckDuckGo
24///
25/// This tool searches DuckDuckGo's lite endpoint and returns organic search results.
26/// It does not require an API key, making it a free alternative to paid search APIs.
27///
28/// # Examples
29///
30/// ```ignore
31/// use mojentic::llm::tools::web_search_tool::WebSearchTool;
32/// use std::collections::HashMap;
33///
34/// let tool = WebSearchTool::new();
35/// let mut args = HashMap::new();
36/// args.insert("query".to_string(), serde_json::json!("Rust programming"));
37///
38/// let results = tool.run(&args)?;
39/// // results contains an array of search results with title, url, and snippet
40/// ```
41#[derive(Clone)]
42pub struct WebSearchTool {
43    client: reqwest::Client,
44}
45
46impl WebSearchTool {
47    /// Creates a new WebSearchTool instance
48    pub fn new() -> Self {
49        let client = reqwest::Client::builder()
50            .timeout(std::time::Duration::from_secs(TIMEOUT_SECONDS))
51            .build()
52            .expect("Failed to create HTTP client");
53
54        Self { client }
55    }
56
57    /// Creates a new WebSearchTool with a custom HTTP client (for testing)
58    #[cfg(test)]
59    pub fn with_client(client: reqwest::Client) -> Self {
60        Self { client }
61    }
62
63    /// Perform the web search
64    async fn perform_search(&self, query: &str) -> Result<Vec<SearchResult>> {
65        let url = format!("{}?q={}", BASE_URL, urlencoding::encode(query));
66
67        let response = self
68            .client
69            .get(&url)
70            .send()
71            .await
72            .map_err(crate::error::MojenticError::HttpError)?;
73
74        if !response.status().is_success() {
75            return Err(crate::error::MojenticError::ApiError(format!(
76                "HTTP request failed with status {}",
77                response.status()
78            )));
79        }
80
81        let html = response.text().await.map_err(crate::error::MojenticError::HttpError)?;
82
83        self.parse_results(&html)
84    }
85
86    /// Parse HTML results from DuckDuckGo lite
87    fn parse_results(&self, html: &str) -> Result<Vec<SearchResult>> {
88        let document = Html::parse_document(html);
89
90        // DuckDuckGo lite uses a simple structure with result-link class
91        let link_selector = Selector::parse("a.result-link").map_err(|e| {
92            crate::error::MojenticError::ParseError(format!("Invalid selector: {:?}", e))
93        })?;
94        let snippet_selector = Selector::parse("td.result-snippet").map_err(|e| {
95            crate::error::MojenticError::ParseError(format!("Invalid selector: {:?}", e))
96        })?;
97
98        let mut results = Vec::new();
99
100        // Extract links and titles
101        let links: Vec<_> = document.select(&link_selector).collect();
102        let snippets: Vec<_> = document.select(&snippet_selector).collect();
103
104        for (i, link) in links.iter().take(MAX_RESULTS).enumerate() {
105            if let Some(href) = link.value().attr("href") {
106                let title = link.text().collect::<Vec<_>>().join(" ");
107                let url = Self::decode_url(href);
108
109                let snippet = snippets
110                    .get(i)
111                    .map(|s| Self::clean_text(&s.text().collect::<Vec<_>>().join(" ")))
112                    .unwrap_or_default();
113
114                results.push(SearchResult {
115                    title: Self::clean_text(&title),
116                    url,
117                    snippet,
118                });
119            }
120        }
121
122        Ok(results)
123    }
124
125    /// Decode DuckDuckGo redirect URLs
126    fn decode_url(url: &str) -> String {
127        // DuckDuckGo uses redirect URLs like //duckduckgo.com/l/?uddg=https%3A%2F%2Fexample.com
128        if url.contains("uddg=") {
129            url.split("uddg=")
130                .nth(1)
131                .and_then(|s| s.split('&').next())
132                .map(|s| urlencoding::decode(s).unwrap_or_default().to_string())
133                .unwrap_or_else(|| url.to_string())
134        } else {
135            url.to_string()
136        }
137    }
138
139    /// Clean text by removing extra whitespace and decoding HTML entities
140    fn clean_text(text: &str) -> String {
141        let text = text.trim().replace(|c: char| c.is_whitespace(), " ");
142        let text = text.split_whitespace().collect::<Vec<_>>().join(" ");
143
144        // Decode common HTML entities
145        text.replace("&amp;", "&")
146            .replace("&lt;", "<")
147            .replace("&gt;", ">")
148            .replace("&quot;", "\"")
149            .replace("&#39;", "'")
150    }
151}
152
153impl Default for WebSearchTool {
154    fn default() -> Self {
155        Self::new()
156    }
157}
158
159impl LlmTool for WebSearchTool {
160    fn run(&self, args: &HashMap<String, Value>) -> Result<Value> {
161        let query = args.get("query").and_then(|v| v.as_str()).ok_or_else(|| {
162            crate::error::MojenticError::InvalidArgument("query parameter is required".to_string())
163        })?;
164
165        if query.is_empty() {
166            return Err(crate::error::MojenticError::InvalidArgument(
167                "query parameter cannot be empty".to_string(),
168            ));
169        }
170
171        // Since we're in a sync context, we need to run the async function
172        let rt = tokio::runtime::Runtime::new().map_err(|e| {
173            crate::error::MojenticError::RuntimeError(format!("Failed to create runtime: {}", e))
174        })?;
175
176        let results = rt.block_on(self.perform_search(query)).map_err(|e| {
177            crate::error::MojenticError::ToolExecutionError(format!("Search failed: {}", e))
178        })?;
179
180        Ok(json!(results))
181    }
182
183    fn descriptor(&self) -> ToolDescriptor {
184        ToolDescriptor {
185            r#type: "function".to_string(),
186            function: FunctionDescriptor {
187                name: "web_search".to_string(),
188                description: "Search the web for information using DuckDuckGo. Returns organic search results including title, URL, and snippet for each result.".to_string(),
189                parameters: json!({
190                    "type": "object",
191                    "properties": {
192                        "query": {
193                            "type": "string",
194                            "description": "The search query"
195                        }
196                    },
197                    "required": ["query"]
198                }),
199            },
200        }
201    }
202
203    fn clone_box(&self) -> Box<dyn LlmTool> {
204        Box::new(self.clone())
205    }
206}
207
208#[cfg(test)]
209mod tests {
210    use super::*;
211    use mockito::Server;
212
213    fn sample_html() -> String {
214        r#"
215        <!DOCTYPE html>
216        <html>
217        <body>
218            <table>
219                <tr>
220                    <td>
221                        <a class="result-link" href="//duckduckgo.com/l/?uddg=https%3A%2F%2Fwww.rust-lang.org%2F">The Rust Programming Language</a>
222                    </td>
223                </tr>
224                <tr>
225                    <td class="result-snippet">A language empowering everyone to build reliable and efficient software.</td>
226                </tr>
227                <tr>
228                    <td>
229                        <a class="result-link" href="//duckduckgo.com/l/?uddg=https%3A%2F%2Fdoc.rust-lang.org%2F">Rust Documentation</a>
230                    </td>
231                </tr>
232                <tr>
233                    <td class="result-snippet">The official Rust documentation and learning resources.</td>
234                </tr>
235                <tr>
236                    <td>
237                        <a class="result-link" href="//duckduckgo.com/l/?uddg=https%3A%2F%2Fcrates.io%2F">crates.io: Rust Package Registry</a>
238                    </td>
239                </tr>
240                <tr>
241                    <td class="result-snippet">The Rust community's crate registry.</td>
242                </tr>
243            </table>
244        </body>
245        </html>
246        "#.to_string()
247    }
248
249    #[test]
250    fn test_descriptor() {
251        let tool = WebSearchTool::new();
252        let descriptor = tool.descriptor();
253
254        assert_eq!(descriptor.r#type, "function");
255        assert_eq!(descriptor.function.name, "web_search");
256        assert!(descriptor
257            .function
258            .description
259            .contains("Search the web for information using DuckDuckGo"));
260
261        let params = descriptor.function.parameters;
262        assert_eq!(params["type"], "object");
263        assert!(params["properties"]["query"].is_object());
264        assert_eq!(params["required"][0], "query");
265    }
266
267    #[test]
268    fn test_parse_results() {
269        let tool = WebSearchTool::new();
270        let html = sample_html();
271
272        let results = tool.parse_results(&html).unwrap();
273
274        assert_eq!(results.len(), 3);
275
276        assert_eq!(results[0].title, "The Rust Programming Language");
277        assert_eq!(results[0].url, "https://www.rust-lang.org/");
278        assert!(results[0].snippet.contains("A language empowering everyone"));
279
280        assert_eq!(results[1].title, "Rust Documentation");
281        assert_eq!(results[1].url, "https://doc.rust-lang.org/");
282        assert!(results[1].snippet.contains("official Rust documentation"));
283
284        assert_eq!(results[2].title, "crates.io: Rust Package Registry");
285        assert_eq!(results[2].url, "https://crates.io/");
286        assert!(results[2].snippet.contains("crate registry"));
287    }
288
289    #[test]
290    fn test_decode_url() {
291        let url = "//duckduckgo.com/l/?uddg=https%3A%2F%2Fexample.com%2Fpath";
292        let decoded = WebSearchTool::decode_url(url);
293        assert_eq!(decoded, "https://example.com/path");
294
295        // Test URL without encoding
296        let url = "https://example.com/direct";
297        let decoded = WebSearchTool::decode_url(url);
298        assert_eq!(decoded, "https://example.com/direct");
299    }
300
301    #[test]
302    fn test_clean_text() {
303        assert_eq!(WebSearchTool::clean_text("  Multiple   spaces  "), "Multiple spaces");
304        assert_eq!(WebSearchTool::clean_text("Text&amp;more"), "Text&more");
305        assert_eq!(WebSearchTool::clean_text("&lt;tag&gt;"), "<tag>");
306        assert_eq!(WebSearchTool::clean_text("&quot;quoted&quot;"), "\"quoted\"");
307        assert_eq!(WebSearchTool::clean_text("it&#39;s"), "it's");
308    }
309
310    #[test]
311    fn test_tool_matches() {
312        let tool = WebSearchTool::new();
313        assert!(tool.matches("web_search"));
314        assert!(!tool.matches("other_tool"));
315    }
316
317    #[test]
318    fn test_run_missing_query() {
319        let tool = WebSearchTool::new();
320        let args = HashMap::new();
321
322        let result = tool.run(&args);
323        assert!(result.is_err());
324        assert!(result.unwrap_err().to_string().contains("query parameter is required"));
325    }
326
327    #[test]
328    fn test_run_empty_query() {
329        let tool = WebSearchTool::new();
330        let mut args = HashMap::new();
331        args.insert("query".to_string(), json!(""));
332
333        let result = tool.run(&args);
334        assert!(result.is_err());
335        assert!(result.unwrap_err().to_string().contains("query parameter cannot be empty"));
336    }
337
338    #[tokio::test]
339    async fn test_perform_search_success() {
340        let mut server = Server::new_async().await;
341        let mock = server
342            .mock("GET", mockito::Matcher::Any)
343            .with_status(200)
344            .with_body(sample_html())
345            .create_async()
346            .await;
347
348        let client = reqwest::Client::builder()
349            .timeout(std::time::Duration::from_secs(10))
350            .build()
351            .unwrap();
352
353        let tool = WebSearchTool::with_client(client);
354
355        // Override BASE_URL for testing by constructing URL directly
356        let url = format!("{}?q=rust", server.url());
357        let response = tool.client.get(&url).send().await.unwrap();
358        let html = response.text().await.unwrap();
359        let results = tool.parse_results(&html).unwrap();
360
361        assert_eq!(results.len(), 3);
362        assert_eq!(results[0].title, "The Rust Programming Language");
363
364        mock.assert_async().await;
365    }
366
367    #[tokio::test]
368    async fn test_perform_search_http_error() {
369        let mut server = Server::new_async().await;
370        let mock = server.mock("GET", mockito::Matcher::Any).with_status(500).create_async().await;
371
372        let client = reqwest::Client::builder()
373            .timeout(std::time::Duration::from_secs(10))
374            .build()
375            .unwrap();
376
377        let tool = WebSearchTool::with_client(client);
378
379        let url = format!("{}?q=test", server.url());
380        let response = tool.client.get(&url).send().await.unwrap();
381
382        assert_eq!(response.status(), 500);
383
384        mock.assert_async().await;
385    }
386
387    #[test]
388    fn test_max_results_limit() {
389        let tool = WebSearchTool::new();
390
391        // Create HTML with more than MAX_RESULTS entries
392        let mut html = String::from("<html><body><table>");
393        for i in 0..15 {
394            html.push_str(&format!(
395                r#"<tr><td><a class="result-link" href="https://example.com/{}">Result {}</a></td></tr>"#,
396                i, i
397            ));
398            html.push_str(&format!(r#"<tr><td class="result-snippet">Snippet {}</td></tr>"#, i));
399        }
400        html.push_str("</table></body></html>");
401
402        let results = tool.parse_results(&html).unwrap();
403
404        assert_eq!(results.len(), MAX_RESULTS);
405    }
406
407    #[test]
408    fn test_clone_box() {
409        let tool = WebSearchTool::new();
410        let cloned = tool.clone_box();
411
412        assert_eq!(cloned.descriptor().function.name, tool.descriptor().function.name);
413    }
414
415    #[test]
416    fn test_search_result_serialization() {
417        let result = SearchResult {
418            title: "Test Title".to_string(),
419            url: "https://example.com".to_string(),
420            snippet: "Test snippet".to_string(),
421        };
422
423        let json = serde_json::to_string(&result).unwrap();
424        assert!(json.contains("Test Title"));
425        assert!(json.contains("https://example.com"));
426        assert!(json.contains("Test snippet"));
427
428        let deserialized: SearchResult = serde_json::from_str(&json).unwrap();
429        assert_eq!(deserialized, result);
430    }
431
432    #[test]
433    fn test_parse_empty_html() {
434        let tool = WebSearchTool::new();
435        let html = "<html><body></body></html>";
436
437        let results = tool.parse_results(html).unwrap();
438        assert_eq!(results.len(), 0);
439    }
440
441    #[test]
442    fn test_parse_malformed_html() {
443        let tool = WebSearchTool::new();
444        let html = "<html><body><a class=\"result-link\">No href</a></body></html>";
445
446        let results = tool.parse_results(html).unwrap();
447        assert_eq!(results.len(), 0);
448    }
449}