mojentic/examples/react/
thinking_agent.rs

1//! Planning agent for the ReAct pattern.
2//!
3//! This agent creates structured plans for solving user queries.
4
5use crate::agents::BaseAsyncAgent;
6use crate::event::Event;
7use crate::llm::tools::simple_date_tool::SimpleDateTool;
8use crate::llm::tools::LlmTool;
9use crate::llm::{LlmBroker, LlmMessage, MessageRole};
10use crate::Result;
11use async_trait::async_trait;
12use std::sync::Arc;
13
14use super::events::{FailureOccurred, InvokeDecisioning, InvokeThinking};
15use super::formatters::{format_available_tools, format_current_context};
16use super::models::{Plan, ThoughtActionObservation};
17
18/// Agent responsible for creating plans in the ReAct loop.
19///
20/// This agent analyzes the user query and available tools to create
21/// a step-by-step plan for answering the query.
22pub struct ThinkingAgent {
23    llm: Arc<LlmBroker>,
24    tools: Vec<Box<dyn LlmTool>>,
25}
26
27impl ThinkingAgent {
28    /// Initialize the thinking agent.
29    ///
30    /// # Arguments
31    ///
32    /// * `llm` - The LLM broker to use for generating plans.
33    pub fn new(llm: Arc<LlmBroker>) -> Self {
34        Self {
35            llm,
36            tools: vec![Box::new(SimpleDateTool)],
37        }
38    }
39
40    /// Generate the prompt for the planning LLM.
41    fn prompt(&self, event: &InvokeThinking) -> String {
42        let tools_list: Vec<&dyn LlmTool> = self.tools.iter().map(|t| t.as_ref()).collect();
43
44        format!(
45            "You are to solve a problem by reasoning and acting on the information you have. Here is the current context:
46
47{}
48{}
49
50Your Instructions:
51Given our context and what we've done so far, and the tools available, create a step-by-step plan to answer the query.
52Each step should be concrete and actionable. Consider which tools you'll need to use.",
53            format_current_context(&event.context),
54            format_available_tools(&tools_list)
55        )
56    }
57}
58
59#[async_trait]
60impl BaseAsyncAgent for ThinkingAgent {
61    async fn receive_event_async(&self, event: Box<dyn Event>) -> Result<Vec<Box<dyn Event>>> {
62        // Downcast to InvokeThinking
63        let thinking_event = match event.as_any().downcast_ref::<InvokeThinking>() {
64            Some(e) => e,
65            None => return Ok(vec![]),
66        };
67
68        let prompt = self.prompt(thinking_event);
69        println!("\n{}\n{}\n{}\n", "=".repeat(80), prompt, "=".repeat(80));
70
71        // Generate plan using structured output
72        let plan = match self
73            .llm
74            .generate_object::<Plan>(
75                &[LlmMessage {
76                    role: MessageRole::User,
77                    content: Some(prompt),
78                    tool_calls: None,
79                    image_paths: None,
80                }],
81                None,
82                thinking_event.correlation_id.clone(),
83            )
84            .await
85        {
86            Ok(p) => p,
87            Err(e) => {
88                return Ok(vec![Box::new(FailureOccurred {
89                    source: "ThinkingAgent".to_string(),
90                    correlation_id: thinking_event.correlation_id.clone(),
91                    context: thinking_event.context.clone(),
92                    reason: format!("Error during planning: {}", e),
93                }) as Box<dyn Event>]);
94            }
95        };
96
97        println!("\n{}\nPlan: {:?}\n{}\n", "=".repeat(80), plan, "=".repeat(80));
98
99        // Update context with new plan
100        let mut updated_context = thinking_event.context.clone();
101        updated_context.plan = plan.clone();
102
103        // Add planning step to history
104        updated_context.history.push(ThoughtActionObservation {
105            thought: "I need to create a plan to solve this query.".to_string(),
106            action: "Created a step-by-step plan.".to_string(),
107            observation: format!("Plan has {} steps.", plan.steps.len()),
108        });
109
110        Ok(vec![Box::new(InvokeDecisioning {
111            source: "ThinkingAgent".to_string(),
112            correlation_id: thinking_event.correlation_id.clone(),
113            context: updated_context,
114        }) as Box<dyn Event>])
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use super::*;
121    use crate::llm::gateways::OllamaGateway;
122
123    use super::super::models::CurrentContext;
124
125    #[test]
126    fn test_thinking_agent_prompt_generation() {
127        let gateway = Arc::new(OllamaGateway::new());
128        let llm = Arc::new(LlmBroker::new("qwen3:32b", gateway, None));
129        let agent = ThinkingAgent::new(llm);
130
131        let context = CurrentContext::new("What is the date tomorrow?");
132        let event = InvokeThinking {
133            source: "TestSource".to_string(),
134            correlation_id: Some("test-123".to_string()),
135            context,
136        };
137
138        let prompt = agent.prompt(&event);
139
140        assert!(prompt.contains("What is the date tomorrow?"));
141        assert!(prompt.contains("create a step-by-step plan"));
142        assert!(prompt.contains("resolve_date"));
143    }
144
145    #[test]
146    fn test_thinking_agent_with_existing_plan() {
147        let gateway = Arc::new(OllamaGateway::new());
148        let llm = Arc::new(LlmBroker::new("qwen3:32b", gateway, None));
149        let agent = ThinkingAgent::new(llm);
150
151        let mut context = CurrentContext::new("What day is it?");
152        context.plan = Plan {
153            steps: vec!["Get current date".to_string()],
154        };
155
156        let event = InvokeThinking {
157            source: "TestSource".to_string(),
158            correlation_id: Some("test-456".to_string()),
159            context,
160        };
161
162        let prompt = agent.prompt(&event);
163
164        assert!(prompt.contains("Current plan:"));
165        assert!(prompt.contains("Get current date"));
166    }
167
168    #[tokio::test]
169    async fn test_thinking_agent_ignores_wrong_event_type() {
170        let gateway = Arc::new(OllamaGateway::new());
171        let llm = Arc::new(LlmBroker::new("qwen3:32b", gateway, None));
172        let agent = ThinkingAgent::new(llm);
173
174        let wrong_event = Box::new(InvokeDecisioning {
175            source: "Wrong".to_string(),
176            correlation_id: None,
177            context: CurrentContext::new("Test"),
178        }) as Box<dyn Event>;
179
180        let result = agent.receive_event_async(wrong_event).await.unwrap();
181        assert!(result.is_empty());
182    }
183}