1use crate::agents::BaseAsyncAgent;
8use crate::event::Event;
9use crate::llm::{LlmBroker, LlmMessage, LlmTool};
10use crate::Result;
11use async_trait::async_trait;
12use serde::{Deserialize, Serialize};
13use std::sync::Arc;
14
15pub struct AsyncLlmAgent {
35 broker: Arc<LlmBroker>,
36 behaviour: String,
37 tools: Vec<Box<dyn LlmTool>>,
38}
39
40impl AsyncLlmAgent {
41 pub fn new(
49 broker: Arc<LlmBroker>,
50 behaviour: impl Into<String>,
51 tools: Option<Vec<Box<dyn LlmTool>>>,
52 ) -> Self {
53 Self {
54 broker,
55 behaviour: behaviour.into(),
56 tools: tools.unwrap_or_default(),
57 }
58 }
59
60 pub fn add_tool(&mut self, tool: Box<dyn LlmTool>) {
66 self.tools.push(tool);
67 }
68
69 pub async fn generate_response(
80 &self,
81 content: &str,
82 correlation_id: Option<String>,
83 ) -> Result<String> {
84 let messages = vec![
85 LlmMessage::system(&self.behaviour),
86 LlmMessage::user(content),
87 ];
88
89 let tools = if self.tools.is_empty() {
90 None
91 } else {
92 Some(self.tools.as_slice())
93 };
94
95 self.broker.generate(&messages, tools, None, correlation_id).await
96 }
97
98 pub async fn generate_object<T>(
109 &self,
110 content: &str,
111 correlation_id: Option<String>,
112 ) -> Result<T>
113 where
114 T: for<'de> Deserialize<'de> + Serialize + schemars::JsonSchema + Send,
115 {
116 let messages = vec![
117 LlmMessage::system(&self.behaviour),
118 LlmMessage::user(content),
119 ];
120
121 self.broker.generate_object(&messages, None, correlation_id).await
122 }
123}
124
125#[async_trait]
126impl BaseAsyncAgent for AsyncLlmAgent {
127 async fn receive_event_async(&self, _event: Box<dyn Event>) -> Result<Vec<Box<dyn Event>>> {
128 Ok(vec![])
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 use super::*;
137 use crate::llm::gateway::{CompletionConfig, StreamChunk};
138 use crate::llm::{LlmGateway, LlmGatewayResponse};
139 use futures::stream::{self, Stream};
140 use serde_json::Value;
141 use std::collections::HashMap;
142 use std::pin::Pin;
143
144 struct MockGateway {
146 response: String,
147 }
148
149 impl MockGateway {
150 fn new(response: impl Into<String>) -> Self {
151 Self {
152 response: response.into(),
153 }
154 }
155 }
156
157 #[async_trait::async_trait]
158 impl LlmGateway for MockGateway {
159 async fn complete(
160 &self,
161 _model: &str,
162 _messages: &[LlmMessage],
163 _tools: Option<&[Box<dyn LlmTool>]>,
164 _config: &CompletionConfig,
165 ) -> Result<LlmGatewayResponse> {
166 Ok(LlmGatewayResponse {
167 content: Some(self.response.clone()),
168 object: None,
169 tool_calls: vec![],
170 thinking: None,
171 })
172 }
173
174 async fn complete_json(
175 &self,
176 _model: &str,
177 _messages: &[LlmMessage],
178 _schema: Value,
179 _config: &CompletionConfig,
180 ) -> Result<Value> {
181 Ok(serde_json::json!({
182 "message": self.response,
183 "confidence": 0.95
184 }))
185 }
186
187 async fn get_available_models(&self) -> Result<Vec<String>> {
188 Ok(vec!["test-model".to_string()])
189 }
190
191 async fn calculate_embeddings(
192 &self,
193 _text: &str,
194 _model: Option<&str>,
195 ) -> Result<Vec<f32>> {
196 Ok(vec![0.1, 0.2, 0.3])
197 }
198
199 fn complete_stream<'a>(
200 &'a self,
201 _model: &'a str,
202 _messages: &'a [LlmMessage],
203 _tools: Option<&'a [Box<dyn LlmTool>]>,
204 _config: &'a CompletionConfig,
205 ) -> Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send + 'a>> {
206 Box::pin(stream::iter(vec![Ok(StreamChunk::Content(self.response.clone()))]))
207 }
208 }
209
210 #[tokio::test]
211 async fn test_new_agent() {
212 let gateway = Arc::new(MockGateway::new("test response"));
213 let broker = Arc::new(LlmBroker::new("test-model", gateway, None));
214 let agent = AsyncLlmAgent::new(broker, "You are helpful", None);
215
216 assert_eq!(agent.behaviour, "You are helpful");
217 assert_eq!(agent.tools.len(), 0);
218 }
219
220 #[tokio::test]
221 async fn test_new_agent_with_tools() {
222 use crate::llm::tools::{FunctionDescriptor, ToolDescriptor};
223
224 #[derive(Clone)]
225 struct MockTool;
226 impl LlmTool for MockTool {
227 fn run(&self, _args: &HashMap<String, Value>) -> Result<Value> {
228 Ok(serde_json::json!({"result": "ok"}))
229 }
230 fn descriptor(&self) -> ToolDescriptor {
231 ToolDescriptor {
232 r#type: "function".to_string(),
233 function: FunctionDescriptor {
234 name: "mock_tool".to_string(),
235 description: "A mock tool".to_string(),
236 parameters: serde_json::json!({}),
237 },
238 }
239 }
240 fn clone_box(&self) -> Box<dyn LlmTool> {
241 Box::new(self.clone())
242 }
243 }
244
245 let gateway = Arc::new(MockGateway::new("test"));
246 let broker = Arc::new(LlmBroker::new("test-model", gateway, None));
247 let tools: Vec<Box<dyn LlmTool>> = vec![Box::new(MockTool)];
248 let agent = AsyncLlmAgent::new(broker, "You are helpful", Some(tools));
249
250 assert_eq!(agent.tools.len(), 1);
251 }
252
253 #[tokio::test]
254 async fn test_add_tool() {
255 use crate::llm::tools::{FunctionDescriptor, ToolDescriptor};
256
257 #[derive(Clone)]
258 struct MockTool;
259 impl LlmTool for MockTool {
260 fn run(&self, _args: &HashMap<String, Value>) -> Result<Value> {
261 Ok(serde_json::json!({"result": "ok"}))
262 }
263 fn descriptor(&self) -> ToolDescriptor {
264 ToolDescriptor {
265 r#type: "function".to_string(),
266 function: FunctionDescriptor {
267 name: "mock_tool".to_string(),
268 description: "A mock tool".to_string(),
269 parameters: serde_json::json!({}),
270 },
271 }
272 }
273 fn clone_box(&self) -> Box<dyn LlmTool> {
274 Box::new(self.clone())
275 }
276 }
277
278 let gateway = Arc::new(MockGateway::new("test"));
279 let broker = Arc::new(LlmBroker::new("test-model", gateway, None));
280 let mut agent = AsyncLlmAgent::new(broker, "You are helpful", None);
281
282 assert_eq!(agent.tools.len(), 0);
283 agent.add_tool(Box::new(MockTool));
284 assert_eq!(agent.tools.len(), 1);
285 }
286
287 #[tokio::test]
288 async fn test_generate_response() {
289 let gateway = Arc::new(MockGateway::new("Hello from LLM"));
290 let broker = Arc::new(LlmBroker::new("test-model", gateway, None));
291 let agent = AsyncLlmAgent::new(broker, "You are helpful", None);
292
293 let response = agent.generate_response("Test message", None).await.unwrap();
294
295 assert_eq!(response, "Hello from LLM");
296 }
297
298 #[tokio::test]
299 async fn test_generate_response_with_correlation_id() {
300 let gateway = Arc::new(MockGateway::new("Response"));
301 let broker = Arc::new(LlmBroker::new("test-model", gateway, None));
302 let agent = AsyncLlmAgent::new(broker, "You are helpful", None);
303
304 let response = agent
305 .generate_response("Test", Some("correlation-123".to_string()))
306 .await
307 .unwrap();
308
309 assert_eq!(response, "Response");
310 }
311
312 #[tokio::test]
313 async fn test_generate_object() {
314 #[derive(Debug, Serialize, Deserialize, schemars::JsonSchema)]
315 struct TestResponse {
316 message: String,
317 confidence: f64,
318 }
319
320 let gateway = Arc::new(MockGateway::new("Test message"));
321 let broker = Arc::new(LlmBroker::new("test-model", gateway, None));
322 let agent = AsyncLlmAgent::new(broker, "You are helpful", None);
323
324 let response: TestResponse = agent.generate_object("Generate object", None).await.unwrap();
325
326 assert_eq!(response.message, "Test message");
327 assert_eq!(response.confidence, 0.95);
328 }
329
330 #[tokio::test]
331 async fn test_receive_event_async_default() {
332 use crate::event::Event;
333 use serde::{Deserialize, Serialize};
334 use std::any::Any;
335
336 #[derive(Debug, Clone, Serialize, Deserialize)]
337 struct TestEvent {
338 source: String,
339 correlation_id: Option<String>,
340 }
341
342 impl Event for TestEvent {
343 fn source(&self) -> &str {
344 &self.source
345 }
346 fn correlation_id(&self) -> Option<&str> {
347 self.correlation_id.as_deref()
348 }
349 fn set_correlation_id(&mut self, id: String) {
350 self.correlation_id = Some(id);
351 }
352 fn as_any(&self) -> &dyn Any {
353 self
354 }
355 fn clone_box(&self) -> Box<dyn Event> {
356 Box::new(self.clone())
357 }
358 }
359
360 let gateway = Arc::new(MockGateway::new("test"));
361 let broker = Arc::new(LlmBroker::new("test-model", gateway, None));
362 let agent = AsyncLlmAgent::new(broker, "You are helpful", None);
363
364 let event = Box::new(TestEvent {
365 source: "Test".to_string(),
366 correlation_id: None,
367 }) as Box<dyn Event>;
368
369 let result = agent.receive_event_async(event).await.unwrap();
370 assert_eq!(result.len(), 0); }
372
373 #[tokio::test]
374 async fn test_agent_with_custom_behaviour() {
375 let gateway = Arc::new(MockGateway::new("Custom response"));
376 let broker = Arc::new(LlmBroker::new("test-model", gateway, None));
377 let agent =
378 AsyncLlmAgent::new(broker, "You are a specialized agent with custom behavior", None);
379
380 assert_eq!(agent.behaviour, "You are a specialized agent with custom behavior");
381
382 let response = agent.generate_response("Test", None).await.unwrap();
383 assert_eq!(response, "Custom response");
384 }
385}