1use crate::agents::BaseAsyncAgent;
8use std::any::TypeId;
9use std::collections::HashMap;
10use std::sync::Arc;
11
12pub struct Router {
27 routes: HashMap<TypeId, Vec<Arc<dyn BaseAsyncAgent>>>,
28}
29
30impl Router {
31 pub fn new() -> Self {
33 Self {
34 routes: HashMap::new(),
35 }
36 }
37
38 pub fn add_route<T: 'static>(&mut self, agent: Arc<dyn BaseAsyncAgent>) {
44 let type_id = TypeId::of::<T>();
45 self.routes.entry(type_id).or_default().push(agent);
46 }
47
48 pub fn get_agents(&self, type_id: TypeId) -> Vec<Arc<dyn BaseAsyncAgent>> {
54 self.routes.get(&type_id).cloned().unwrap_or_default()
55 }
56}
57
58impl Default for Router {
59 fn default() -> Self {
60 Self::new()
61 }
62}
63
64#[cfg(test)]
65mod tests {
66 use super::*;
67 use crate::agents::BaseAsyncAgent;
68 use crate::event::Event;
69 use crate::Result;
70 use async_trait::async_trait;
71
72 #[derive(Debug)]
73 struct TestEvent1;
74 #[derive(Debug)]
75 struct TestEvent2;
76
77 struct TestAgent;
78
79 #[async_trait]
80 impl BaseAsyncAgent for TestAgent {
81 async fn receive_event_async(&self, _event: Box<dyn Event>) -> Result<Vec<Box<dyn Event>>> {
82 Ok(vec![])
83 }
84 }
85
86 #[test]
87 fn test_router_new() {
88 let router = Router::new();
89 assert_eq!(router.routes.len(), 0);
90 }
91
92 #[test]
93 fn test_router_default() {
94 let router = Router::default();
95 assert_eq!(router.routes.len(), 0);
96 }
97
98 #[test]
99 fn test_add_route() {
100 let mut router = Router::new();
101 let agent = Arc::new(TestAgent);
102
103 router.add_route::<TestEvent1>(agent.clone());
104
105 let agents = router.get_agents(TypeId::of::<TestEvent1>());
106 assert_eq!(agents.len(), 1);
107 }
108
109 #[test]
110 fn test_add_multiple_routes_same_type() {
111 let mut router = Router::new();
112 let agent1 = Arc::new(TestAgent);
113 let agent2 = Arc::new(TestAgent);
114
115 router.add_route::<TestEvent1>(agent1);
116 router.add_route::<TestEvent1>(agent2);
117
118 let agents = router.get_agents(TypeId::of::<TestEvent1>());
119 assert_eq!(agents.len(), 2);
120 }
121
122 #[test]
123 fn test_add_routes_different_types() {
124 let mut router = Router::new();
125 let agent1 = Arc::new(TestAgent);
126 let agent2 = Arc::new(TestAgent);
127
128 router.add_route::<TestEvent1>(agent1);
129 router.add_route::<TestEvent2>(agent2);
130
131 let agents1 = router.get_agents(TypeId::of::<TestEvent1>());
132 let agents2 = router.get_agents(TypeId::of::<TestEvent2>());
133
134 assert_eq!(agents1.len(), 1);
135 assert_eq!(agents2.len(), 1);
136 }
137
138 #[test]
139 fn test_get_agents_no_routes() {
140 let router = Router::new();
141 let agents = router.get_agents(TypeId::of::<TestEvent1>());
142 assert_eq!(agents.len(), 0);
143 }
144
145 #[test]
146 fn test_get_agents_different_type() {
147 let mut router = Router::new();
148 let agent = Arc::new(TestAgent);
149
150 router.add_route::<TestEvent1>(agent);
151
152 let agents = router.get_agents(TypeId::of::<TestEvent2>());
154 assert_eq!(agents.len(), 0);
155 }
156}