1use crate::agents::BaseAsyncAgent;
6use crate::event::Event;
7use crate::llm::tools::simple_date_tool::SimpleDateTool;
8use crate::llm::tools::LlmTool;
9use crate::llm::{LlmBroker, LlmMessage, MessageRole};
10use crate::Result;
11use async_trait::async_trait;
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::sync::Arc;
15
16use super::events::{
17 FailureOccurred, FinishAndSummarize, InvokeDecisioning, InvokeThinking, InvokeToolCall,
18};
19use super::formatters::{format_available_tools, format_current_context};
20use super::models::NextAction;
21
22#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
24pub struct DecisionResponse {
25 pub thought: String,
27 pub next_action: NextAction,
29 #[serde(skip_serializing_if = "Option::is_none")]
31 pub tool_name: Option<String>,
32 #[serde(default)]
34 pub tool_arguments: HashMap<String, serde_json::Value>,
35}
36
37pub struct DecisioningAgent {
42 llm: Arc<LlmBroker>,
43 tools: Vec<Arc<dyn LlmTool>>,
44}
45
46impl DecisioningAgent {
47 const MAX_ITERATIONS: usize = 10;
49
50 pub fn new(llm: Arc<LlmBroker>) -> Self {
56 Self {
57 llm,
58 tools: vec![Arc::new(SimpleDateTool)],
59 }
60 }
61
62 fn prompt(&self, event: &InvokeDecisioning) -> String {
64 let tools_list: Vec<&dyn LlmTool> = self.tools.iter().map(|t| t.as_ref()).collect();
65
66 format!(
67 "You are to solve a problem by reasoning and acting on the information you have. Here is the current context:
68
69{}
70{}
71
72Your Instructions:
73Review the current plan and history. Decide what to do next:
74
751. PLAN - If the plan is incomplete or needs refinement
762. ACT - If you should take an action using one of the available tools
773. FINISH - If you have enough information to answer the user's query
78
79If you choose ACT, specify which tool to use and what arguments to pass.
80Think carefully about whether each step in the plan has been completed.",
81 format_current_context(&event.context),
82 format_available_tools(&tools_list)
83 )
84 }
85}
86
87#[async_trait]
88impl BaseAsyncAgent for DecisioningAgent {
89 async fn receive_event_async(&self, event: Box<dyn Event>) -> Result<Vec<Box<dyn Event>>> {
90 let decisioning_event = match event.as_any().downcast_ref::<InvokeDecisioning>() {
92 Some(e) => e,
93 None => return Ok(vec![]),
94 };
95
96 if decisioning_event.context.iteration >= Self::MAX_ITERATIONS {
98 return Ok(vec![Box::new(FailureOccurred {
99 source: "DecisioningAgent".to_string(),
100 correlation_id: decisioning_event.correlation_id.clone(),
101 context: decisioning_event.context.clone(),
102 reason: format!("Maximum iterations ({}) exceeded", Self::MAX_ITERATIONS),
103 }) as Box<dyn Event>]);
104 }
105
106 let mut updated_context = decisioning_event.context.clone();
108 updated_context.iteration += 1;
109
110 let prompt = self.prompt(decisioning_event);
111 println!("\n{}\n{}\n{}\n", "=".repeat(80), prompt, "=".repeat(80));
112
113 let decision = match self
115 .llm
116 .generate_object::<DecisionResponse>(
117 &[LlmMessage {
118 role: MessageRole::User,
119 content: Some(prompt),
120 tool_calls: None,
121 image_paths: None,
122 }],
123 None,
124 decisioning_event.correlation_id.clone(),
125 )
126 .await
127 {
128 Ok(d) => d,
129 Err(e) => {
130 return Ok(vec![Box::new(FailureOccurred {
131 source: "DecisioningAgent".to_string(),
132 correlation_id: decisioning_event.correlation_id.clone(),
133 context: updated_context,
134 reason: format!("Error during decision making: {}", e),
135 }) as Box<dyn Event>]);
136 }
137 };
138
139 println!("\n{}\nDecision: {:?}\n{}\n", "=".repeat(80), decision, "=".repeat(80));
140
141 match decision.next_action {
143 NextAction::Finish => Ok(vec![Box::new(FinishAndSummarize {
144 source: "DecisioningAgent".to_string(),
145 correlation_id: decisioning_event.correlation_id.clone(),
146 context: updated_context,
147 thought: decision.thought,
148 }) as Box<dyn Event>]),
149
150 NextAction::Act => {
151 let tool_name = match decision.tool_name {
152 Some(name) => name,
153 None => {
154 return Ok(vec![Box::new(FailureOccurred {
155 source: "DecisioningAgent".to_string(),
156 correlation_id: decisioning_event.correlation_id.clone(),
157 context: updated_context,
158 reason: "ACT decision made but no tool specified".to_string(),
159 }) as Box<dyn Event>]);
160 }
161 };
162
163 let tool =
165 match self.tools.iter().find(|t| t.descriptor().function.name == tool_name) {
166 Some(t) => t.clone(),
167 None => {
168 return Ok(vec![Box::new(FailureOccurred {
169 source: "DecisioningAgent".to_string(),
170 correlation_id: decisioning_event.correlation_id.clone(),
171 context: updated_context,
172 reason: format!("Tool '{}' not found", tool_name),
173 }) as Box<dyn Event>]);
174 }
175 };
176
177 Ok(vec![Box::new(InvokeToolCall {
178 source: "DecisioningAgent".to_string(),
179 correlation_id: decisioning_event.correlation_id.clone(),
180 context: updated_context,
181 thought: decision.thought,
182 action: NextAction::Act,
183 tool,
184 tool_arguments: decision.tool_arguments,
185 }) as Box<dyn Event>])
186 }
187
188 NextAction::Plan => Ok(vec![Box::new(InvokeThinking {
189 source: "DecisioningAgent".to_string(),
190 correlation_id: decisioning_event.correlation_id.clone(),
191 context: updated_context,
192 }) as Box<dyn Event>]),
193 }
194 }
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200 use crate::llm::gateways::OllamaGateway;
201
202 use super::super::models::{CurrentContext, Plan};
203
204 #[test]
205 fn test_decisioning_agent_prompt_generation() {
206 let gateway = Arc::new(OllamaGateway::new());
207 let llm = Arc::new(LlmBroker::new("qwen3:32b", gateway, None));
208 let agent = DecisioningAgent::new(llm);
209
210 let context = CurrentContext::new("What is the date tomorrow?");
211 let event = InvokeDecisioning {
212 source: "TestSource".to_string(),
213 correlation_id: Some("test-123".to_string()),
214 context,
215 };
216
217 let prompt = agent.prompt(&event);
218
219 assert!(prompt.contains("What is the date tomorrow?"));
220 assert!(prompt.contains("Decide what to do next"));
221 assert!(prompt.contains("PLAN"));
222 assert!(prompt.contains("ACT"));
223 assert!(prompt.contains("FINISH"));
224 }
225
226 #[test]
227 fn test_decisioning_agent_with_plan() {
228 let gateway = Arc::new(OllamaGateway::new());
229 let llm = Arc::new(LlmBroker::new("qwen3:32b", gateway, None));
230 let agent = DecisioningAgent::new(llm);
231
232 let mut context = CurrentContext::new("What day is it?");
233 context.plan = Plan {
234 steps: vec!["Call resolve_date tool".to_string()],
235 };
236
237 let event = InvokeDecisioning {
238 source: "TestSource".to_string(),
239 correlation_id: Some("test-456".to_string()),
240 context,
241 };
242
243 let prompt = agent.prompt(&event);
244
245 assert!(prompt.contains("Current plan:"));
246 assert!(prompt.contains("Call resolve_date tool"));
247 }
248
249 #[tokio::test]
250 async fn test_decisioning_agent_max_iterations() {
251 let gateway = Arc::new(OllamaGateway::new());
252 let llm = Arc::new(LlmBroker::new("qwen3:32b", gateway, None));
253 let agent = DecisioningAgent::new(llm);
254
255 let mut context = CurrentContext::new("Test query");
256 context.iteration = DecisioningAgent::MAX_ITERATIONS;
257
258 let event = Box::new(InvokeDecisioning {
259 source: "TestSource".to_string(),
260 correlation_id: Some("test-max".to_string()),
261 context,
262 }) as Box<dyn Event>;
263
264 let result = agent.receive_event_async(event).await.unwrap();
265 assert_eq!(result.len(), 1);
266
267 let failure = result[0].as_any().downcast_ref::<FailureOccurred>().unwrap();
268 assert!(failure.reason.contains("Maximum iterations"));
269 }
270
271 #[tokio::test]
272 async fn test_decisioning_agent_ignores_wrong_event_type() {
273 let gateway = Arc::new(OllamaGateway::new());
274 let llm = Arc::new(LlmBroker::new("qwen3:32b", gateway, None));
275 let agent = DecisioningAgent::new(llm);
276
277 let wrong_event = Box::new(InvokeThinking {
278 source: "Wrong".to_string(),
279 correlation_id: None,
280 context: CurrentContext::new("Test"),
281 }) as Box<dyn Event>;
282
283 let result = agent.receive_event_async(wrong_event).await.unwrap();
284 assert!(result.is_empty());
285 }
286}