mojentic/agents/
simple_recursive_agent.rs

1//! Simple recursive agent using event-driven architecture.
2//!
3//! This module provides a declarative, event-driven agent that recursively attempts
4//! to solve problems using available tools. The agent continues iterating until it
5//! succeeds, fails explicitly, or reaches the maximum number of iterations.
6//!
7//! # Architecture
8//!
9//! The agent uses three main components:
10//!
11//! 1. **GoalState** - Tracks the problem-solving state through iterations
12//! 2. **EventEmitter** - Manages event subscriptions and async dispatch
13//! 3. **SimpleRecursiveAgent** - Orchestrates the problem-solving process
14//!
15//! # Events
16//!
17//! The agent emits the following events during problem-solving:
18//!
19//! - `GoalSubmittedEvent` - When a problem is submitted
20//! - `IterationCompletedEvent` - After each iteration completes
21//! - `GoalAchievedEvent` - When the goal is successfully achieved
22//! - `GoalFailedEvent` - When the goal explicitly fails
23//! - `TimeoutEvent` - When the process times out
24//!
25//! # Examples
26//!
27//! ```ignore
28//! use mojentic::agents::SimpleRecursiveAgent;
29//! use mojentic::llm::{LlmBroker, LlmGateway};
30//! use mojentic::llm::gateways::OllamaGateway;
31//! use mojentic::llm::tools::simple_date_tool::SimpleDateTool;
32//! use std::sync::Arc;
33//!
34//! #[tokio::main]
35//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
36//!     let gateway = Arc::new(OllamaGateway::default());
37//!     let broker = Arc::new(LlmBroker::new("qwen3:32b", gateway, None));
38//!
39//!     let tools: Vec<Box<dyn LlmTool>> = vec![Box::new(SimpleDateTool)];
40//!
41//!     let agent = SimpleRecursiveAgent::builder(broker)
42//!         .tools(tools)
43//!         .max_iterations(5)
44//!         .build();
45//!
46//!     // Subscribe to events
47//!     agent.emitter.subscribe(|event: IterationCompletedEvent| {
48//!         println!("Iteration {}: {}", event.state.iteration, event.response);
49//!     });
50//!
51//!     let result = agent.solve("What's the date next Friday?").await?;
52//!     println!("Result: {}", result);
53//!
54//!     Ok(())
55//! }
56//! ```
57//!
58//! # Completion Indicators
59//!
60//! The agent monitors responses for these keywords (case-insensitive, word boundaries):
61//! - "DONE" - Task completed successfully
62//! - "FAIL" - Task cannot be completed
63
64use crate::error::Result;
65use crate::llm::chat_session::ChatSession;
66use crate::llm::tools::LlmTool;
67use crate::llm::LlmBroker;
68use regex::Regex;
69use std::future::Future;
70use std::pin::Pin;
71use std::sync::Arc;
72use tokio::sync::{mpsc, Mutex};
73use tokio::time::{timeout, Duration};
74use tracing::warn;
75
76/// Represents the state of a problem-solving process.
77#[derive(Debug, Clone)]
78pub struct GoalState {
79    /// The problem or goal to solve
80    pub goal: String,
81    /// Current iteration count
82    pub iteration: usize,
83    /// Maximum allowed iterations
84    pub max_iterations: usize,
85    /// The solution, if found
86    pub solution: Option<String>,
87    /// Whether the problem-solving process is complete
88    pub is_complete: bool,
89}
90
91impl GoalState {
92    /// Create a new goal state
93    pub fn new(goal: impl Into<String>, max_iterations: usize) -> Self {
94        Self {
95            goal: goal.into(),
96            iteration: 0,
97            max_iterations,
98            solution: None,
99            is_complete: false,
100        }
101    }
102}
103
104/// Base trait for solver events.
105pub trait SolverEvent: Send + Sync + std::fmt::Debug {
106    /// Get the current state
107    fn state(&self) -> &GoalState;
108}
109
110/// Event triggered when a goal is submitted for solving.
111#[derive(Debug, Clone)]
112pub struct GoalSubmittedEvent {
113    pub state: GoalState,
114}
115
116impl SolverEvent for GoalSubmittedEvent {
117    fn state(&self) -> &GoalState {
118        &self.state
119    }
120}
121
122/// Event triggered when an iteration of the problem-solving process is completed.
123#[derive(Debug, Clone)]
124pub struct IterationCompletedEvent {
125    pub state: GoalState,
126    /// The response from the LLM for this iteration
127    pub response: String,
128}
129
130impl SolverEvent for IterationCompletedEvent {
131    fn state(&self) -> &GoalState {
132        &self.state
133    }
134}
135
136/// Event triggered when a goal is successfully achieved.
137#[derive(Debug, Clone)]
138pub struct GoalAchievedEvent {
139    pub state: GoalState,
140}
141
142impl SolverEvent for GoalAchievedEvent {
143    fn state(&self) -> &GoalState {
144        &self.state
145    }
146}
147
148/// Event triggered when a goal cannot be solved.
149#[derive(Debug, Clone)]
150pub struct GoalFailedEvent {
151    pub state: GoalState,
152}
153
154impl SolverEvent for GoalFailedEvent {
155    fn state(&self) -> &GoalState {
156        &self.state
157    }
158}
159
160/// Event triggered when the problem-solving process times out.
161#[derive(Debug, Clone)]
162pub struct TimeoutEvent {
163    pub state: GoalState,
164}
165
166impl SolverEvent for TimeoutEvent {
167    fn state(&self) -> &GoalState {
168        &self.state
169    }
170}
171
172/// Union type of all solver events
173#[derive(Debug, Clone)]
174pub enum AnySolverEvent {
175    GoalSubmitted(GoalSubmittedEvent),
176    IterationCompleted(IterationCompletedEvent),
177    GoalAchieved(GoalAchievedEvent),
178    GoalFailed(GoalFailedEvent),
179    Timeout(TimeoutEvent),
180}
181
182impl AnySolverEvent {
183    /// Get the state from any event variant
184    pub fn state(&self) -> &GoalState {
185        match self {
186            AnySolverEvent::GoalSubmitted(e) => &e.state,
187            AnySolverEvent::IterationCompleted(e) => &e.state,
188            AnySolverEvent::GoalAchieved(e) => &e.state,
189            AnySolverEvent::GoalFailed(e) => &e.state,
190            AnySolverEvent::Timeout(e) => &e.state,
191        }
192    }
193}
194
195/// Event handler callback type
196type EventCallback = Arc<dyn Fn(AnySolverEvent) + Send + Sync>;
197
198/// A simple event emitter that allows subscribing to and emitting events.
199///
200/// This implementation uses async channels to dispatch events to subscribers
201/// asynchronously without blocking the emitter.
202pub struct EventEmitter {
203    subscribers: Arc<Mutex<Vec<EventCallback>>>,
204}
205
206impl EventEmitter {
207    /// Create a new event emitter
208    pub fn new() -> Self {
209        Self {
210            subscribers: Arc::new(Mutex::new(Vec::new())),
211        }
212    }
213
214    /// Subscribe to events with a callback function.
215    ///
216    /// # Examples
217    ///
218    /// ```ignore
219    /// emitter.subscribe(|event: AnySolverEvent| {
220    ///     println!("Event received: {:?}", event);
221    /// });
222    /// ```
223    pub async fn subscribe<F>(&self, callback: F)
224    where
225        F: Fn(AnySolverEvent) + Send + Sync + 'static,
226    {
227        let mut subscribers = self.subscribers.lock().await;
228        subscribers.push(Arc::new(callback));
229    }
230
231    /// Emit an event to all subscribers asynchronously.
232    ///
233    /// Events are dispatched to subscribers without blocking the emitter.
234    pub async fn emit(&self, event: AnySolverEvent) {
235        let subscribers = self.subscribers.lock().await.clone();
236
237        for callback in subscribers {
238            let event = event.clone();
239            let callback = callback.clone();
240
241            // Spawn a task to call the callback asynchronously
242            tokio::spawn(async move {
243                callback(event);
244            });
245        }
246    }
247}
248
249impl Default for EventEmitter {
250    fn default() -> Self {
251        Self::new()
252    }
253}
254
255/// An agent that recursively attempts to solve a problem using available tools.
256///
257/// This agent uses an event-driven approach to manage the problem-solving process.
258/// It will continue attempting to solve the problem until it either succeeds,
259/// fails explicitly, or reaches the maximum number of iterations.
260pub struct SimpleRecursiveAgent {
261    broker: Arc<LlmBroker>,
262    tools: Vec<Box<dyn LlmTool>>,
263    max_iterations: usize,
264    system_prompt: String,
265    /// The event emitter used to manage events
266    pub emitter: Arc<EventEmitter>,
267}
268
269impl SimpleRecursiveAgent {
270    /// Create a new SimpleRecursiveAgent with default settings.
271    ///
272    /// # Arguments
273    ///
274    /// * `broker` - The LLM broker to use for generating responses
275    ///
276    /// # Examples
277    ///
278    /// ```ignore
279    /// use mojentic::agents::SimpleRecursiveAgent;
280    /// use mojentic::llm::LlmBroker;
281    /// use std::sync::Arc;
282    ///
283    /// let broker = Arc::new(LlmBroker::new("qwen3:32b", gateway, None));
284    /// let agent = SimpleRecursiveAgent::new(broker);
285    /// ```
286    pub fn new(broker: Arc<LlmBroker>) -> Self {
287        Self::builder(broker).build()
288    }
289
290    /// Create a SimpleRecursiveAgent builder for custom configuration.
291    ///
292    /// # Arguments
293    ///
294    /// * `broker` - The LLM broker to use for generating responses
295    ///
296    /// # Examples
297    ///
298    /// ```ignore
299    /// use mojentic::agents::SimpleRecursiveAgent;
300    ///
301    /// let agent = SimpleRecursiveAgent::builder(broker)
302    ///     .max_iterations(10)
303    ///     .system_prompt("You are a specialized assistant.")
304    ///     .tools(vec![Box::new(SimpleDateTool)])
305    ///     .build();
306    /// ```
307    pub fn builder(broker: Arc<LlmBroker>) -> SimpleRecursiveAgentBuilder {
308        SimpleRecursiveAgentBuilder::new(broker)
309    }
310
311    /// Solve a problem asynchronously.
312    ///
313    /// This method runs the event-driven problem-solving process with a 300-second timeout.
314    /// The agent will continue iterating until:
315    /// - The task is completed successfully ("DONE")
316    /// - The task fails explicitly ("FAIL")
317    /// - The maximum number of iterations is reached
318    /// - The process times out (300 seconds)
319    ///
320    /// # Arguments
321    ///
322    /// * `problem` - The problem or request to be solved
323    ///
324    /// # Returns
325    ///
326    /// The solution to the problem
327    ///
328    /// # Examples
329    ///
330    /// ```ignore
331    /// let solution = agent.solve("Calculate the factorial of 5").await?;
332    /// println!("Solution: {}", solution);
333    /// ```
334    pub async fn solve(&self, problem: impl Into<String>) -> Result<String> {
335        let problem = problem.into();
336
337        // Create a channel to receive the solution
338        let (solution_tx, mut solution_rx) = mpsc::channel::<String>(1);
339
340        // Create the initial goal state
341        let state = GoalState::new(problem.clone(), self.max_iterations);
342
343        // Clone what we need for the async task
344        let solution_tx_clone = solution_tx.clone();
345
346        // Subscribe to completion events
347        let emitter = self.emitter.clone();
348        emitter
349            .subscribe(move |event: AnySolverEvent| match &event {
350                AnySolverEvent::GoalAchieved(_)
351                | AnySolverEvent::GoalFailed(_)
352                | AnySolverEvent::Timeout(_) => {
353                    if let Some(solution) = &event.state().solution {
354                        let _ = solution_tx_clone.try_send(solution.clone());
355                    }
356                }
357                _ => {}
358            })
359            .await;
360
361        // Start the solving process
362        self.emitter
363            .emit(AnySolverEvent::GoalSubmitted(GoalSubmittedEvent {
364                state: state.clone(),
365            }))
366            .await;
367
368        // Spawn a task to handle the problem submission
369        let agent = self.clone_for_handler();
370        tokio::spawn(async move {
371            agent.handle_goal_submitted(state).await;
372        });
373
374        // Wait for solution or timeout (300 seconds)
375        match timeout(Duration::from_secs(300), solution_rx.recv()).await {
376            Ok(Some(solution)) => Ok(solution),
377            Ok(None) => {
378                let timeout_message =
379                    "Timeout: Could not solve the problem within 300 seconds.".to_string();
380                let mut timeout_state = GoalState::new(problem, self.max_iterations);
381                timeout_state.solution = Some(timeout_message.clone());
382                timeout_state.is_complete = true;
383
384                self.emitter
385                    .emit(AnySolverEvent::Timeout(TimeoutEvent {
386                        state: timeout_state,
387                    }))
388                    .await;
389
390                Ok(timeout_message)
391            }
392            Err(_) => {
393                let timeout_message =
394                    "Timeout: Could not solve the problem within 300 seconds.".to_string();
395                let mut timeout_state = GoalState::new(problem, self.max_iterations);
396                timeout_state.solution = Some(timeout_message.clone());
397                timeout_state.is_complete = true;
398
399                self.emitter
400                    .emit(AnySolverEvent::Timeout(TimeoutEvent {
401                        state: timeout_state,
402                    }))
403                    .await;
404
405                Ok(timeout_message)
406            }
407        }
408    }
409
410    /// Handle a goal submitted event
411    fn handle_goal_submitted(
412        &self,
413        state: GoalState,
414    ) -> Pin<Box<dyn Future<Output = ()> + Send + '_>> {
415        Box::pin(async move {
416            self.process_iteration(state).await;
417        })
418    }
419
420    /// Handle an iteration completed event
421    fn handle_iteration_completed(
422        &self,
423        mut state: GoalState,
424        response: String,
425    ) -> Pin<Box<dyn Future<Output = ()> + Send + '_>> {
426        Box::pin(async move {
427            let response_lower = response.to_lowercase();
428
429            // Create regex patterns for word boundary matching
430            let done_pattern = Regex::new(r"\bdone\b").unwrap();
431            let fail_pattern = Regex::new(r"\bfail\b").unwrap();
432
433            // Check if the task failed
434            if fail_pattern.is_match(&response_lower) {
435                state.solution = Some(format!(
436                    "Failed to solve after {} iterations:\n{}",
437                    state.iteration, response
438                ));
439                state.is_complete = true;
440
441                self.emitter.emit(AnySolverEvent::GoalFailed(GoalFailedEvent { state })).await;
442                return;
443            }
444
445            // Check if the task succeeded
446            if done_pattern.is_match(&response_lower) {
447                state.solution = Some(response);
448                state.is_complete = true;
449
450                self.emitter
451                    .emit(AnySolverEvent::GoalAchieved(GoalAchievedEvent { state }))
452                    .await;
453                return;
454            }
455
456            // Check if we've reached max iterations
457            if state.iteration >= state.max_iterations {
458                state.solution = Some(format!(
459                    "Best solution after {} iterations:\n{}",
460                    state.max_iterations, response
461                ));
462                state.is_complete = true;
463
464                self.emitter
465                    .emit(AnySolverEvent::GoalAchieved(GoalAchievedEvent { state }))
466                    .await;
467                return;
468            }
469
470            // Continue with next iteration
471            self.process_iteration(state).await;
472        })
473    }
474
475    /// Process a single iteration of the problem-solving process
476    fn process_iteration(
477        &self,
478        mut state: GoalState,
479    ) -> Pin<Box<dyn Future<Output = ()> + Send + '_>> {
480        Box::pin(async move {
481            // Increment iteration counter
482            state.iteration += 1;
483
484            // Generate prompt for this iteration
485            let prompt = format!(
486                "Given the user request:\n\
487             {}\n\
488             \n\
489             Use the tools at your disposal to act on their request.\n\
490             You may wish to create a step-by-step plan for more complicated requests.\n\
491             \n\
492             If you cannot provide an answer, say only \"FAIL\".\n\
493             If you have the answer, say only \"DONE\".",
494                state.goal
495            );
496
497            // Generate response asynchronously
498            match self.generate_response(&prompt).await {
499                Ok(response) => {
500                    self.emitter
501                        .emit(AnySolverEvent::IterationCompleted(IterationCompletedEvent {
502                            state: state.clone(),
503                            response: response.clone(),
504                        }))
505                        .await;
506
507                    // Handle the completed iteration
508                    self.handle_iteration_completed(state, response).await;
509                }
510                Err(e) => {
511                    warn!("Error generating response: {}", e);
512                    let mut error_state = state;
513                    error_state.solution = Some(format!("Error: {}", e));
514                    error_state.is_complete = true;
515
516                    self.emitter
517                        .emit(AnySolverEvent::GoalFailed(GoalFailedEvent { state: error_state }))
518                        .await;
519                }
520            }
521        })
522    }
523
524    /// Generate a response using a ChatSession
525    async fn generate_response(&self, prompt: &str) -> Result<String> {
526        // Create a chat session for this request
527        let broker = Arc::clone(&self.broker);
528        let mut chat = ChatSession::builder((*broker).clone())
529            .system_prompt(&self.system_prompt)
530            .tools(self.tools.iter().map(|t| t.clone_box()).collect())
531            .build();
532
533        chat.send(prompt).await
534    }
535
536    /// Clone the agent for use in async handlers
537    ///
538    /// This creates a shallow clone suitable for spawned tasks
539    fn clone_for_handler(&self) -> Self {
540        Self {
541            broker: self.broker.clone(),
542            tools: self.tools.iter().map(|t| t.clone_box()).collect(),
543            max_iterations: self.max_iterations,
544            system_prompt: self.system_prompt.clone(),
545            emitter: self.emitter.clone(),
546        }
547    }
548}
549
550/// Builder for constructing a `SimpleRecursiveAgent` with custom configuration.
551pub struct SimpleRecursiveAgentBuilder {
552    broker: Arc<LlmBroker>,
553    tools: Vec<Box<dyn LlmTool>>,
554    max_iterations: usize,
555    system_prompt: Option<String>,
556}
557
558impl SimpleRecursiveAgentBuilder {
559    /// Create a new builder
560    fn new(broker: Arc<LlmBroker>) -> Self {
561        Self {
562            broker,
563            tools: Vec::new(),
564            max_iterations: 5,
565            system_prompt: None,
566        }
567    }
568
569    /// Set the tools available to the agent
570    pub fn tools(mut self, tools: Vec<Box<dyn LlmTool>>) -> Self {
571        self.tools = tools;
572        self
573    }
574
575    /// Set the maximum number of iterations (default: 5)
576    pub fn max_iterations(mut self, max_iterations: usize) -> Self {
577        self.max_iterations = max_iterations;
578        self
579    }
580
581    /// Set a custom system prompt
582    pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
583        self.system_prompt = Some(prompt.into());
584        self
585    }
586
587    /// Build the agent
588    pub fn build(self) -> SimpleRecursiveAgent {
589        let system_prompt = self.system_prompt.unwrap_or_else(|| {
590            "You are a problem-solving assistant that can solve complex problems step by step. \
591             You analyze problems, break them down into smaller parts, and solve them systematically. \
592             If you cannot solve a problem completely in one step, you make progress and identify what to do next."
593                .to_string()
594        });
595
596        SimpleRecursiveAgent {
597            broker: self.broker,
598            tools: self.tools,
599            max_iterations: self.max_iterations,
600            system_prompt,
601            emitter: Arc::new(EventEmitter::new()),
602        }
603    }
604}
605
606#[cfg(test)]
607mod tests {
608    use super::*;
609    use crate::llm::gateway::{CompletionConfig, LlmGateway, StreamChunk};
610    use crate::llm::models::{LlmGatewayResponse, LlmMessage};
611    use crate::llm::tools::{FunctionDescriptor, ToolDescriptor};
612    use futures::stream::{self, Stream};
613    use serde_json::{json, Value};
614    use std::collections::HashMap;
615    use std::pin::Pin;
616    use std::sync::atomic::{AtomicUsize, Ordering};
617
618    // Mock gateway for testing
619    struct MockGateway {
620        responses: Vec<String>,
621        call_count: Arc<AtomicUsize>,
622    }
623
624    impl MockGateway {
625        fn new(responses: Vec<String>) -> Self {
626            Self {
627                responses,
628                call_count: Arc::new(AtomicUsize::new(0)),
629            }
630        }
631    }
632
633    #[async_trait::async_trait]
634    impl LlmGateway for MockGateway {
635        async fn complete(
636            &self,
637            _model: &str,
638            _messages: &[LlmMessage],
639            _tools: Option<&[Box<dyn LlmTool>]>,
640            _config: &CompletionConfig,
641        ) -> Result<LlmGatewayResponse> {
642            let idx = self.call_count.fetch_add(1, Ordering::SeqCst);
643
644            let content = if idx < self.responses.len() {
645                self.responses[idx].clone()
646            } else {
647                "default response".to_string()
648            };
649
650            Ok(LlmGatewayResponse {
651                content: Some(content),
652                object: None,
653                tool_calls: vec![],
654                thinking: None,
655            })
656        }
657
658        async fn complete_json(
659            &self,
660            _model: &str,
661            _messages: &[LlmMessage],
662            _schema: Value,
663            _config: &CompletionConfig,
664        ) -> Result<Value> {
665            Ok(json!({}))
666        }
667
668        async fn get_available_models(&self) -> Result<Vec<String>> {
669            Ok(vec!["test-model".to_string()])
670        }
671
672        async fn calculate_embeddings(
673            &self,
674            _text: &str,
675            _model: Option<&str>,
676        ) -> Result<Vec<f32>> {
677            Ok(vec![0.1, 0.2, 0.3])
678        }
679
680        fn complete_stream<'a>(
681            &'a self,
682            _model: &'a str,
683            _messages: &'a [LlmMessage],
684            _tools: Option<&'a [Box<dyn LlmTool>]>,
685            _config: &'a CompletionConfig,
686        ) -> Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send + 'a>> {
687            Box::pin(stream::iter(vec![Ok(StreamChunk::Content("test".to_string()))]))
688        }
689    }
690
691    // Mock tool for testing
692    #[derive(Clone)]
693    struct MockTool {
694        name: String,
695    }
696
697    impl LlmTool for MockTool {
698        fn run(&self, _args: &HashMap<String, Value>) -> Result<Value> {
699            Ok(json!({"result": "success"}))
700        }
701
702        fn descriptor(&self) -> ToolDescriptor {
703            ToolDescriptor {
704                r#type: "function".to_string(),
705                function: FunctionDescriptor {
706                    name: self.name.clone(),
707                    description: "A mock tool".to_string(),
708                    parameters: json!({}),
709                },
710            }
711        }
712
713        fn clone_box(&self) -> Box<dyn LlmTool> {
714            Box::new(self.clone())
715        }
716    }
717
718    #[tokio::test]
719    async fn test_goal_state_creation() {
720        let state = GoalState::new("Test problem", 5);
721
722        assert_eq!(state.goal, "Test problem");
723        assert_eq!(state.iteration, 0);
724        assert_eq!(state.max_iterations, 5);
725        assert_eq!(state.solution, None);
726        assert!(!state.is_complete);
727    }
728
729    #[tokio::test]
730    async fn test_event_emitter_subscribe_and_emit() {
731        let emitter = EventEmitter::new();
732        let received = Arc::new(Mutex::new(false));
733        let received_clone = received.clone();
734
735        emitter
736            .subscribe(move |_event: AnySolverEvent| {
737                let received = received_clone.clone();
738                tokio::spawn(async move {
739                    *received.lock().await = true;
740                });
741            })
742            .await;
743
744        let state = GoalState::new("Test", 5);
745        emitter.emit(AnySolverEvent::GoalSubmitted(GoalSubmittedEvent { state })).await;
746
747        // Give the async task time to execute
748        tokio::time::sleep(Duration::from_millis(50)).await;
749
750        assert!(*received.lock().await);
751    }
752
753    #[tokio::test]
754    async fn test_builder_default_settings() {
755        let gateway = Arc::new(MockGateway::new(vec![]));
756        let broker = Arc::new(LlmBroker::new("test-model", gateway, None));
757        let agent = SimpleRecursiveAgent::new(broker);
758
759        assert_eq!(agent.max_iterations, 5);
760        assert_eq!(agent.tools.len(), 0);
761    }
762
763    #[tokio::test]
764    async fn test_builder_custom_max_iterations() {
765        let gateway = Arc::new(MockGateway::new(vec![]));
766        let broker = Arc::new(LlmBroker::new("test-model", gateway, None));
767        let agent = SimpleRecursiveAgent::builder(broker).max_iterations(10).build();
768
769        assert_eq!(agent.max_iterations, 10);
770    }
771
772    #[tokio::test]
773    async fn test_builder_with_tools() {
774        let gateway = Arc::new(MockGateway::new(vec![]));
775        let broker = Arc::new(LlmBroker::new("test-model", gateway, None));
776
777        let tools: Vec<Box<dyn LlmTool>> = vec![Box::new(MockTool {
778            name: "test_tool".to_string(),
779        })];
780
781        let agent = SimpleRecursiveAgent::builder(broker).tools(tools).build();
782
783        assert_eq!(agent.tools.len(), 1);
784    }
785
786    #[tokio::test]
787    async fn test_builder_custom_system_prompt() {
788        let gateway = Arc::new(MockGateway::new(vec![]));
789        let broker = Arc::new(LlmBroker::new("test-model", gateway, None));
790        let agent = SimpleRecursiveAgent::builder(broker).system_prompt("Custom prompt").build();
791
792        assert_eq!(agent.system_prompt, "Custom prompt");
793    }
794
795    #[tokio::test]
796    async fn test_solve_completes_with_done() {
797        let gateway = Arc::new(MockGateway::new(vec!["DONE".to_string()]));
798        let broker = Arc::new(LlmBroker::new("test-model", gateway, None));
799        let agent = SimpleRecursiveAgent::new(broker);
800
801        let result = agent.solve("Test problem").await.unwrap();
802
803        assert_eq!(result, "DONE");
804    }
805
806    #[tokio::test]
807    async fn test_solve_fails_with_fail() {
808        let gateway = Arc::new(MockGateway::new(vec!["FAIL".to_string()]));
809        let broker = Arc::new(LlmBroker::new("test-model", gateway, None));
810        let agent = SimpleRecursiveAgent::new(broker);
811
812        let result = agent.solve("Impossible problem").await.unwrap();
813
814        assert!(result.contains("Failed to solve after 1 iterations"));
815        assert!(result.contains("FAIL"));
816    }
817
818    #[tokio::test]
819    async fn test_solve_case_insensitive_done() {
820        let gateway = Arc::new(MockGateway::new(vec!["done".to_string()]));
821        let broker = Arc::new(LlmBroker::new("test-model", gateway, None));
822        let agent = SimpleRecursiveAgent::new(broker);
823
824        let result = agent.solve("Test problem").await.unwrap();
825
826        assert_eq!(result, "done");
827    }
828
829    #[tokio::test]
830    async fn test_solve_case_insensitive_fail() {
831        let gateway = Arc::new(MockGateway::new(vec!["fail".to_string()]));
832        let broker = Arc::new(LlmBroker::new("test-model", gateway, None));
833        let agent = SimpleRecursiveAgent::new(broker);
834
835        let result = agent.solve("Test problem").await.unwrap();
836
837        assert!(result.contains("Failed to solve"));
838        assert!(result.contains("fail"));
839    }
840
841    #[tokio::test]
842    async fn test_solve_word_boundary_done() {
843        let gateway = Arc::new(MockGateway::new(vec!["I'm DONE with this task".to_string()]));
844        let broker = Arc::new(LlmBroker::new("test-model", gateway, None));
845        let agent = SimpleRecursiveAgent::new(broker);
846
847        let result = agent.solve("Test problem").await.unwrap();
848
849        assert_eq!(result, "I'm DONE with this task");
850    }
851
852    #[tokio::test]
853    async fn test_solve_word_boundary_fail() {
854        let gateway = Arc::new(MockGateway::new(vec!["This will FAIL".to_string()]));
855        let broker = Arc::new(LlmBroker::new("test-model", gateway, None));
856        let agent = SimpleRecursiveAgent::new(broker);
857
858        let result = agent.solve("Test problem").await.unwrap();
859
860        assert!(result.contains("Failed to solve"));
861    }
862
863    #[tokio::test]
864    async fn test_solve_stops_at_max_iterations() {
865        let gateway = Arc::new(MockGateway::new(vec![
866            "Step 1".to_string(),
867            "Step 2".to_string(),
868            "Step 3".to_string(),
869        ]));
870        let broker = Arc::new(LlmBroker::new("test-model", gateway, None));
871        let agent = SimpleRecursiveAgent::builder(broker).max_iterations(3).build();
872
873        let result = agent.solve("Long problem").await.unwrap();
874
875        assert!(result.contains("Best solution after 3 iterations"));
876        assert!(result.contains("Step 3"));
877    }
878
879    #[tokio::test]
880    async fn test_solve_multiple_iterations_before_done() {
881        let gateway = Arc::new(MockGateway::new(vec![
882            "Working...".to_string(),
883            "Still working...".to_string(),
884            "DONE".to_string(),
885        ]));
886        let broker = Arc::new(LlmBroker::new("test-model", gateway, None));
887        let agent = SimpleRecursiveAgent::builder(broker).max_iterations(5).build();
888
889        let result = agent.solve("Complex problem").await.unwrap();
890
891        assert_eq!(result, "DONE");
892    }
893
894    #[tokio::test]
895    async fn test_event_emission_during_solve() {
896        let gateway = Arc::new(MockGateway::new(vec!["DONE".to_string()]));
897        let broker = Arc::new(LlmBroker::new("test-model", gateway, None));
898        let agent = SimpleRecursiveAgent::new(broker);
899
900        let goal_submitted = Arc::new(Mutex::new(false));
901        let iteration_completed = Arc::new(Mutex::new(false));
902        let goal_achieved = Arc::new(Mutex::new(false));
903
904        let goal_submitted_clone = goal_submitted.clone();
905        let iteration_completed_clone = iteration_completed.clone();
906        let goal_achieved_clone = goal_achieved.clone();
907
908        agent
909            .emitter
910            .subscribe(move |event: AnySolverEvent| {
911                let gs = goal_submitted_clone.clone();
912                let ic = iteration_completed_clone.clone();
913                let ga = goal_achieved_clone.clone();
914
915                tokio::spawn(async move {
916                    match event {
917                        AnySolverEvent::GoalSubmitted(_) => *gs.lock().await = true,
918                        AnySolverEvent::IterationCompleted(_) => *ic.lock().await = true,
919                        AnySolverEvent::GoalAchieved(_) => *ga.lock().await = true,
920                        _ => {}
921                    }
922                });
923            })
924            .await;
925
926        let _result = agent.solve("Test problem").await.unwrap();
927
928        // Give async tasks time to complete
929        tokio::time::sleep(Duration::from_millis(100)).await;
930
931        assert!(*goal_submitted.lock().await, "GoalSubmitted event not fired");
932        assert!(*iteration_completed.lock().await, "IterationCompleted event not fired");
933        assert!(*goal_achieved.lock().await, "GoalAchieved event not fired");
934    }
935
936    #[tokio::test]
937    async fn test_event_emission_on_failure() {
938        let gateway = Arc::new(MockGateway::new(vec!["FAIL".to_string()]));
939        let broker = Arc::new(LlmBroker::new("test-model", gateway, None));
940        let agent = SimpleRecursiveAgent::new(broker);
941
942        let goal_failed = Arc::new(Mutex::new(false));
943        let goal_failed_clone = goal_failed.clone();
944
945        agent
946            .emitter
947            .subscribe(move |event: AnySolverEvent| {
948                let gf = goal_failed_clone.clone();
949                tokio::spawn(async move {
950                    if matches!(event, AnySolverEvent::GoalFailed(_)) {
951                        *gf.lock().await = true;
952                    }
953                });
954            })
955            .await;
956
957        let _result = agent.solve("Test problem").await.unwrap();
958
959        // Give async tasks time to complete
960        tokio::time::sleep(Duration::from_millis(100)).await;
961
962        assert!(*goal_failed.lock().await, "GoalFailed event not fired");
963    }
964}