mojentic/examples/react/
decisioning_agent.rs

1//! Decision-making agent for the ReAct pattern.
2//!
3//! This agent evaluates the current context and decides on the next action to take.
4
5use crate::agents::BaseAsyncAgent;
6use crate::event::Event;
7use crate::llm::tools::simple_date_tool::SimpleDateTool;
8use crate::llm::tools::LlmTool;
9use crate::llm::{LlmBroker, LlmMessage, MessageRole};
10use crate::Result;
11use async_trait::async_trait;
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::sync::Arc;
15
16use super::events::{
17    FailureOccurred, FinishAndSummarize, InvokeDecisioning, InvokeThinking, InvokeToolCall,
18};
19use super::formatters::{format_available_tools, format_current_context};
20use super::models::NextAction;
21
22/// Structured response from the decisioning agent.
23#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
24pub struct DecisionResponse {
25    /// The reasoning behind the decision
26    pub thought: String,
27    /// What should happen next: PLAN, ACT, or FINISH
28    pub next_action: NextAction,
29    /// Name of tool to use if next_action is ACT
30    #[serde(skip_serializing_if = "Option::is_none")]
31    pub tool_name: Option<String>,
32    /// Arguments for the tool if next_action is ACT
33    #[serde(default)]
34    pub tool_arguments: HashMap<String, serde_json::Value>,
35}
36
37/// Agent responsible for deciding the next action in the ReAct loop.
38///
39/// This agent evaluates the current context, plan, and history to determine
40/// whether to continue planning, take an action, or finish and summarize.
41pub struct DecisioningAgent {
42    llm: Arc<LlmBroker>,
43    tools: Vec<Arc<dyn LlmTool>>,
44}
45
46impl DecisioningAgent {
47    /// Maximum iterations before failing
48    const MAX_ITERATIONS: usize = 10;
49
50    /// Initialize the decisioning agent.
51    ///
52    /// # Arguments
53    ///
54    /// * `llm` - The LLM broker to use for generating decisions.
55    pub fn new(llm: Arc<LlmBroker>) -> Self {
56        Self {
57            llm,
58            tools: vec![Arc::new(SimpleDateTool)],
59        }
60    }
61
62    /// Generate the prompt for the decision-making LLM.
63    fn prompt(&self, event: &InvokeDecisioning) -> String {
64        let tools_list: Vec<&dyn LlmTool> = self.tools.iter().map(|t| t.as_ref()).collect();
65
66        format!(
67            "You are to solve a problem by reasoning and acting on the information you have. Here is the current context:
68
69{}
70{}
71
72Your Instructions:
73Review the current plan and history. Decide what to do next:
74
751. PLAN - If the plan is incomplete or needs refinement
762. ACT - If you should take an action using one of the available tools
773. FINISH - If you have enough information to answer the user's query
78
79If you choose ACT, specify which tool to use and what arguments to pass.
80Think carefully about whether each step in the plan has been completed.",
81            format_current_context(&event.context),
82            format_available_tools(&tools_list)
83        )
84    }
85}
86
87#[async_trait]
88impl BaseAsyncAgent for DecisioningAgent {
89    async fn receive_event_async(&self, event: Box<dyn Event>) -> Result<Vec<Box<dyn Event>>> {
90        // Downcast to InvokeDecisioning
91        let decisioning_event = match event.as_any().downcast_ref::<InvokeDecisioning>() {
92            Some(e) => e,
93            None => return Ok(vec![]),
94        };
95
96        // Check iteration limit
97        if decisioning_event.context.iteration >= Self::MAX_ITERATIONS {
98            return Ok(vec![Box::new(FailureOccurred {
99                source: "DecisioningAgent".to_string(),
100                correlation_id: decisioning_event.correlation_id.clone(),
101                context: decisioning_event.context.clone(),
102                reason: format!("Maximum iterations ({}) exceeded", Self::MAX_ITERATIONS),
103            }) as Box<dyn Event>]);
104        }
105
106        // Increment iteration counter
107        let mut updated_context = decisioning_event.context.clone();
108        updated_context.iteration += 1;
109
110        let prompt = self.prompt(decisioning_event);
111        println!("\n{}\n{}\n{}\n", "=".repeat(80), prompt, "=".repeat(80));
112
113        // Generate decision using structured output
114        let decision = match self
115            .llm
116            .generate_object::<DecisionResponse>(
117                &[LlmMessage {
118                    role: MessageRole::User,
119                    content: Some(prompt),
120                    tool_calls: None,
121                    image_paths: None,
122                }],
123                None,
124                decisioning_event.correlation_id.clone(),
125            )
126            .await
127        {
128            Ok(d) => d,
129            Err(e) => {
130                return Ok(vec![Box::new(FailureOccurred {
131                    source: "DecisioningAgent".to_string(),
132                    correlation_id: decisioning_event.correlation_id.clone(),
133                    context: updated_context,
134                    reason: format!("Error during decision making: {}", e),
135                }) as Box<dyn Event>]);
136            }
137        };
138
139        println!("\n{}\nDecision: {:?}\n{}\n", "=".repeat(80), decision, "=".repeat(80));
140
141        // Route based on decision
142        match decision.next_action {
143            NextAction::Finish => Ok(vec![Box::new(FinishAndSummarize {
144                source: "DecisioningAgent".to_string(),
145                correlation_id: decisioning_event.correlation_id.clone(),
146                context: updated_context,
147                thought: decision.thought,
148            }) as Box<dyn Event>]),
149
150            NextAction::Act => {
151                let tool_name = match decision.tool_name {
152                    Some(name) => name,
153                    None => {
154                        return Ok(vec![Box::new(FailureOccurred {
155                            source: "DecisioningAgent".to_string(),
156                            correlation_id: decisioning_event.correlation_id.clone(),
157                            context: updated_context,
158                            reason: "ACT decision made but no tool specified".to_string(),
159                        }) as Box<dyn Event>]);
160                    }
161                };
162
163                // Find the requested tool
164                let tool =
165                    match self.tools.iter().find(|t| t.descriptor().function.name == tool_name) {
166                        Some(t) => t.clone(),
167                        None => {
168                            return Ok(vec![Box::new(FailureOccurred {
169                                source: "DecisioningAgent".to_string(),
170                                correlation_id: decisioning_event.correlation_id.clone(),
171                                context: updated_context,
172                                reason: format!("Tool '{}' not found", tool_name),
173                            }) as Box<dyn Event>]);
174                        }
175                    };
176
177                Ok(vec![Box::new(InvokeToolCall {
178                    source: "DecisioningAgent".to_string(),
179                    correlation_id: decisioning_event.correlation_id.clone(),
180                    context: updated_context,
181                    thought: decision.thought,
182                    action: NextAction::Act,
183                    tool,
184                    tool_arguments: decision.tool_arguments,
185                }) as Box<dyn Event>])
186            }
187
188            NextAction::Plan => Ok(vec![Box::new(InvokeThinking {
189                source: "DecisioningAgent".to_string(),
190                correlation_id: decisioning_event.correlation_id.clone(),
191                context: updated_context,
192            }) as Box<dyn Event>]),
193        }
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200    use crate::llm::gateways::OllamaGateway;
201
202    use super::super::models::{CurrentContext, Plan};
203
204    #[test]
205    fn test_decisioning_agent_prompt_generation() {
206        let gateway = Arc::new(OllamaGateway::new());
207        let llm = Arc::new(LlmBroker::new("qwen3:32b", gateway, None));
208        let agent = DecisioningAgent::new(llm);
209
210        let context = CurrentContext::new("What is the date tomorrow?");
211        let event = InvokeDecisioning {
212            source: "TestSource".to_string(),
213            correlation_id: Some("test-123".to_string()),
214            context,
215        };
216
217        let prompt = agent.prompt(&event);
218
219        assert!(prompt.contains("What is the date tomorrow?"));
220        assert!(prompt.contains("Decide what to do next"));
221        assert!(prompt.contains("PLAN"));
222        assert!(prompt.contains("ACT"));
223        assert!(prompt.contains("FINISH"));
224    }
225
226    #[test]
227    fn test_decisioning_agent_with_plan() {
228        let gateway = Arc::new(OllamaGateway::new());
229        let llm = Arc::new(LlmBroker::new("qwen3:32b", gateway, None));
230        let agent = DecisioningAgent::new(llm);
231
232        let mut context = CurrentContext::new("What day is it?");
233        context.plan = Plan {
234            steps: vec!["Call resolve_date tool".to_string()],
235        };
236
237        let event = InvokeDecisioning {
238            source: "TestSource".to_string(),
239            correlation_id: Some("test-456".to_string()),
240            context,
241        };
242
243        let prompt = agent.prompt(&event);
244
245        assert!(prompt.contains("Current plan:"));
246        assert!(prompt.contains("Call resolve_date tool"));
247    }
248
249    #[tokio::test]
250    async fn test_decisioning_agent_max_iterations() {
251        let gateway = Arc::new(OllamaGateway::new());
252        let llm = Arc::new(LlmBroker::new("qwen3:32b", gateway, None));
253        let agent = DecisioningAgent::new(llm);
254
255        let mut context = CurrentContext::new("Test query");
256        context.iteration = DecisioningAgent::MAX_ITERATIONS;
257
258        let event = Box::new(InvokeDecisioning {
259            source: "TestSource".to_string(),
260            correlation_id: Some("test-max".to_string()),
261            context,
262        }) as Box<dyn Event>;
263
264        let result = agent.receive_event_async(event).await.unwrap();
265        assert_eq!(result.len(), 1);
266
267        let failure = result[0].as_any().downcast_ref::<FailureOccurred>().unwrap();
268        assert!(failure.reason.contains("Maximum iterations"));
269    }
270
271    #[tokio::test]
272    async fn test_decisioning_agent_ignores_wrong_event_type() {
273        let gateway = Arc::new(OllamaGateway::new());
274        let llm = Arc::new(LlmBroker::new("qwen3:32b", gateway, None));
275        let agent = DecisioningAgent::new(llm);
276
277        let wrong_event = Box::new(InvokeThinking {
278            source: "Wrong".to_string(),
279            correlation_id: None,
280            context: CurrentContext::new("Test"),
281        }) as Box<dyn Event>;
282
283        let result = agent.receive_event_async(wrong_event).await.unwrap();
284        assert!(result.is_empty());
285    }
286}