mojentic/agents/
iterative_problem_solver.rs

1//! Iterative problem solver agent that uses tools to break down and solve complex problems.
2//!
3//! This agent uses a chat-based approach to iteratively work on solving a problem,
4//! continuing until it succeeds, fails explicitly, or reaches the maximum number of iterations.
5
6use crate::error::Result;
7use crate::llm::chat_session::ChatSession;
8use crate::llm::tools::LlmTool;
9use crate::llm::LlmBroker;
10use tracing::{info, warn};
11
12/// An agent that iteratively attempts to solve a problem using available tools.
13///
14/// The solver uses a chat-based approach to break down and solve complex problems.
15/// It will continue attempting to solve the problem until it either succeeds,
16/// fails explicitly, or reaches the maximum number of iterations.
17///
18/// # Examples
19///
20/// ```ignore
21/// use mojentic::agents::IterativeProblemSolver;
22/// use mojentic::llm::{LlmBroker, LlmGateway};
23/// use mojentic::llm::gateways::OllamaGateway;
24/// use mojentic::llm::tools::simple_date_tool::SimpleDateTool;
25/// use std::sync::Arc;
26///
27/// #[tokio::main]
28/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
29///     let gateway = Arc::new(OllamaGateway::default());
30///     let broker = LlmBroker::new("qwen3:32b", gateway, None);
31///
32///     let tools: Vec<Box<dyn LlmTool>> = vec![Box::new(SimpleDateTool)];
33///
34///     let solver = IterativeProblemSolver::builder(broker)
35///         .tools(tools)
36///         .max_iterations(5)
37///         .build();
38///
39///     let result = solver.solve("What's the date next Friday?").await?;
40///     println!("Result: {}", result);
41///
42///     Ok(())
43/// }
44/// ```
45pub struct IterativeProblemSolver {
46    chat: ChatSession,
47    max_iterations: usize,
48}
49
50impl IterativeProblemSolver {
51    /// Create a new problem solver with default settings.
52    ///
53    /// # Arguments
54    ///
55    /// * `broker` - The LLM broker to use for generating responses
56    ///
57    /// # Examples
58    ///
59    /// ```ignore
60    /// use mojentic::agents::IterativeProblemSolver;
61    /// use mojentic::llm::LlmBroker;
62    ///
63    /// let solver = IterativeProblemSolver::new(broker);
64    /// ```
65    pub fn new(broker: LlmBroker) -> Self {
66        Self::builder(broker).build()
67    }
68
69    /// Create a problem solver builder for custom configuration.
70    ///
71    /// # Arguments
72    ///
73    /// * `broker` - The LLM broker to use for generating responses
74    ///
75    /// # Examples
76    ///
77    /// ```ignore
78    /// use mojentic::agents::IterativeProblemSolver;
79    ///
80    /// let solver = IterativeProblemSolver::builder(broker)
81    ///     .max_iterations(10)
82    ///     .system_prompt("You are a specialized problem solver.")
83    ///     .tools(vec![Box::new(SimpleDateTool)])
84    ///     .build();
85    /// ```
86    pub fn builder(broker: LlmBroker) -> IterativeProblemSolverBuilder {
87        IterativeProblemSolverBuilder::new(broker)
88    }
89
90    /// Execute the problem-solving process.
91    ///
92    /// This method runs the iterative problem-solving process, continuing until one of
93    /// these conditions is met:
94    /// - The task is completed successfully (response contains "DONE")
95    /// - The task fails explicitly (response contains "FAIL")
96    /// - The maximum number of iterations is reached
97    ///
98    /// After completion, the agent requests a summary of the final result.
99    ///
100    /// # Arguments
101    ///
102    /// * `problem` - The problem or request to be solved
103    ///
104    /// # Returns
105    ///
106    /// A summary of the final result, excluding the process details
107    ///
108    /// # Examples
109    ///
110    /// ```ignore
111    /// let result = solver.solve("Calculate the date 7 days from now").await?;
112    /// println!("Solution: {}", result);
113    /// ```
114    pub async fn solve(&mut self, problem: &str) -> Result<String> {
115        let mut iterations_remaining = self.max_iterations;
116
117        loop {
118            let result = self.step(problem).await?;
119
120            // Check for explicit failure
121            if result.to_lowercase().contains("fail") {
122                info!(user_request = problem, result = result.as_str(), "Task failed");
123                break;
124            }
125
126            // Check for successful completion
127            if result.to_lowercase().contains("done") {
128                info!(user_request = problem, result = result.as_str(), "Task completed");
129                break;
130            }
131
132            iterations_remaining -= 1;
133            if iterations_remaining == 0 {
134                warn!(
135                    max_iterations = self.max_iterations,
136                    user_request = problem,
137                    result = result.as_str(),
138                    "Max iterations reached"
139                );
140                break;
141            }
142        }
143
144        // Request final summary
145        let summary = self
146            .chat
147            .send(
148                "Summarize the final result, and only the final result, \
149                 without commenting on the process by which you achieved it.",
150            )
151            .await?;
152
153        Ok(summary)
154    }
155
156    /// Execute a single problem-solving step.
157    ///
158    /// This method sends a prompt to the chat session asking it to work on the user's request
159    /// using available tools. The response should indicate success ("DONE") or failure ("FAIL").
160    ///
161    /// # Arguments
162    ///
163    /// * `problem` - The problem or request to be solved
164    ///
165    /// # Returns
166    ///
167    /// The response from the chat session, indicating the step's outcome
168    async fn step(&mut self, problem: &str) -> Result<String> {
169        let prompt = format!(
170            "Given the user request:\n\
171             {}\n\
172             \n\
173             Use the tools at your disposal to act on their request. \
174             You may wish to create a step-by-step plan for more complicated requests.\n\
175             \n\
176             If you cannot provide an answer, say only \"FAIL\".\n\
177             If you have the answer, say only \"DONE\".",
178            problem
179        );
180
181        self.chat.send(&prompt).await
182    }
183}
184
185/// Builder for constructing an `IterativeProblemSolver` with custom configuration.
186pub struct IterativeProblemSolverBuilder {
187    broker: LlmBroker,
188    tools: Option<Vec<Box<dyn LlmTool>>>,
189    max_iterations: usize,
190    system_prompt: Option<String>,
191}
192
193impl IterativeProblemSolverBuilder {
194    /// Create a new builder
195    fn new(broker: LlmBroker) -> Self {
196        Self {
197            broker,
198            tools: None,
199            max_iterations: 3,
200            system_prompt: None,
201        }
202    }
203
204    /// Set the tools available to the problem solver
205    pub fn tools(mut self, tools: Vec<Box<dyn LlmTool>>) -> Self {
206        self.tools = Some(tools);
207        self
208    }
209
210    /// Set the maximum number of iterations (default: 3)
211    pub fn max_iterations(mut self, max_iterations: usize) -> Self {
212        self.max_iterations = max_iterations;
213        self
214    }
215
216    /// Set a custom system prompt
217    pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
218        self.system_prompt = Some(prompt.into());
219        self
220    }
221
222    /// Build the problem solver
223    pub fn build(self) -> IterativeProblemSolver {
224        let system_prompt = self.system_prompt.unwrap_or_else(|| {
225            "You are a problem-solving assistant that can solve complex problems step by step. \
226             You analyze problems, break them down into smaller parts, and solve them systematically. \
227             If you cannot solve a problem completely in one step, you make progress and identify what to do next."
228                .to_string()
229        });
230
231        let mut chat_builder = ChatSession::builder(self.broker).system_prompt(system_prompt);
232
233        if let Some(tools) = self.tools {
234            chat_builder = chat_builder.tools(tools);
235        }
236
237        IterativeProblemSolver {
238            chat: chat_builder.build(),
239            max_iterations: self.max_iterations,
240        }
241    }
242}
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247    use crate::llm::gateway::{CompletionConfig, LlmGateway, StreamChunk};
248    use crate::llm::models::{LlmGatewayResponse, LlmMessage};
249    use crate::llm::tools::{FunctionDescriptor, ToolDescriptor};
250    use futures::stream::{self, Stream};
251    use serde_json::{json, Value};
252    use std::collections::HashMap;
253    use std::pin::Pin;
254    use std::sync::{Arc, Mutex};
255
256    // Mock gateway for testing
257    struct MockGateway {
258        responses: Vec<String>,
259        call_count: Arc<Mutex<usize>>,
260    }
261
262    impl MockGateway {
263        fn new(responses: Vec<String>) -> Self {
264            Self {
265                responses,
266                call_count: Arc::new(Mutex::new(0)),
267            }
268        }
269    }
270
271    #[async_trait::async_trait]
272    impl LlmGateway for MockGateway {
273        async fn complete(
274            &self,
275            _model: &str,
276            _messages: &[LlmMessage],
277            _tools: Option<&[Box<dyn LlmTool>]>,
278            _config: &CompletionConfig,
279        ) -> Result<LlmGatewayResponse> {
280            let mut count = self.call_count.lock().unwrap();
281            let idx = *count;
282            *count += 1;
283
284            let content = if idx < self.responses.len() {
285                self.responses[idx].clone()
286            } else {
287                "default response".to_string()
288            };
289
290            Ok(LlmGatewayResponse {
291                content: Some(content),
292                object: None,
293                tool_calls: vec![],
294                thinking: None,
295            })
296        }
297
298        async fn complete_json(
299            &self,
300            _model: &str,
301            _messages: &[LlmMessage],
302            _schema: Value,
303            _config: &CompletionConfig,
304        ) -> Result<Value> {
305            Ok(json!({}))
306        }
307
308        async fn get_available_models(&self) -> Result<Vec<String>> {
309            Ok(vec!["test-model".to_string()])
310        }
311
312        async fn calculate_embeddings(
313            &self,
314            _text: &str,
315            _model: Option<&str>,
316        ) -> Result<Vec<f32>> {
317            Ok(vec![0.1, 0.2, 0.3])
318        }
319
320        fn complete_stream<'a>(
321            &'a self,
322            _model: &'a str,
323            _messages: &'a [LlmMessage],
324            _tools: Option<&'a [Box<dyn LlmTool>]>,
325            _config: &'a CompletionConfig,
326        ) -> Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send + 'a>> {
327            Box::pin(stream::iter(vec![Ok(StreamChunk::Content("test".to_string()))]))
328        }
329    }
330
331    // Mock tool for testing
332    #[derive(Clone)]
333    struct MockTool {
334        name: String,
335    }
336
337    impl LlmTool for MockTool {
338        fn run(&self, _args: &HashMap<String, Value>) -> Result<Value> {
339            Ok(json!({"result": "success"}))
340        }
341
342        fn descriptor(&self) -> ToolDescriptor {
343            ToolDescriptor {
344                r#type: "function".to_string(),
345                function: FunctionDescriptor {
346                    name: self.name.clone(),
347                    description: "A mock tool".to_string(),
348                    parameters: json!({}),
349                },
350            }
351        }
352
353        fn clone_box(&self) -> Box<dyn LlmTool> {
354            Box::new(self.clone())
355        }
356    }
357
358    #[tokio::test]
359    async fn test_builder_default_settings() {
360        let gateway = Arc::new(MockGateway::new(vec![]));
361        let broker = LlmBroker::new("test-model", gateway, None);
362        let solver = IterativeProblemSolver::new(broker);
363
364        assert_eq!(solver.max_iterations, 3);
365    }
366
367    #[tokio::test]
368    async fn test_builder_custom_max_iterations() {
369        let gateway = Arc::new(MockGateway::new(vec![]));
370        let broker = LlmBroker::new("test-model", gateway, None);
371        let solver = IterativeProblemSolver::builder(broker).max_iterations(5).build();
372
373        assert_eq!(solver.max_iterations, 5);
374    }
375
376    #[tokio::test]
377    async fn test_builder_with_tools() {
378        let gateway = Arc::new(MockGateway::new(vec![]));
379        let broker = LlmBroker::new("test-model", gateway, None);
380
381        let tools: Vec<Box<dyn LlmTool>> = vec![Box::new(MockTool {
382            name: "test_tool".to_string(),
383        })];
384
385        let _solver = IterativeProblemSolver::builder(broker).tools(tools).build();
386
387        // If this compiles and runs, the builder pattern works
388    }
389
390    #[tokio::test]
391    async fn test_solve_completes_with_done() {
392        let gateway = Arc::new(MockGateway::new(vec![
393            "Working on it...".to_string(),
394            "DONE".to_string(),
395            "The answer is 42".to_string(),
396        ]));
397        let broker = LlmBroker::new("test-model", gateway, None);
398        let mut solver = IterativeProblemSolver::new(broker);
399
400        let result = solver.solve("Test problem").await.unwrap();
401
402        assert_eq!(result, "The answer is 42");
403    }
404
405    #[tokio::test]
406    async fn test_solve_fails_with_fail() {
407        let gateway = Arc::new(MockGateway::new(vec![
408            "Trying...".to_string(),
409            "FAIL".to_string(),
410            "Could not solve the problem".to_string(),
411        ]));
412        let broker = LlmBroker::new("test-model", gateway, None);
413        let mut solver = IterativeProblemSolver::new(broker);
414
415        let result = solver.solve("Impossible problem").await.unwrap();
416
417        assert_eq!(result, "Could not solve the problem");
418    }
419
420    #[tokio::test]
421    async fn test_solve_stops_at_max_iterations() {
422        let gateway = Arc::new(MockGateway::new(vec![
423            "Step 1".to_string(),
424            "Step 2".to_string(),
425            "Step 3".to_string(),
426            "Final summary".to_string(),
427        ]));
428        let broker = LlmBroker::new("test-model", gateway, None);
429        let mut solver = IterativeProblemSolver::builder(broker).max_iterations(3).build();
430
431        let result = solver.solve("Long problem").await.unwrap();
432
433        // Should have called the gateway 4 times: 3 iterations + 1 summary
434        assert_eq!(result, "Final summary");
435    }
436
437    #[tokio::test]
438    async fn test_solve_case_insensitive_done() {
439        let gateway = Arc::new(MockGateway::new(vec![
440            "done".to_string(),                 // lowercase "done"
441            "The task is complete".to_string(), // summary
442        ]));
443        let broker = LlmBroker::new("test-model", gateway, None);
444        let mut solver = IterativeProblemSolver::new(broker);
445
446        let result = solver.solve("Test problem").await.unwrap();
447
448        assert_eq!(result, "The task is complete");
449    }
450
451    #[tokio::test]
452    async fn test_solve_case_insensitive_fail() {
453        let gateway = Arc::new(MockGateway::new(vec![
454            "fail".to_string(),                    // lowercase "fail"
455            "Unable to complete task".to_string(), // summary
456        ]));
457        let broker = LlmBroker::new("test-model", gateway, None);
458        let mut solver = IterativeProblemSolver::new(broker);
459
460        let result = solver.solve("Test problem").await.unwrap();
461
462        assert_eq!(result, "Unable to complete task");
463    }
464
465    #[tokio::test]
466    async fn test_custom_system_prompt() {
467        let gateway =
468            Arc::new(MockGateway::new(vec!["DONE".to_string(), "Custom response".to_string()]));
469        let broker = LlmBroker::new("test-model", gateway, None);
470        let mut solver = IterativeProblemSolver::builder(broker)
471            .system_prompt("Custom system prompt for testing")
472            .build();
473
474        let result = solver.solve("Test problem").await.unwrap();
475
476        assert_eq!(result, "Custom response");
477    }
478
479    #[tokio::test]
480    async fn test_step_method() {
481        let gateway = Arc::new(MockGateway::new(vec!["Step response".to_string()]));
482        let broker = LlmBroker::new("test-model", gateway, None);
483        let mut solver = IterativeProblemSolver::new(broker);
484
485        let result = solver.step("Test problem").await.unwrap();
486
487        assert_eq!(result, "Step response");
488    }
489
490    #[tokio::test]
491    async fn test_multiple_iterations_before_done() {
492        let gateway = Arc::new(MockGateway::new(vec![
493            "Working...".to_string(),
494            "Still working...".to_string(),
495            "Almost there...".to_string(),
496            "DONE".to_string(),
497            "Completed successfully".to_string(),
498        ]));
499        let broker = LlmBroker::new("test-model", gateway, None);
500        let mut solver = IterativeProblemSolver::builder(broker).max_iterations(5).build();
501
502        let result = solver.solve("Complex problem").await.unwrap();
503
504        assert_eq!(result, "Completed successfully");
505    }
506
507    #[tokio::test]
508    async fn test_done_substring_detection() {
509        let gateway = Arc::new(MockGateway::new(vec![
510            "I'm DONE with this task".to_string(), // Contains "DONE"
511            "Task completed".to_string(),
512        ]));
513        let broker = LlmBroker::new("test-model", gateway, None);
514        let mut solver = IterativeProblemSolver::new(broker);
515
516        let result = solver.solve("Test problem").await.unwrap();
517
518        assert_eq!(result, "Task completed");
519    }
520
521    #[tokio::test]
522    async fn test_fail_substring_detection() {
523        let gateway = Arc::new(MockGateway::new(vec![
524            "This will FAIL".to_string(), // Contains "FAIL"
525            "Failed to complete".to_string(),
526        ]));
527        let broker = LlmBroker::new("test-model", gateway, None);
528        let mut solver = IterativeProblemSolver::new(broker);
529
530        let result = solver.solve("Test problem").await.unwrap();
531
532        assert_eq!(result, "Failed to complete");
533    }
534}