mojentic/
async_dispatcher.rs1use crate::event::{Event, TerminateEvent};
7use crate::router::Router;
8use crate::{MojenticError, Result};
9use std::collections::VecDeque;
10use std::sync::atomic::{AtomicBool, Ordering};
11use std::sync::Arc;
12use std::time::Duration;
13use tokio::sync::Mutex;
14use tokio::task::JoinHandle;
15use tracing::{debug, info};
16use uuid::Uuid;
17
18pub struct AsyncDispatcher {
39 router: Arc<Router>,
40 event_queue: Arc<Mutex<VecDeque<Box<dyn Event>>>>,
41 stop_flag: Arc<AtomicBool>,
42 task_handle: Option<JoinHandle<()>>,
43 batch_size: usize,
44}
45
46impl AsyncDispatcher {
47 pub fn new(router: Arc<Router>) -> Self {
53 Self {
54 router,
55 event_queue: Arc::new(Mutex::new(VecDeque::new())),
56 stop_flag: Arc::new(AtomicBool::new(false)),
57 task_handle: None,
58 batch_size: 5,
59 }
60 }
61
62 pub fn with_batch_size(mut self, size: usize) -> Self {
68 self.batch_size = size;
69 self
70 }
71
72 pub async fn start(&mut self) -> Result<()> {
76 if self.task_handle.is_some() {
77 return Err(MojenticError::DispatcherError("Dispatcher already started".to_string()));
78 }
79
80 debug!("Starting async dispatcher");
81 self.stop_flag.store(false, Ordering::Relaxed);
82
83 let router = self.router.clone();
84 let queue = self.event_queue.clone();
85 let stop_flag = self.stop_flag.clone();
86 let batch_size = self.batch_size;
87
88 let handle = tokio::spawn(async move {
89 Self::dispatch_loop(router, queue, stop_flag, batch_size).await;
90 });
91
92 self.task_handle = Some(handle);
93 info!("Async dispatcher started");
94
95 Ok(())
96 }
97
98 pub async fn stop(&mut self) -> Result<()> {
102 if let Some(handle) = self.task_handle.take() {
103 debug!("Stopping async dispatcher");
104 self.stop_flag.store(true, Ordering::Relaxed);
105 handle.await.map_err(|e| {
106 MojenticError::DispatcherError(format!("Failed to stop dispatcher: {}", e))
107 })?;
108 info!("Async dispatcher stopped");
109 }
110
111 Ok(())
112 }
113
114 pub fn dispatch(&self, mut event: Box<dyn Event>) {
122 if event.correlation_id().is_none() {
124 event.set_correlation_id(Uuid::new_v4().to_string());
125 }
126
127 let queue = self.event_queue.clone();
128 tokio::spawn(async move {
129 let mut q = queue.lock().await;
130 debug!("Dispatching event: {:?}", event);
131 q.push_back(event);
132 });
133 }
134
135 pub async fn wait_for_empty_queue(&self, timeout: Option<Duration>) -> Result<bool> {
145 let start = tokio::time::Instant::now();
146
147 loop {
148 let len = {
149 let queue = self.event_queue.lock().await;
150 queue.len()
151 };
152
153 if len == 0 {
154 return Ok(true);
155 }
156
157 if let Some(timeout_duration) = timeout {
158 if start.elapsed() > timeout_duration {
159 return Ok(false);
160 }
161 }
162
163 tokio::time::sleep(Duration::from_millis(100)).await;
164 }
165 }
166
167 pub async fn queue_len(&self) -> usize {
169 let queue = self.event_queue.lock().await;
170 queue.len()
171 }
172
173 async fn dispatch_loop(
175 router: Arc<Router>,
176 queue: Arc<Mutex<VecDeque<Box<dyn Event>>>>,
177 stop_flag: Arc<AtomicBool>,
178 batch_size: usize,
179 ) {
180 while !stop_flag.load(Ordering::Relaxed) {
181 for _ in 0..batch_size {
182 let event = {
183 let mut q = queue.lock().await;
184 q.pop_front()
185 };
186
187 if let Some(event) = event {
188 debug!("Processing event: {:?}", event);
189
190 if event.as_any().is::<TerminateEvent>() {
192 info!("Received TerminateEvent, stopping dispatcher");
193 stop_flag.store(true, Ordering::Relaxed);
194 break;
195 }
196
197 let type_id = event.as_any().type_id();
199
200 let agents = router.get_agents(type_id);
202 debug!("Found {} agents for event type", agents.len());
203
204 for agent in agents {
206 debug!("Sending event to agent");
207 match agent.receive_event_async(event.clone_box()).await {
208 Ok(new_events) => {
209 debug!("Agent returned {} events", new_events.len());
210 let mut q = queue.lock().await;
212 for new_event in new_events {
213 q.push_back(new_event);
214 }
215 }
216 Err(e) => {
217 tracing::error!("Agent error processing event: {}", e);
218 }
219 }
220 }
221 }
222 }
223
224 tokio::time::sleep(Duration::from_millis(100)).await;
225 }
226
227 debug!("Dispatch loop exiting");
228 }
229}
230
231impl Drop for AsyncDispatcher {
232 fn drop(&mut self) {
233 self.stop_flag.store(true, Ordering::Relaxed);
234 }
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240 use crate::agents::BaseAsyncAgent;
241 use async_trait::async_trait;
242 use serde::{Deserialize, Serialize};
243 use std::any::Any;
244
245 #[derive(Debug, Clone, Serialize, Deserialize)]
246 struct TestEvent {
247 source: String,
248 correlation_id: Option<String>,
249 data: String,
250 }
251
252 impl Event for TestEvent {
253 fn source(&self) -> &str {
254 &self.source
255 }
256 fn correlation_id(&self) -> Option<&str> {
257 self.correlation_id.as_deref()
258 }
259 fn set_correlation_id(&mut self, id: String) {
260 self.correlation_id = Some(id);
261 }
262 fn as_any(&self) -> &dyn Any {
263 self
264 }
265 fn clone_box(&self) -> Box<dyn Event> {
266 Box::new(self.clone())
267 }
268 }
269
270 struct CountingAgent {
271 count: Arc<Mutex<usize>>,
272 }
273
274 #[async_trait]
275 impl BaseAsyncAgent for CountingAgent {
276 async fn receive_event_async(&self, _event: Box<dyn Event>) -> Result<Vec<Box<dyn Event>>> {
277 let mut count = self.count.lock().await;
278 *count += 1;
279 Ok(vec![])
280 }
281 }
282
283 #[tokio::test]
284 async fn test_dispatcher_new() {
285 let router = Arc::new(Router::new());
286 let dispatcher = AsyncDispatcher::new(router);
287 assert!(dispatcher.task_handle.is_none());
288 assert_eq!(dispatcher.batch_size, 5);
289 }
290
291 #[tokio::test]
292 async fn test_dispatcher_with_batch_size() {
293 let router = Arc::new(Router::new());
294 let dispatcher = AsyncDispatcher::new(router).with_batch_size(10);
295 assert_eq!(dispatcher.batch_size, 10);
296 }
297
298 #[tokio::test]
299 async fn test_start_and_stop() {
300 let router = Arc::new(Router::new());
301 let mut dispatcher = AsyncDispatcher::new(router);
302
303 dispatcher.start().await.unwrap();
304 assert!(dispatcher.task_handle.is_some());
305
306 dispatcher.stop().await.unwrap();
307 assert!(dispatcher.task_handle.is_none());
308 }
309
310 #[tokio::test]
311 async fn test_start_twice_fails() {
312 let router = Arc::new(Router::new());
313 let mut dispatcher = AsyncDispatcher::new(router);
314
315 dispatcher.start().await.unwrap();
316 let result = dispatcher.start().await;
317 assert!(result.is_err());
318
319 dispatcher.stop().await.unwrap();
320 }
321
322 #[tokio::test]
323 async fn test_dispatch_event() {
324 let mut router = Router::new();
325 let count = Arc::new(Mutex::new(0));
326 let agent = Arc::new(CountingAgent {
327 count: count.clone(),
328 });
329
330 router.add_route::<TestEvent>(agent);
331
332 let mut dispatcher = AsyncDispatcher::new(Arc::new(router));
333 dispatcher.start().await.unwrap();
334
335 let event = Box::new(TestEvent {
336 source: "Test".to_string(),
337 correlation_id: Some("test-123".to_string()),
338 data: "test".to_string(),
339 }) as Box<dyn Event>;
340
341 dispatcher.dispatch(event);
342
343 tokio::time::sleep(Duration::from_millis(500)).await;
345
346 let final_count = *count.lock().await;
347 assert_eq!(final_count, 1);
348
349 dispatcher.stop().await.unwrap();
350 }
351
352 #[tokio::test]
353 async fn test_dispatch_assigns_correlation_id() {
354 let router = Arc::new(Router::new());
355 let mut dispatcher = AsyncDispatcher::new(router);
356 dispatcher.start().await.unwrap();
357
358 let event = Box::new(TestEvent {
359 source: "Test".to_string(),
360 correlation_id: None,
361 data: "test".to_string(),
362 }) as Box<dyn Event>;
363
364 dispatcher.dispatch(event);
365
366 tokio::time::sleep(Duration::from_millis(100)).await;
368
369 dispatcher.stop().await.unwrap();
370 }
371
372 #[tokio::test]
373 async fn test_wait_for_empty_queue() {
374 let mut router = Router::new();
375 let count = Arc::new(Mutex::new(0));
376 let agent = Arc::new(CountingAgent {
377 count: count.clone(),
378 });
379
380 router.add_route::<TestEvent>(agent);
381
382 let mut dispatcher = AsyncDispatcher::new(Arc::new(router));
383 dispatcher.start().await.unwrap();
384
385 let event = Box::new(TestEvent {
386 source: "Test".to_string(),
387 correlation_id: Some("test-456".to_string()),
388 data: "test".to_string(),
389 }) as Box<dyn Event>;
390
391 dispatcher.dispatch(event);
392
393 let result = dispatcher.wait_for_empty_queue(Some(Duration::from_secs(2))).await.unwrap();
394
395 assert!(result);
396 dispatcher.stop().await.unwrap();
397 }
398
399 #[tokio::test]
400 async fn test_wait_for_empty_queue_timeout() {
401 use std::sync::atomic::{AtomicUsize, Ordering};
402
403 struct SlowAgent {
404 processing_count: Arc<AtomicUsize>,
405 }
406
407 #[async_trait]
408 impl BaseAsyncAgent for SlowAgent {
409 async fn receive_event_async(
410 &self,
411 _event: Box<dyn Event>,
412 ) -> Result<Vec<Box<dyn Event>>> {
413 tokio::time::sleep(Duration::from_millis(200)).await;
415 self.processing_count.fetch_add(1, Ordering::Relaxed);
416 Ok(vec![])
417 }
418 }
419
420 let mut router = Router::new();
421 let processing_count = Arc::new(AtomicUsize::new(0));
422 let agent = Arc::new(SlowAgent {
423 processing_count: processing_count.clone(),
424 });
425 router.add_route::<TestEvent>(agent);
426
427 let mut dispatcher = AsyncDispatcher::new(Arc::new(router));
428 dispatcher.start().await.unwrap();
429
430 for i in 0..10 {
432 let event = Box::new(TestEvent {
433 source: "Test".to_string(),
434 correlation_id: Some(format!("slow-{}", i)),
435 data: "test".to_string(),
436 }) as Box<dyn Event>;
437 dispatcher.dispatch(event);
438 }
439
440 tokio::time::sleep(Duration::from_millis(100)).await;
442
443 let result =
445 dispatcher.wait_for_empty_queue(Some(Duration::from_millis(300))).await.unwrap();
446
447 assert!(!result); dispatcher.stop().await.unwrap();
450 }
451
452 #[tokio::test]
453 async fn test_queue_len() {
454 let router = Arc::new(Router::new());
455 let mut dispatcher = AsyncDispatcher::new(router);
456
457 assert_eq!(dispatcher.queue_len().await, 0);
458
459 dispatcher.start().await.unwrap();
460
461 let event = Box::new(TestEvent {
462 source: "Test".to_string(),
463 correlation_id: Some("len-test".to_string()),
464 data: "test".to_string(),
465 }) as Box<dyn Event>;
466
467 dispatcher.dispatch(event);
468
469 tokio::time::sleep(Duration::from_millis(100)).await;
471
472 dispatcher.stop().await.unwrap();
473 }
474
475 #[tokio::test]
476 async fn test_terminate_event_stops_dispatcher() {
477 let mut router = Router::new();
478 let count = Arc::new(Mutex::new(0));
479 let agent = Arc::new(CountingAgent {
480 count: count.clone(),
481 });
482
483 router.add_route::<TestEvent>(agent.clone());
484 router.add_route::<TerminateEvent>(agent);
485
486 let mut dispatcher = AsyncDispatcher::new(Arc::new(router));
487 dispatcher.start().await.unwrap();
488
489 let event = Box::new(TestEvent {
491 source: "Test".to_string(),
492 correlation_id: Some("before-stop".to_string()),
493 data: "test".to_string(),
494 }) as Box<dyn Event>;
495
496 dispatcher.dispatch(event);
497
498 let terminate = Box::new(TerminateEvent::new("System")) as Box<dyn Event>;
500 dispatcher.dispatch(terminate);
501
502 tokio::time::sleep(Duration::from_secs(1)).await;
504
505 assert!(dispatcher.stop_flag.load(Ordering::Relaxed));
507
508 dispatcher.stop().await.unwrap();
509 }
510
511 #[tokio::test]
512 async fn test_multiple_agents_receive_event() {
513 let mut router = Router::new();
514 let count1 = Arc::new(Mutex::new(0));
515 let count2 = Arc::new(Mutex::new(0));
516
517 let agent1 = Arc::new(CountingAgent {
518 count: count1.clone(),
519 });
520 let agent2 = Arc::new(CountingAgent {
521 count: count2.clone(),
522 });
523
524 router.add_route::<TestEvent>(agent1);
525 router.add_route::<TestEvent>(agent2);
526
527 let mut dispatcher = AsyncDispatcher::new(Arc::new(router));
528 dispatcher.start().await.unwrap();
529
530 let event = Box::new(TestEvent {
531 source: "Test".to_string(),
532 correlation_id: Some("multi-agent".to_string()),
533 data: "test".to_string(),
534 }) as Box<dyn Event>;
535
536 dispatcher.dispatch(event);
537
538 tokio::time::sleep(Duration::from_millis(500)).await;
540
541 assert_eq!(*count1.lock().await, 1);
542 assert_eq!(*count2.lock().await, 1);
543
544 dispatcher.stop().await.unwrap();
545 }
546}