mojentic/agents/
async_llm_agent.rs

1//! Async LLM-powered agent implementation.
2//!
3//! This module provides an agent that uses an LLM to generate responses to events.
4//! It supports system prompts (behaviour), structured output via response models,
5//! and tool calling.
6
7use crate::agents::BaseAsyncAgent;
8use crate::event::Event;
9use crate::llm::{LlmBroker, LlmMessage, LlmTool};
10use crate::Result;
11use async_trait::async_trait;
12use serde::{Deserialize, Serialize};
13use std::sync::Arc;
14
15/// An async agent powered by an LLM.
16///
17/// This agent uses an LLM broker to generate responses. It can be configured
18/// with a system prompt (behaviour), optional tools, and a response model for
19/// structured output.
20///
21/// # Examples
22///
23/// ```ignore
24/// use mojentic::agents::AsyncLlmAgent;
25/// use mojentic::llm::LlmBroker;
26///
27/// let broker = Arc::new(LlmBroker::new("model-name", gateway, None));
28/// let agent = AsyncLlmAgent::new(
29///     broker,
30///     "You are a helpful assistant.",
31///     None, // tools
32/// );
33/// ```
34pub struct AsyncLlmAgent {
35    broker: Arc<LlmBroker>,
36    behaviour: String,
37    tools: Vec<Box<dyn LlmTool>>,
38}
39
40impl AsyncLlmAgent {
41    /// Create a new AsyncLlmAgent.
42    ///
43    /// # Arguments
44    ///
45    /// * `broker` - The LLM broker to use for generating responses
46    /// * `behaviour` - System prompt defining the agent's personality and behavior
47    /// * `tools` - Optional tools available to the LLM
48    pub fn new(
49        broker: Arc<LlmBroker>,
50        behaviour: impl Into<String>,
51        tools: Option<Vec<Box<dyn LlmTool>>>,
52    ) -> Self {
53        Self {
54            broker,
55            behaviour: behaviour.into(),
56            tools: tools.unwrap_or_default(),
57        }
58    }
59
60    /// Add a tool to the agent.
61    ///
62    /// # Arguments
63    ///
64    /// * `tool` - The tool to add
65    pub fn add_tool(&mut self, tool: Box<dyn LlmTool>) {
66        self.tools.push(tool);
67    }
68
69    /// Generate a text response using the LLM.
70    ///
71    /// # Arguments
72    ///
73    /// * `content` - The user message content
74    /// * `correlation_id` - Optional correlation ID for tracing
75    ///
76    /// # Returns
77    ///
78    /// The generated text response
79    pub async fn generate_response(
80        &self,
81        content: &str,
82        correlation_id: Option<String>,
83    ) -> Result<String> {
84        let messages = vec![
85            LlmMessage::system(&self.behaviour),
86            LlmMessage::user(content),
87        ];
88
89        let tools = if self.tools.is_empty() {
90            None
91        } else {
92            Some(self.tools.as_slice())
93        };
94
95        self.broker.generate(&messages, tools, None, correlation_id).await
96    }
97
98    /// Generate a structured object response using the LLM.
99    ///
100    /// # Arguments
101    ///
102    /// * `content` - The user message content
103    /// * `correlation_id` - Optional correlation ID for tracing
104    ///
105    /// # Returns
106    ///
107    /// The generated structured object
108    pub async fn generate_object<T>(
109        &self,
110        content: &str,
111        correlation_id: Option<String>,
112    ) -> Result<T>
113    where
114        T: for<'de> Deserialize<'de> + Serialize + schemars::JsonSchema + Send,
115    {
116        let messages = vec![
117            LlmMessage::system(&self.behaviour),
118            LlmMessage::user(content),
119        ];
120
121        self.broker.generate_object(&messages, None, correlation_id).await
122    }
123}
124
125#[async_trait]
126impl BaseAsyncAgent for AsyncLlmAgent {
127    async fn receive_event_async(&self, _event: Box<dyn Event>) -> Result<Vec<Box<dyn Event>>> {
128        // Default implementation returns no events
129        // Subclasses should override this to handle specific event types
130        Ok(vec![])
131    }
132}
133
134#[cfg(test)]
135mod tests {
136    use super::*;
137    use crate::llm::gateway::{CompletionConfig, StreamChunk};
138    use crate::llm::{LlmGateway, LlmGatewayResponse};
139    use futures::stream::{self, Stream};
140    use serde_json::Value;
141    use std::collections::HashMap;
142    use std::pin::Pin;
143
144    // Mock gateway for testing
145    struct MockGateway {
146        response: String,
147    }
148
149    impl MockGateway {
150        fn new(response: impl Into<String>) -> Self {
151            Self {
152                response: response.into(),
153            }
154        }
155    }
156
157    #[async_trait::async_trait]
158    impl LlmGateway for MockGateway {
159        async fn complete(
160            &self,
161            _model: &str,
162            _messages: &[LlmMessage],
163            _tools: Option<&[Box<dyn LlmTool>]>,
164            _config: &CompletionConfig,
165        ) -> Result<LlmGatewayResponse> {
166            Ok(LlmGatewayResponse {
167                content: Some(self.response.clone()),
168                object: None,
169                tool_calls: vec![],
170                thinking: None,
171            })
172        }
173
174        async fn complete_json(
175            &self,
176            _model: &str,
177            _messages: &[LlmMessage],
178            _schema: Value,
179            _config: &CompletionConfig,
180        ) -> Result<Value> {
181            Ok(serde_json::json!({
182                "message": self.response,
183                "confidence": 0.95
184            }))
185        }
186
187        async fn get_available_models(&self) -> Result<Vec<String>> {
188            Ok(vec!["test-model".to_string()])
189        }
190
191        async fn calculate_embeddings(
192            &self,
193            _text: &str,
194            _model: Option<&str>,
195        ) -> Result<Vec<f32>> {
196            Ok(vec![0.1, 0.2, 0.3])
197        }
198
199        fn complete_stream<'a>(
200            &'a self,
201            _model: &'a str,
202            _messages: &'a [LlmMessage],
203            _tools: Option<&'a [Box<dyn LlmTool>]>,
204            _config: &'a CompletionConfig,
205        ) -> Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send + 'a>> {
206            Box::pin(stream::iter(vec![Ok(StreamChunk::Content(self.response.clone()))]))
207        }
208    }
209
210    #[tokio::test]
211    async fn test_new_agent() {
212        let gateway = Arc::new(MockGateway::new("test response"));
213        let broker = Arc::new(LlmBroker::new("test-model", gateway, None));
214        let agent = AsyncLlmAgent::new(broker, "You are helpful", None);
215
216        assert_eq!(agent.behaviour, "You are helpful");
217        assert_eq!(agent.tools.len(), 0);
218    }
219
220    #[tokio::test]
221    async fn test_new_agent_with_tools() {
222        use crate::llm::tools::{FunctionDescriptor, ToolDescriptor};
223
224        #[derive(Clone)]
225        struct MockTool;
226        impl LlmTool for MockTool {
227            fn run(&self, _args: &HashMap<String, Value>) -> Result<Value> {
228                Ok(serde_json::json!({"result": "ok"}))
229            }
230            fn descriptor(&self) -> ToolDescriptor {
231                ToolDescriptor {
232                    r#type: "function".to_string(),
233                    function: FunctionDescriptor {
234                        name: "mock_tool".to_string(),
235                        description: "A mock tool".to_string(),
236                        parameters: serde_json::json!({}),
237                    },
238                }
239            }
240            fn clone_box(&self) -> Box<dyn LlmTool> {
241                Box::new(self.clone())
242            }
243        }
244
245        let gateway = Arc::new(MockGateway::new("test"));
246        let broker = Arc::new(LlmBroker::new("test-model", gateway, None));
247        let tools: Vec<Box<dyn LlmTool>> = vec![Box::new(MockTool)];
248        let agent = AsyncLlmAgent::new(broker, "You are helpful", Some(tools));
249
250        assert_eq!(agent.tools.len(), 1);
251    }
252
253    #[tokio::test]
254    async fn test_add_tool() {
255        use crate::llm::tools::{FunctionDescriptor, ToolDescriptor};
256
257        #[derive(Clone)]
258        struct MockTool;
259        impl LlmTool for MockTool {
260            fn run(&self, _args: &HashMap<String, Value>) -> Result<Value> {
261                Ok(serde_json::json!({"result": "ok"}))
262            }
263            fn descriptor(&self) -> ToolDescriptor {
264                ToolDescriptor {
265                    r#type: "function".to_string(),
266                    function: FunctionDescriptor {
267                        name: "mock_tool".to_string(),
268                        description: "A mock tool".to_string(),
269                        parameters: serde_json::json!({}),
270                    },
271                }
272            }
273            fn clone_box(&self) -> Box<dyn LlmTool> {
274                Box::new(self.clone())
275            }
276        }
277
278        let gateway = Arc::new(MockGateway::new("test"));
279        let broker = Arc::new(LlmBroker::new("test-model", gateway, None));
280        let mut agent = AsyncLlmAgent::new(broker, "You are helpful", None);
281
282        assert_eq!(agent.tools.len(), 0);
283        agent.add_tool(Box::new(MockTool));
284        assert_eq!(agent.tools.len(), 1);
285    }
286
287    #[tokio::test]
288    async fn test_generate_response() {
289        let gateway = Arc::new(MockGateway::new("Hello from LLM"));
290        let broker = Arc::new(LlmBroker::new("test-model", gateway, None));
291        let agent = AsyncLlmAgent::new(broker, "You are helpful", None);
292
293        let response = agent.generate_response("Test message", None).await.unwrap();
294
295        assert_eq!(response, "Hello from LLM");
296    }
297
298    #[tokio::test]
299    async fn test_generate_response_with_correlation_id() {
300        let gateway = Arc::new(MockGateway::new("Response"));
301        let broker = Arc::new(LlmBroker::new("test-model", gateway, None));
302        let agent = AsyncLlmAgent::new(broker, "You are helpful", None);
303
304        let response = agent
305            .generate_response("Test", Some("correlation-123".to_string()))
306            .await
307            .unwrap();
308
309        assert_eq!(response, "Response");
310    }
311
312    #[tokio::test]
313    async fn test_generate_object() {
314        #[derive(Debug, Serialize, Deserialize, schemars::JsonSchema)]
315        struct TestResponse {
316            message: String,
317            confidence: f64,
318        }
319
320        let gateway = Arc::new(MockGateway::new("Test message"));
321        let broker = Arc::new(LlmBroker::new("test-model", gateway, None));
322        let agent = AsyncLlmAgent::new(broker, "You are helpful", None);
323
324        let response: TestResponse = agent.generate_object("Generate object", None).await.unwrap();
325
326        assert_eq!(response.message, "Test message");
327        assert_eq!(response.confidence, 0.95);
328    }
329
330    #[tokio::test]
331    async fn test_receive_event_async_default() {
332        use crate::event::Event;
333        use serde::{Deserialize, Serialize};
334        use std::any::Any;
335
336        #[derive(Debug, Clone, Serialize, Deserialize)]
337        struct TestEvent {
338            source: String,
339            correlation_id: Option<String>,
340        }
341
342        impl Event for TestEvent {
343            fn source(&self) -> &str {
344                &self.source
345            }
346            fn correlation_id(&self) -> Option<&str> {
347                self.correlation_id.as_deref()
348            }
349            fn set_correlation_id(&mut self, id: String) {
350                self.correlation_id = Some(id);
351            }
352            fn as_any(&self) -> &dyn Any {
353                self
354            }
355            fn clone_box(&self) -> Box<dyn Event> {
356                Box::new(self.clone())
357            }
358        }
359
360        let gateway = Arc::new(MockGateway::new("test"));
361        let broker = Arc::new(LlmBroker::new("test-model", gateway, None));
362        let agent = AsyncLlmAgent::new(broker, "You are helpful", None);
363
364        let event = Box::new(TestEvent {
365            source: "Test".to_string(),
366            correlation_id: None,
367        }) as Box<dyn Event>;
368
369        let result = agent.receive_event_async(event).await.unwrap();
370        assert_eq!(result.len(), 0); // Default implementation returns empty
371    }
372
373    #[tokio::test]
374    async fn test_agent_with_custom_behaviour() {
375        let gateway = Arc::new(MockGateway::new("Custom response"));
376        let broker = Arc::new(LlmBroker::new("test-model", gateway, None));
377        let agent =
378            AsyncLlmAgent::new(broker, "You are a specialized agent with custom behavior", None);
379
380        assert_eq!(agent.behaviour, "You are a specialized agent with custom behavior");
381
382        let response = agent.generate_response("Test", None).await.unwrap();
383        assert_eq!(response, "Custom response");
384    }
385}