mojentic/examples/react/
tool_call_agent.rs

1//! Tool execution agent for the ReAct pattern.
2//!
3//! This agent handles the actual execution of tools and captures the results.
4
5use crate::agents::BaseAsyncAgent;
6use crate::event::Event;
7use crate::Result;
8use async_trait::async_trait;
9
10use super::events::{FailureOccurred, InvokeDecisioning, InvokeToolCall};
11use super::models::ThoughtActionObservation;
12
13/// Agent responsible for executing tool calls.
14///
15/// This agent receives tool call events, executes the specified tool,
16/// and updates the context with the results before continuing to the
17/// decisioning phase.
18pub struct ToolCallAgent;
19
20impl ToolCallAgent {
21    /// Create a new tool call agent.
22    pub fn new() -> Self {
23        Self
24    }
25}
26
27impl Default for ToolCallAgent {
28    fn default() -> Self {
29        Self::new()
30    }
31}
32
33#[async_trait]
34impl BaseAsyncAgent for ToolCallAgent {
35    async fn receive_event_async(&self, event: Box<dyn Event>) -> Result<Vec<Box<dyn Event>>> {
36        // Downcast to InvokeToolCall
37        let tool_call_event = match event.as_any().downcast_ref::<InvokeToolCall>() {
38            Some(e) => e,
39            None => return Ok(vec![]),
40        };
41
42        let tool = &tool_call_event.tool;
43        let tool_name = tool.descriptor().function.name.clone();
44        let arguments = &tool_call_event.tool_arguments;
45
46        println!("\nExecuting tool: {}", tool_name);
47        println!("Arguments: {:?}", arguments);
48
49        // Execute the tool
50        let result = match tool.run(arguments) {
51            Ok(r) => r,
52            Err(e) => {
53                eprintln!("Tool execution error: {}", e);
54                return Ok(vec![Box::new(FailureOccurred {
55                    source: "ToolCallAgent".to_string(),
56                    correlation_id: tool_call_event.correlation_id.clone(),
57                    context: tool_call_event.context.clone(),
58                    reason: format!("Tool execution failed: {}", e),
59                }) as Box<dyn Event>]);
60            }
61        };
62
63        println!("Result: {:?}", result);
64
65        // Extract text content from result
66        let result_text = if result.is_object() {
67            // If it's an object, try to get a "summary" field, otherwise use the whole JSON
68            result
69                .get("summary")
70                .and_then(|v| v.as_str())
71                .map(|s| s.to_string())
72                .unwrap_or_else(|| result.to_string())
73        } else {
74            result.to_string()
75        };
76
77        // Update context with observation
78        let mut updated_context = tool_call_event.context.clone();
79        updated_context.history.push(ThoughtActionObservation {
80            thought: tool_call_event.thought.clone(),
81            action: format!("Called {} with {:?}", tool_name, tool_call_event.tool_arguments),
82            observation: result_text,
83        });
84
85        // Continue to decisioning
86        Ok(vec![Box::new(InvokeDecisioning {
87            source: "ToolCallAgent".to_string(),
88            correlation_id: tool_call_event.correlation_id.clone(),
89            context: updated_context,
90        }) as Box<dyn Event>])
91    }
92}
93
94#[cfg(test)]
95mod tests {
96    use super::*;
97    use crate::llm::tools::simple_date_tool::SimpleDateTool;
98    use crate::llm::tools::LlmTool;
99    use serde_json::json;
100    use std::collections::HashMap;
101    use std::sync::Arc;
102
103    use super::super::events::InvokeToolCall;
104    use super::super::models::{CurrentContext, NextAction};
105
106    #[tokio::test]
107    async fn test_tool_call_agent_successful_execution() {
108        let agent = ToolCallAgent::new();
109        let tool: Arc<dyn LlmTool> = Arc::new(SimpleDateTool);
110
111        let mut args = HashMap::new();
112        args.insert("relative_date".to_string(), json!("tomorrow"));
113
114        let context = CurrentContext::new("What is the date tomorrow?");
115
116        let event = Box::new(InvokeToolCall {
117            source: "TestSource".to_string(),
118            correlation_id: Some("test-123".to_string()),
119            context,
120            thought: "I need to resolve the date".to_string(),
121            action: NextAction::Act,
122            tool,
123            tool_arguments: args,
124        }) as Box<dyn Event>;
125
126        let result = agent.receive_event_async(event).await.unwrap();
127        assert_eq!(result.len(), 1);
128
129        // Should return InvokeDecisioning event
130        let decisioning = result[0].as_any().downcast_ref::<InvokeDecisioning>();
131        assert!(decisioning.is_some());
132
133        let decisioning = decisioning.unwrap();
134        assert_eq!(decisioning.context.history.len(), 1);
135        assert!(decisioning.context.history[0].observation.contains("tomorrow"));
136    }
137
138    #[tokio::test]
139    async fn test_tool_call_agent_ignores_wrong_event_type() {
140        let agent = ToolCallAgent::new();
141
142        let wrong_event = Box::new(InvokeDecisioning {
143            source: "Wrong".to_string(),
144            correlation_id: None,
145            context: CurrentContext::new("Test"),
146        }) as Box<dyn Event>;
147
148        let result = agent.receive_event_async(wrong_event).await.unwrap();
149        assert!(result.is_empty());
150    }
151
152    #[test]
153    fn test_tool_call_agent_default() {
154        let _agent1 = ToolCallAgent::new();
155        let _agent2 = ToolCallAgent;
156    }
157}