mojentic/agents/
async_aggregator_agent.rs

1//! Async event aggregator agent implementation.
2//!
3//! This module provides an agent that aggregates events by correlation ID,
4//! waiting for all required event types before processing them together.
5
6use crate::agents::BaseAsyncAgent;
7use crate::event::Event;
8use crate::{MojenticError, Result};
9use async_trait::async_trait;
10use std::any::TypeId;
11use std::collections::HashMap;
12use std::sync::Arc;
13use std::time::Duration;
14use tokio::sync::{oneshot, Mutex};
15use tracing::debug;
16
17type EventStore = Arc<Mutex<HashMap<String, Vec<Box<dyn Event>>>>>;
18type WaiterStore = Arc<Mutex<HashMap<String, Vec<oneshot::Sender<Vec<Box<dyn Event>>>>>>>;
19
20/// An agent that aggregates events by correlation ID.
21///
22/// This agent waits for all specified event types to arrive for a given
23/// correlation ID before processing them together. This is useful for
24/// workflows where multiple independent operations must complete before
25/// a final action can be taken.
26///
27/// # Examples
28///
29/// ```ignore
30/// use mojentic::agents::AsyncAggregatorAgent;
31/// use std::any::TypeId;
32///
33/// let agent = AsyncAggregatorAgent::new(vec![
34///     TypeId::of::<Event1>(),
35///     TypeId::of::<Event2>(),
36/// ]);
37/// ```
38pub struct AsyncAggregatorAgent {
39    event_types_needed: Vec<TypeId>,
40    results: EventStore,
41    waiters: WaiterStore,
42}
43
44impl AsyncAggregatorAgent {
45    /// Create a new AsyncAggregatorAgent.
46    ///
47    /// # Arguments
48    ///
49    /// * `event_types_needed` - Vector of TypeIds representing the event types
50    ///   that must be collected before processing
51    pub fn new(event_types_needed: Vec<TypeId>) -> Self {
52        Self {
53            event_types_needed,
54            results: Arc::new(Mutex::new(HashMap::new())),
55            waiters: Arc::new(Mutex::new(HashMap::new())),
56        }
57    }
58
59    /// Wait for all needed events for a specific correlation ID.
60    ///
61    /// This method blocks until all required event types have been received
62    /// for the given correlation ID, or until the timeout expires.
63    ///
64    /// # Arguments
65    ///
66    /// * `correlation_id` - The correlation ID to wait for
67    /// * `timeout` - Optional timeout duration
68    ///
69    /// # Returns
70    ///
71    /// Vector of all events collected for this correlation ID
72    pub async fn wait_for_events(
73        &self,
74        correlation_id: &str,
75        timeout: Option<Duration>,
76    ) -> Result<Vec<Box<dyn Event>>> {
77        // Check if we already have all needed events
78        {
79            let results = self.results.lock().await;
80            if let Some(events) = results.get(correlation_id) {
81                if self.has_all_needed_types(events) {
82                    debug!(
83                        "All needed events already available for correlation_id: {}",
84                        correlation_id
85                    );
86                    return Ok(events.iter().map(|e| e.clone_box()).collect());
87                }
88            }
89        }
90
91        // Create a oneshot channel to wait for events
92        let (tx, rx) = oneshot::channel();
93
94        // Register the waiter
95        {
96            let mut waiters = self.waiters.lock().await;
97            waiters.entry(correlation_id.to_string()).or_default().push(tx);
98        }
99
100        // Wait for the events with optional timeout
101        if let Some(timeout_duration) = timeout {
102            match tokio::time::timeout(timeout_duration, rx).await {
103                Ok(Ok(events)) => Ok(events),
104                Ok(Err(_)) => Err(MojenticError::EventError(
105                    "Channel closed before events arrived".to_string(),
106                )),
107                Err(_) => {
108                    debug!("Timeout waiting for events for correlation_id: {}", correlation_id);
109                    // Return whatever we have collected so far
110                    Err(MojenticError::TimeoutError(format!(
111                        "Timeout waiting for events for correlation_id: {}",
112                        correlation_id
113                    )))
114                }
115            }
116        } else {
117            rx.await.map_err(|_| {
118                MojenticError::EventError("Channel closed before events arrived".to_string())
119            })
120        }
121    }
122
123    /// Process collected events.
124    ///
125    /// This method is called when all needed event types have been collected.
126    /// Override this in subclasses to implement custom processing logic.
127    ///
128    /// # Arguments
129    ///
130    /// * `events` - All collected events for a correlation ID
131    ///
132    /// # Returns
133    ///
134    /// Vector of new events to emit
135    pub async fn process_events(
136        &self,
137        _events: Vec<Box<dyn Event>>,
138    ) -> Result<Vec<Box<dyn Event>>> {
139        // Default implementation returns empty
140        // Subclasses should override this
141        Ok(vec![])
142    }
143
144    /// Check if we have all needed event types.
145    fn has_all_needed_types(&self, events: &[Box<dyn Event>]) -> bool {
146        let event_types: Vec<TypeId> = events.iter().map(|e| e.as_any().type_id()).collect();
147
148        self.event_types_needed
149            .iter()
150            .all(|needed_type| event_types.contains(needed_type))
151    }
152
153    /// Capture an event and check if we have all needed types.
154    async fn capture_event(&self, event: Box<dyn Event>) -> Result<Option<Vec<Box<dyn Event>>>> {
155        let correlation_id = event
156            .correlation_id()
157            .ok_or_else(|| MojenticError::EventError("Event missing correlation_id".to_string()))?
158            .to_string();
159
160        // Add event to results
161        {
162            let mut results = self.results.lock().await;
163            results.entry(correlation_id.clone()).or_default().push(event);
164        }
165
166        // Check if we have all needed events
167        let all_events: Option<Vec<Box<dyn Event>>> = {
168            let results = self.results.lock().await;
169            results
170                .get(&correlation_id)
171                .map(|events| events.iter().map(|e| e.clone_box()).collect())
172        };
173
174        if let Some(events) = all_events {
175            if self.has_all_needed_types(&events) {
176                debug!("All needed events collected for correlation_id: {}", correlation_id);
177
178                // Notify all waiters
179                {
180                    let mut waiters = self.waiters.lock().await;
181                    if let Some(senders) = waiters.remove(&correlation_id) {
182                        for sender in senders {
183                            let events_for_waiter: Vec<Box<dyn Event>> =
184                                events.iter().map(|e| e.clone_box()).collect();
185                            let _ = sender.send(events_for_waiter);
186                        }
187                    }
188                }
189
190                // Clear results for this correlation_id
191                {
192                    let mut results = self.results.lock().await;
193                    results.remove(&correlation_id);
194                }
195
196                return Ok(Some(events));
197            }
198        }
199
200        Ok(None)
201    }
202}
203
204#[async_trait]
205impl BaseAsyncAgent for AsyncAggregatorAgent {
206    async fn receive_event_async(&self, event: Box<dyn Event>) -> Result<Vec<Box<dyn Event>>> {
207        debug!("AsyncAggregatorAgent received event");
208
209        // Capture the event
210        if let Some(events) = self.capture_event(event).await? {
211            // We have all needed events, process them
212            return self.process_events(events).await;
213        }
214
215        // Still waiting for more events
216        Ok(vec![])
217    }
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223    use serde::{Deserialize, Serialize};
224    use std::any::Any;
225
226    #[derive(Debug, Clone, Serialize, Deserialize)]
227    struct Event1 {
228        source: String,
229        correlation_id: Option<String>,
230        data: String,
231    }
232
233    impl Event for Event1 {
234        fn source(&self) -> &str {
235            &self.source
236        }
237        fn correlation_id(&self) -> Option<&str> {
238            self.correlation_id.as_deref()
239        }
240        fn set_correlation_id(&mut self, id: String) {
241            self.correlation_id = Some(id);
242        }
243        fn as_any(&self) -> &dyn Any {
244            self
245        }
246        fn clone_box(&self) -> Box<dyn Event> {
247            Box::new(self.clone())
248        }
249    }
250
251    #[derive(Debug, Clone, Serialize, Deserialize)]
252    struct Event2 {
253        source: String,
254        correlation_id: Option<String>,
255        value: i32,
256    }
257
258    impl Event for Event2 {
259        fn source(&self) -> &str {
260            &self.source
261        }
262        fn correlation_id(&self) -> Option<&str> {
263            self.correlation_id.as_deref()
264        }
265        fn set_correlation_id(&mut self, id: String) {
266            self.correlation_id = Some(id);
267        }
268        fn as_any(&self) -> &dyn Any {
269            self
270        }
271        fn clone_box(&self) -> Box<dyn Event> {
272            Box::new(self.clone())
273        }
274    }
275
276    #[tokio::test]
277    async fn test_new_aggregator() {
278        let agent = AsyncAggregatorAgent::new(vec![TypeId::of::<Event1>(), TypeId::of::<Event2>()]);
279        assert_eq!(agent.event_types_needed.len(), 2);
280    }
281
282    #[tokio::test]
283    async fn test_single_event_does_not_trigger() {
284        let agent = AsyncAggregatorAgent::new(vec![TypeId::of::<Event1>(), TypeId::of::<Event2>()]);
285
286        let event1 = Box::new(Event1 {
287            source: "Test".to_string(),
288            correlation_id: Some("test-123".to_string()),
289            data: "data".to_string(),
290        }) as Box<dyn Event>;
291
292        let result = agent.receive_event_async(event1).await.unwrap();
293        assert_eq!(result.len(), 0); // Should not process yet
294    }
295
296    #[tokio::test]
297    async fn test_both_events_trigger_processing() {
298        let agent = Arc::new(AsyncAggregatorAgent::new(vec![
299            TypeId::of::<Event1>(),
300            TypeId::of::<Event2>(),
301        ]));
302
303        let event1 = Box::new(Event1 {
304            source: "Test".to_string(),
305            correlation_id: Some("test-123".to_string()),
306            data: "data".to_string(),
307        }) as Box<dyn Event>;
308
309        let event2 = Box::new(Event2 {
310            source: "Test".to_string(),
311            correlation_id: Some("test-123".to_string()),
312            value: 42,
313        }) as Box<dyn Event>;
314
315        // First event should not trigger
316        let result1 = agent.receive_event_async(event1).await.unwrap();
317        assert_eq!(result1.len(), 0);
318
319        // Second event should trigger processing
320        let result2 = agent.receive_event_async(event2).await.unwrap();
321        // Default process_events returns empty, but it should have been called
322        assert_eq!(result2.len(), 0);
323    }
324
325    #[tokio::test]
326    async fn test_wait_for_events() {
327        let agent = Arc::new(AsyncAggregatorAgent::new(vec![
328            TypeId::of::<Event1>(),
329            TypeId::of::<Event2>(),
330        ]));
331
332        let correlation_id = "wait-test-456";
333        let agent_clone = agent.clone();
334
335        // Spawn a task that will send events
336        tokio::spawn(async move {
337            tokio::time::sleep(Duration::from_millis(100)).await;
338
339            let event1 = Box::new(Event1 {
340                source: "Test".to_string(),
341                correlation_id: Some(correlation_id.to_string()),
342                data: "data".to_string(),
343            }) as Box<dyn Event>;
344
345            agent_clone.receive_event_async(event1).await.unwrap();
346
347            tokio::time::sleep(Duration::from_millis(100)).await;
348
349            let event2 = Box::new(Event2 {
350                source: "Test".to_string(),
351                correlation_id: Some(correlation_id.to_string()),
352                value: 42,
353            }) as Box<dyn Event>;
354
355            agent_clone.receive_event_async(event2).await.unwrap();
356        });
357
358        // Wait for all events
359        let result = agent
360            .wait_for_events(correlation_id, Some(Duration::from_secs(5)))
361            .await
362            .unwrap();
363
364        assert_eq!(result.len(), 2);
365    }
366
367    #[tokio::test]
368    async fn test_wait_for_events_timeout() {
369        let agent = AsyncAggregatorAgent::new(vec![TypeId::of::<Event1>(), TypeId::of::<Event2>()]);
370
371        // Send only one event
372        let event1 = Box::new(Event1 {
373            source: "Test".to_string(),
374            correlation_id: Some("timeout-test".to_string()),
375            data: "data".to_string(),
376        }) as Box<dyn Event>;
377
378        agent.receive_event_async(event1).await.unwrap();
379
380        // Wait should timeout
381        let result = agent.wait_for_events("timeout-test", Some(Duration::from_millis(100))).await;
382
383        assert!(result.is_err());
384        match result {
385            Err(MojenticError::TimeoutError(_)) => {}
386            _ => panic!("Expected TimeoutError"),
387        }
388    }
389
390    #[tokio::test]
391    async fn test_different_correlation_ids() {
392        let agent = Arc::new(AsyncAggregatorAgent::new(vec![
393            TypeId::of::<Event1>(),
394            TypeId::of::<Event2>(),
395        ]));
396
397        // Send events with different correlation IDs
398        let event1_a = Box::new(Event1 {
399            source: "Test".to_string(),
400            correlation_id: Some("corr-a".to_string()),
401            data: "data-a".to_string(),
402        }) as Box<dyn Event>;
403
404        let event1_b = Box::new(Event1 {
405            source: "Test".to_string(),
406            correlation_id: Some("corr-b".to_string()),
407            data: "data-b".to_string(),
408        }) as Box<dyn Event>;
409
410        agent.receive_event_async(event1_a).await.unwrap();
411        agent.receive_event_async(event1_b).await.unwrap();
412
413        // Complete corr-a
414        let event2_a = Box::new(Event2 {
415            source: "Test".to_string(),
416            correlation_id: Some("corr-a".to_string()),
417            value: 1,
418        }) as Box<dyn Event>;
419
420        let result = agent.receive_event_async(event2_a).await.unwrap();
421        assert_eq!(result.len(), 0); // corr-a completes
422
423        // Complete corr-b
424        let event2_b = Box::new(Event2 {
425            source: "Test".to_string(),
426            correlation_id: Some("corr-b".to_string()),
427            value: 2,
428        }) as Box<dyn Event>;
429
430        let result = agent.receive_event_async(event2_b).await.unwrap();
431        assert_eq!(result.len(), 0); // corr-b completes
432    }
433
434    #[tokio::test]
435    async fn test_event_without_correlation_id_fails() {
436        let agent = AsyncAggregatorAgent::new(vec![TypeId::of::<Event1>()]);
437
438        let event = Box::new(Event1 {
439            source: "Test".to_string(),
440            correlation_id: None,
441            data: "data".to_string(),
442        }) as Box<dyn Event>;
443
444        let result = agent.receive_event_async(event).await;
445        assert!(result.is_err());
446        match result {
447            Err(MojenticError::EventError(_)) => {}
448            _ => panic!("Expected EventError"),
449        }
450    }
451
452    #[tokio::test]
453    async fn test_process_events_override() {
454        struct CustomAggregator {
455            inner: AsyncAggregatorAgent,
456            processed_count: Arc<Mutex<usize>>,
457        }
458
459        impl CustomAggregator {
460            fn new(event_types: Vec<TypeId>) -> Self {
461                Self {
462                    inner: AsyncAggregatorAgent::new(event_types),
463                    processed_count: Arc::new(Mutex::new(0)),
464                }
465            }
466        }
467
468        #[async_trait]
469        impl BaseAsyncAgent for CustomAggregator {
470            async fn receive_event_async(
471                &self,
472                event: Box<dyn Event>,
473            ) -> Result<Vec<Box<dyn Event>>> {
474                if let Some(_events) = self.inner.capture_event(event).await? {
475                    // Custom processing
476                    let mut count = self.processed_count.lock().await;
477                    *count += 1;
478
479                    return Ok(vec![]);
480                }
481                Ok(vec![])
482            }
483        }
484
485        let agent = CustomAggregator::new(vec![TypeId::of::<Event1>(), TypeId::of::<Event2>()]);
486        let count_clone = agent.processed_count.clone();
487
488        let event1 = Box::new(Event1 {
489            source: "Test".to_string(),
490            correlation_id: Some("custom-test".to_string()),
491            data: "data".to_string(),
492        }) as Box<dyn Event>;
493
494        let event2 = Box::new(Event2 {
495            source: "Test".to_string(),
496            correlation_id: Some("custom-test".to_string()),
497            value: 42,
498        }) as Box<dyn Event>;
499
500        agent.receive_event_async(event1).await.unwrap();
501        agent.receive_event_async(event2).await.unwrap();
502
503        let count = *count_clone.lock().await;
504        assert_eq!(count, 1); // process_events should have been called once
505    }
506}