1use crate::agents::BaseAsyncAgent;
7use crate::event::Event;
8use crate::{MojenticError, Result};
9use async_trait::async_trait;
10use std::any::TypeId;
11use std::collections::HashMap;
12use std::sync::Arc;
13use std::time::Duration;
14use tokio::sync::{oneshot, Mutex};
15use tracing::debug;
16
17type EventStore = Arc<Mutex<HashMap<String, Vec<Box<dyn Event>>>>>;
18type WaiterStore = Arc<Mutex<HashMap<String, Vec<oneshot::Sender<Vec<Box<dyn Event>>>>>>>;
19
20pub struct AsyncAggregatorAgent {
39 event_types_needed: Vec<TypeId>,
40 results: EventStore,
41 waiters: WaiterStore,
42}
43
44impl AsyncAggregatorAgent {
45 pub fn new(event_types_needed: Vec<TypeId>) -> Self {
52 Self {
53 event_types_needed,
54 results: Arc::new(Mutex::new(HashMap::new())),
55 waiters: Arc::new(Mutex::new(HashMap::new())),
56 }
57 }
58
59 pub async fn wait_for_events(
73 &self,
74 correlation_id: &str,
75 timeout: Option<Duration>,
76 ) -> Result<Vec<Box<dyn Event>>> {
77 {
79 let results = self.results.lock().await;
80 if let Some(events) = results.get(correlation_id) {
81 if self.has_all_needed_types(events) {
82 debug!(
83 "All needed events already available for correlation_id: {}",
84 correlation_id
85 );
86 return Ok(events.iter().map(|e| e.clone_box()).collect());
87 }
88 }
89 }
90
91 let (tx, rx) = oneshot::channel();
93
94 {
96 let mut waiters = self.waiters.lock().await;
97 waiters.entry(correlation_id.to_string()).or_default().push(tx);
98 }
99
100 if let Some(timeout_duration) = timeout {
102 match tokio::time::timeout(timeout_duration, rx).await {
103 Ok(Ok(events)) => Ok(events),
104 Ok(Err(_)) => Err(MojenticError::EventError(
105 "Channel closed before events arrived".to_string(),
106 )),
107 Err(_) => {
108 debug!("Timeout waiting for events for correlation_id: {}", correlation_id);
109 Err(MojenticError::TimeoutError(format!(
111 "Timeout waiting for events for correlation_id: {}",
112 correlation_id
113 )))
114 }
115 }
116 } else {
117 rx.await.map_err(|_| {
118 MojenticError::EventError("Channel closed before events arrived".to_string())
119 })
120 }
121 }
122
123 pub async fn process_events(
136 &self,
137 _events: Vec<Box<dyn Event>>,
138 ) -> Result<Vec<Box<dyn Event>>> {
139 Ok(vec![])
142 }
143
144 fn has_all_needed_types(&self, events: &[Box<dyn Event>]) -> bool {
146 let event_types: Vec<TypeId> = events.iter().map(|e| e.as_any().type_id()).collect();
147
148 self.event_types_needed
149 .iter()
150 .all(|needed_type| event_types.contains(needed_type))
151 }
152
153 async fn capture_event(&self, event: Box<dyn Event>) -> Result<Option<Vec<Box<dyn Event>>>> {
155 let correlation_id = event
156 .correlation_id()
157 .ok_or_else(|| MojenticError::EventError("Event missing correlation_id".to_string()))?
158 .to_string();
159
160 {
162 let mut results = self.results.lock().await;
163 results.entry(correlation_id.clone()).or_default().push(event);
164 }
165
166 let all_events: Option<Vec<Box<dyn Event>>> = {
168 let results = self.results.lock().await;
169 results
170 .get(&correlation_id)
171 .map(|events| events.iter().map(|e| e.clone_box()).collect())
172 };
173
174 if let Some(events) = all_events {
175 if self.has_all_needed_types(&events) {
176 debug!("All needed events collected for correlation_id: {}", correlation_id);
177
178 {
180 let mut waiters = self.waiters.lock().await;
181 if let Some(senders) = waiters.remove(&correlation_id) {
182 for sender in senders {
183 let events_for_waiter: Vec<Box<dyn Event>> =
184 events.iter().map(|e| e.clone_box()).collect();
185 let _ = sender.send(events_for_waiter);
186 }
187 }
188 }
189
190 {
192 let mut results = self.results.lock().await;
193 results.remove(&correlation_id);
194 }
195
196 return Ok(Some(events));
197 }
198 }
199
200 Ok(None)
201 }
202}
203
204#[async_trait]
205impl BaseAsyncAgent for AsyncAggregatorAgent {
206 async fn receive_event_async(&self, event: Box<dyn Event>) -> Result<Vec<Box<dyn Event>>> {
207 debug!("AsyncAggregatorAgent received event");
208
209 if let Some(events) = self.capture_event(event).await? {
211 return self.process_events(events).await;
213 }
214
215 Ok(vec![])
217 }
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223 use serde::{Deserialize, Serialize};
224 use std::any::Any;
225
226 #[derive(Debug, Clone, Serialize, Deserialize)]
227 struct Event1 {
228 source: String,
229 correlation_id: Option<String>,
230 data: String,
231 }
232
233 impl Event for Event1 {
234 fn source(&self) -> &str {
235 &self.source
236 }
237 fn correlation_id(&self) -> Option<&str> {
238 self.correlation_id.as_deref()
239 }
240 fn set_correlation_id(&mut self, id: String) {
241 self.correlation_id = Some(id);
242 }
243 fn as_any(&self) -> &dyn Any {
244 self
245 }
246 fn clone_box(&self) -> Box<dyn Event> {
247 Box::new(self.clone())
248 }
249 }
250
251 #[derive(Debug, Clone, Serialize, Deserialize)]
252 struct Event2 {
253 source: String,
254 correlation_id: Option<String>,
255 value: i32,
256 }
257
258 impl Event for Event2 {
259 fn source(&self) -> &str {
260 &self.source
261 }
262 fn correlation_id(&self) -> Option<&str> {
263 self.correlation_id.as_deref()
264 }
265 fn set_correlation_id(&mut self, id: String) {
266 self.correlation_id = Some(id);
267 }
268 fn as_any(&self) -> &dyn Any {
269 self
270 }
271 fn clone_box(&self) -> Box<dyn Event> {
272 Box::new(self.clone())
273 }
274 }
275
276 #[tokio::test]
277 async fn test_new_aggregator() {
278 let agent = AsyncAggregatorAgent::new(vec![TypeId::of::<Event1>(), TypeId::of::<Event2>()]);
279 assert_eq!(agent.event_types_needed.len(), 2);
280 }
281
282 #[tokio::test]
283 async fn test_single_event_does_not_trigger() {
284 let agent = AsyncAggregatorAgent::new(vec![TypeId::of::<Event1>(), TypeId::of::<Event2>()]);
285
286 let event1 = Box::new(Event1 {
287 source: "Test".to_string(),
288 correlation_id: Some("test-123".to_string()),
289 data: "data".to_string(),
290 }) as Box<dyn Event>;
291
292 let result = agent.receive_event_async(event1).await.unwrap();
293 assert_eq!(result.len(), 0); }
295
296 #[tokio::test]
297 async fn test_both_events_trigger_processing() {
298 let agent = Arc::new(AsyncAggregatorAgent::new(vec![
299 TypeId::of::<Event1>(),
300 TypeId::of::<Event2>(),
301 ]));
302
303 let event1 = Box::new(Event1 {
304 source: "Test".to_string(),
305 correlation_id: Some("test-123".to_string()),
306 data: "data".to_string(),
307 }) as Box<dyn Event>;
308
309 let event2 = Box::new(Event2 {
310 source: "Test".to_string(),
311 correlation_id: Some("test-123".to_string()),
312 value: 42,
313 }) as Box<dyn Event>;
314
315 let result1 = agent.receive_event_async(event1).await.unwrap();
317 assert_eq!(result1.len(), 0);
318
319 let result2 = agent.receive_event_async(event2).await.unwrap();
321 assert_eq!(result2.len(), 0);
323 }
324
325 #[tokio::test]
326 async fn test_wait_for_events() {
327 let agent = Arc::new(AsyncAggregatorAgent::new(vec![
328 TypeId::of::<Event1>(),
329 TypeId::of::<Event2>(),
330 ]));
331
332 let correlation_id = "wait-test-456";
333 let agent_clone = agent.clone();
334
335 tokio::spawn(async move {
337 tokio::time::sleep(Duration::from_millis(100)).await;
338
339 let event1 = Box::new(Event1 {
340 source: "Test".to_string(),
341 correlation_id: Some(correlation_id.to_string()),
342 data: "data".to_string(),
343 }) as Box<dyn Event>;
344
345 agent_clone.receive_event_async(event1).await.unwrap();
346
347 tokio::time::sleep(Duration::from_millis(100)).await;
348
349 let event2 = Box::new(Event2 {
350 source: "Test".to_string(),
351 correlation_id: Some(correlation_id.to_string()),
352 value: 42,
353 }) as Box<dyn Event>;
354
355 agent_clone.receive_event_async(event2).await.unwrap();
356 });
357
358 let result = agent
360 .wait_for_events(correlation_id, Some(Duration::from_secs(5)))
361 .await
362 .unwrap();
363
364 assert_eq!(result.len(), 2);
365 }
366
367 #[tokio::test]
368 async fn test_wait_for_events_timeout() {
369 let agent = AsyncAggregatorAgent::new(vec![TypeId::of::<Event1>(), TypeId::of::<Event2>()]);
370
371 let event1 = Box::new(Event1 {
373 source: "Test".to_string(),
374 correlation_id: Some("timeout-test".to_string()),
375 data: "data".to_string(),
376 }) as Box<dyn Event>;
377
378 agent.receive_event_async(event1).await.unwrap();
379
380 let result = agent.wait_for_events("timeout-test", Some(Duration::from_millis(100))).await;
382
383 assert!(result.is_err());
384 match result {
385 Err(MojenticError::TimeoutError(_)) => {}
386 _ => panic!("Expected TimeoutError"),
387 }
388 }
389
390 #[tokio::test]
391 async fn test_different_correlation_ids() {
392 let agent = Arc::new(AsyncAggregatorAgent::new(vec![
393 TypeId::of::<Event1>(),
394 TypeId::of::<Event2>(),
395 ]));
396
397 let event1_a = Box::new(Event1 {
399 source: "Test".to_string(),
400 correlation_id: Some("corr-a".to_string()),
401 data: "data-a".to_string(),
402 }) as Box<dyn Event>;
403
404 let event1_b = Box::new(Event1 {
405 source: "Test".to_string(),
406 correlation_id: Some("corr-b".to_string()),
407 data: "data-b".to_string(),
408 }) as Box<dyn Event>;
409
410 agent.receive_event_async(event1_a).await.unwrap();
411 agent.receive_event_async(event1_b).await.unwrap();
412
413 let event2_a = Box::new(Event2 {
415 source: "Test".to_string(),
416 correlation_id: Some("corr-a".to_string()),
417 value: 1,
418 }) as Box<dyn Event>;
419
420 let result = agent.receive_event_async(event2_a).await.unwrap();
421 assert_eq!(result.len(), 0); let event2_b = Box::new(Event2 {
425 source: "Test".to_string(),
426 correlation_id: Some("corr-b".to_string()),
427 value: 2,
428 }) as Box<dyn Event>;
429
430 let result = agent.receive_event_async(event2_b).await.unwrap();
431 assert_eq!(result.len(), 0); }
433
434 #[tokio::test]
435 async fn test_event_without_correlation_id_fails() {
436 let agent = AsyncAggregatorAgent::new(vec![TypeId::of::<Event1>()]);
437
438 let event = Box::new(Event1 {
439 source: "Test".to_string(),
440 correlation_id: None,
441 data: "data".to_string(),
442 }) as Box<dyn Event>;
443
444 let result = agent.receive_event_async(event).await;
445 assert!(result.is_err());
446 match result {
447 Err(MojenticError::EventError(_)) => {}
448 _ => panic!("Expected EventError"),
449 }
450 }
451
452 #[tokio::test]
453 async fn test_process_events_override() {
454 struct CustomAggregator {
455 inner: AsyncAggregatorAgent,
456 processed_count: Arc<Mutex<usize>>,
457 }
458
459 impl CustomAggregator {
460 fn new(event_types: Vec<TypeId>) -> Self {
461 Self {
462 inner: AsyncAggregatorAgent::new(event_types),
463 processed_count: Arc::new(Mutex::new(0)),
464 }
465 }
466 }
467
468 #[async_trait]
469 impl BaseAsyncAgent for CustomAggregator {
470 async fn receive_event_async(
471 &self,
472 event: Box<dyn Event>,
473 ) -> Result<Vec<Box<dyn Event>>> {
474 if let Some(_events) = self.inner.capture_event(event).await? {
475 let mut count = self.processed_count.lock().await;
477 *count += 1;
478
479 return Ok(vec![]);
480 }
481 Ok(vec![])
482 }
483 }
484
485 let agent = CustomAggregator::new(vec![TypeId::of::<Event1>(), TypeId::of::<Event2>()]);
486 let count_clone = agent.processed_count.clone();
487
488 let event1 = Box::new(Event1 {
489 source: "Test".to_string(),
490 correlation_id: Some("custom-test".to_string()),
491 data: "data".to_string(),
492 }) as Box<dyn Event>;
493
494 let event2 = Box::new(Event2 {
495 source: "Test".to_string(),
496 correlation_id: Some("custom-test".to_string()),
497 value: 42,
498 }) as Box<dyn Event>;
499
500 agent.receive_event_async(event1).await.unwrap();
501 agent.receive_event_async(event2).await.unwrap();
502
503 let count = *count_clone.lock().await;
504 assert_eq!(count, 1); }
506}