1use crate::error::Result;
2use crate::llm::gateway::{CompletionConfig, LlmGateway, StreamChunk};
3use crate::llm::models::{LlmGatewayResponse, LlmMessage, MessageRole};
4use crate::llm::tools::LlmTool;
5use crate::tracer::TracerSystem;
6use futures::stream::{Stream, StreamExt};
7use serde::{Deserialize, Serialize};
8use std::pin::Pin;
9use std::sync::Arc;
10use tracing::{info, warn};
11use uuid::Uuid;
12
13#[derive(Clone)]
15pub struct LlmBroker {
16 model: String,
17 gateway: Arc<dyn LlmGateway>,
18 tracer: Option<Arc<TracerSystem>>,
19}
20
21impl LlmBroker {
22 pub fn new(
30 model: impl Into<String>,
31 gateway: Arc<dyn LlmGateway>,
32 tracer: Option<Arc<TracerSystem>>,
33 ) -> Self {
34 Self {
35 model: model.into(),
36 gateway,
37 tracer,
38 }
39 }
40
41 pub async fn generate(
50 &self,
51 messages: &[LlmMessage],
52 tools: Option<&[Box<dyn LlmTool>]>,
53 config: Option<CompletionConfig>,
54 correlation_id: Option<String>,
55 ) -> Result<String> {
56 let config = config.unwrap_or_default();
57 let current_messages = messages.to_vec();
58 let correlation_id = correlation_id.unwrap_or_else(|| Uuid::new_v4().to_string());
59
60 if let Some(tracer) = &self.tracer {
62 let messages_json: Vec<std::collections::HashMap<String, serde_json::Value>> =
63 current_messages
64 .iter()
65 .map(|m| {
66 let mut map = std::collections::HashMap::new();
67 map.insert("role".to_string(), serde_json::json!(format!("{:?}", m.role)));
68 if let Some(content) = &m.content {
69 map.insert("content".to_string(), serde_json::json!(content));
70 }
71 map
72 })
73 .collect();
74
75 let tools_json = tools.map(|t| {
76 t.iter()
77 .map(|tool| {
78 let desc = tool.descriptor();
79 let mut map = std::collections::HashMap::new();
80 map.insert("name".to_string(), serde_json::json!(desc.function.name));
81 map.insert(
82 "description".to_string(),
83 serde_json::json!(desc.function.description),
84 );
85 map
86 })
87 .collect()
88 });
89
90 tracer.record_llm_call(
91 &self.model,
92 messages_json,
93 config.temperature as f64,
94 tools_json,
95 "LlmBroker",
96 &correlation_id,
97 );
98 }
99
100 let start = std::time::Instant::now();
102
103 let response =
105 self.gateway.complete(&self.model, ¤t_messages, tools, &config).await?;
106
107 let call_duration_ms = start.elapsed().as_secs_f64() * 1000.0;
108
109 if let Some(tracer) = &self.tracer {
111 let tool_calls_json = if !response.tool_calls.is_empty() {
112 Some(
113 response
114 .tool_calls
115 .iter()
116 .map(|tc| {
117 let mut map = std::collections::HashMap::new();
118 map.insert("name".to_string(), serde_json::json!(&tc.name));
119 if let Some(id) = &tc.id {
120 map.insert("id".to_string(), serde_json::json!(id));
121 }
122 map
123 })
124 .collect(),
125 )
126 } else {
127 None
128 };
129
130 tracer.record_llm_response(
131 &self.model,
132 response.content.as_ref().unwrap_or(&String::new()),
133 tool_calls_json,
134 Some(call_duration_ms),
135 "LlmBroker",
136 &correlation_id,
137 );
138 }
139
140 if !response.tool_calls.is_empty() {
142 if let Some(tools) = tools {
143 return self
144 .handle_tool_calls(current_messages, response, tools, &config, &correlation_id)
145 .await;
146 }
147 }
148
149 Ok(response.content.unwrap_or_default())
150 }
151
152 fn handle_tool_calls<'a>(
153 &'a self,
154 mut messages: Vec<LlmMessage>,
155 response: LlmGatewayResponse,
156 tools: &'a [Box<dyn LlmTool>],
157 config: &'a CompletionConfig,
158 correlation_id: &'a str,
159 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<String>> + Send + 'a>> {
160 Box::pin(async move {
161 info!("Tool calls requested: {}", response.tool_calls.len());
162
163 for tool_call in &response.tool_calls {
164 if let Some(tool) = tools.iter().find(|t| t.matches(&tool_call.name)) {
166 info!("Executing tool: {}", tool_call.name);
167
168 let start = std::time::Instant::now();
170 let output = tool.run(&tool_call.arguments)?;
171 let tool_duration_ms = start.elapsed().as_secs_f64() * 1000.0;
172
173 if let Some(tracer) = &self.tracer {
175 tracer.record_tool_call(
176 &tool_call.name,
177 tool_call.arguments.clone(),
178 output.clone(),
179 Some("LlmBroker".to_string()),
180 Some(tool_duration_ms),
181 "LlmBroker",
182 correlation_id,
183 );
184 }
185
186 messages.push(LlmMessage {
188 role: MessageRole::Assistant,
189 content: None,
190 tool_calls: Some(vec![tool_call.clone()]),
191 image_paths: None,
192 });
193 messages.push(LlmMessage {
194 role: MessageRole::Tool,
195 content: Some(serde_json::to_string(&output)?),
196 tool_calls: Some(vec![tool_call.clone()]),
197 image_paths: None,
198 });
199
200 return self
202 .generate(
203 &messages,
204 Some(tools),
205 Some(config.clone()),
206 Some(correlation_id.to_string()),
207 )
208 .await;
209 } else {
210 warn!("Tool not found: {}", tool_call.name);
211 }
212 }
213
214 Ok(response.content.unwrap_or_default())
215 })
216 }
217
218 pub async fn generate_object<T>(
226 &self,
227 messages: &[LlmMessage],
228 config: Option<CompletionConfig>,
229 correlation_id: Option<String>,
230 ) -> Result<T>
231 where
232 T: for<'de> Deserialize<'de> + Serialize + schemars::JsonSchema + Send,
233 {
234 let config = config.unwrap_or_default();
235 let correlation_id = correlation_id.unwrap_or_else(|| Uuid::new_v4().to_string());
236
237 let schema = serde_json::to_value(schemars::schema_for!(T))?;
239
240 if let Some(tracer) = &self.tracer {
242 let messages_json: Vec<std::collections::HashMap<String, serde_json::Value>> = messages
243 .iter()
244 .map(|m| {
245 let mut map = std::collections::HashMap::new();
246 map.insert("role".to_string(), serde_json::json!(format!("{:?}", m.role)));
247 if let Some(content) = &m.content {
248 map.insert("content".to_string(), serde_json::json!(content));
249 }
250 map
251 })
252 .collect();
253
254 tracer.record_llm_call(
255 &self.model,
256 messages_json,
257 config.temperature as f64,
258 None,
259 "LlmBroker::generate_object",
260 &correlation_id,
261 );
262 }
263
264 let start = std::time::Instant::now();
266
267 let json_response =
269 self.gateway.complete_json(&self.model, messages, schema, &config).await?;
270
271 let call_duration_ms = start.elapsed().as_secs_f64() * 1000.0;
272
273 let object: T = serde_json::from_value(json_response.clone())?;
275
276 if let Some(tracer) = &self.tracer {
278 let object_str = serde_json::to_string_pretty(&json_response).unwrap_or_default();
279 tracer.record_llm_response(
280 &self.model,
281 format!("Structured response: {}", object_str),
282 None,
283 Some(call_duration_ms),
284 "LlmBroker::generate_object",
285 &correlation_id,
286 );
287 }
288
289 Ok(object)
290 }
291
292 pub fn generate_stream<'a>(
322 &'a self,
323 messages: &'a [LlmMessage],
324 tools: Option<&'a [Box<dyn LlmTool>]>,
325 config: Option<CompletionConfig>,
326 correlation_id: Option<String>,
327 ) -> Pin<Box<dyn Stream<Item = Result<String>> + 'a>> {
328 let config = config.unwrap_or_default();
329 let current_messages = messages.to_vec();
330 let correlation_id = correlation_id.unwrap_or_else(|| Uuid::new_v4().to_string());
331
332 Box::pin(async_stream::stream! {
333 if let Some(tracer) = &self.tracer {
335 let messages_json: Vec<std::collections::HashMap<String, serde_json::Value>> =
336 current_messages
337 .iter()
338 .map(|m| {
339 let mut map = std::collections::HashMap::new();
340 map.insert("role".to_string(), serde_json::json!(format!("{:?}", m.role)));
341 if let Some(content) = &m.content {
342 map.insert("content".to_string(), serde_json::json!(content));
343 }
344 map
345 })
346 .collect();
347
348 let tools_json = tools.map(|t| {
349 t.iter()
350 .map(|tool| {
351 let desc = tool.descriptor();
352 let mut map = std::collections::HashMap::new();
353 map.insert("name".to_string(), serde_json::json!(desc.function.name));
354 map.insert(
355 "description".to_string(),
356 serde_json::json!(desc.function.description),
357 );
358 map
359 })
360 .collect()
361 });
362
363 tracer.record_llm_call(
364 &self.model,
365 messages_json,
366 config.temperature as f64,
367 tools_json,
368 "LlmBroker::generate_stream",
369 &correlation_id,
370 );
371 }
372
373 let mut accumulated_content = String::new();
374 let mut accumulated_tool_calls = Vec::new();
375
376 let start = std::time::Instant::now();
378
379 let mut stream = self.gateway.complete_stream(
381 &self.model,
382 ¤t_messages,
383 tools,
384 &config,
385 );
386
387 while let Some(chunk_result) = stream.next().await {
388 match chunk_result {
389 Ok(StreamChunk::Content(content)) => {
390 accumulated_content.push_str(&content);
391 yield Ok(content);
392 }
393 Ok(StreamChunk::ToolCalls(tool_calls)) => {
394 accumulated_tool_calls = tool_calls;
395 }
396 Err(e) => {
397 yield Err(e);
398 return;
399 }
400 }
401 }
402
403 let call_duration_ms = start.elapsed().as_secs_f64() * 1000.0;
404
405 if let Some(tracer) = &self.tracer {
407 let tool_calls_json = if !accumulated_tool_calls.is_empty() {
408 Some(
409 accumulated_tool_calls
410 .iter()
411 .map(|tc| {
412 let mut map = std::collections::HashMap::new();
413 map.insert("name".to_string(), serde_json::json!(&tc.name));
414 if let Some(id) = &tc.id {
415 map.insert("id".to_string(), serde_json::json!(id));
416 }
417 map
418 })
419 .collect(),
420 )
421 } else {
422 None
423 };
424
425 tracer.record_llm_response(
426 &self.model,
427 &accumulated_content,
428 tool_calls_json,
429 Some(call_duration_ms),
430 "LlmBroker::generate_stream",
431 &correlation_id,
432 );
433 }
434
435 if !accumulated_tool_calls.is_empty() {
437 if let Some(tools) = tools {
438 info!("Processing {} tool call(s) in stream", accumulated_tool_calls.len());
439
440 let mut new_messages = current_messages.clone();
442
443 new_messages.push(LlmMessage {
445 role: MessageRole::Assistant,
446 content: Some(accumulated_content),
447 tool_calls: Some(accumulated_tool_calls.clone()),
448 image_paths: None,
449 });
450
451 for tool_call in &accumulated_tool_calls {
453 if let Some(tool) = tools.iter().find(|t| t.matches(&tool_call.name)) {
454 info!("Executing tool: {}", tool_call.name);
455
456 let tool_start = std::time::Instant::now();
458
459 match tool.run(&tool_call.arguments) {
460 Ok(output) => {
461 let tool_duration_ms = tool_start.elapsed().as_secs_f64() * 1000.0;
462
463 if let Some(tracer) = &self.tracer {
465 tracer.record_tool_call(
466 &tool_call.name,
467 tool_call.arguments.clone(),
468 output.clone(),
469 Some("LlmBroker::generate_stream".to_string()),
470 Some(tool_duration_ms),
471 "LlmBroker::generate_stream",
472 &correlation_id,
473 );
474 }
475
476 let output_str = match serde_json::to_string(&output) {
477 Ok(s) => s,
478 Err(e) => {
479 yield Err(e.into());
480 return;
481 }
482 };
483
484 new_messages.push(LlmMessage {
485 role: MessageRole::Tool,
486 content: Some(output_str),
487 tool_calls: Some(vec![tool_call.clone()]),
488 image_paths: None,
489 });
490 }
491 Err(e) => {
492 warn!("Tool execution failed: {}", e);
493 yield Err(e);
494 return;
495 }
496 }
497 } else {
498 warn!("Tool not found: {}", tool_call.name);
499 }
500 }
501
502 let mut recursive_stream = self.generate_stream(&new_messages, Some(tools), Some(config.clone()), Some(correlation_id.clone()));
504
505 while let Some(result) = recursive_stream.next().await {
506 yield result;
507 }
508 } else {
509 warn!("LLM requested tool calls but no tools provided");
510 }
511 }
512 })
513 }
514}
515
516#[cfg(test)]
517mod tests {
518 use super::*;
519 use crate::llm::models::LlmToolCall;
520 use crate::llm::tools::{FunctionDescriptor, ToolDescriptor};
521 use serde::{Deserialize, Serialize};
522 use serde_json::Value;
523 use std::collections::HashMap;
524
525 struct MockGateway {
527 responses: Vec<LlmGatewayResponse>,
528 call_count: std::sync::Mutex<usize>,
529 }
530
531 impl MockGateway {
532 fn new(responses: Vec<LlmGatewayResponse>) -> Self {
533 Self {
534 responses,
535 call_count: std::sync::Mutex::new(0),
536 }
537 }
538 }
539
540 #[async_trait::async_trait]
541 impl LlmGateway for MockGateway {
542 async fn complete(
543 &self,
544 _model: &str,
545 _messages: &[LlmMessage],
546 _tools: Option<&[Box<dyn LlmTool>]>,
547 _config: &CompletionConfig,
548 ) -> Result<LlmGatewayResponse> {
549 let mut count = self.call_count.lock().unwrap();
550 let idx = *count;
551 *count += 1;
552
553 if idx < self.responses.len() {
554 Ok(self.responses[idx].clone())
555 } else {
556 Ok(LlmGatewayResponse {
557 content: Some("default response".to_string()),
558 object: None,
559 tool_calls: vec![],
560 thinking: None,
561 })
562 }
563 }
564
565 async fn complete_json(
566 &self,
567 _model: &str,
568 _messages: &[LlmMessage],
569 _schema: Value,
570 _config: &CompletionConfig,
571 ) -> Result<Value> {
572 Ok(serde_json::json!({"test": "value"}))
573 }
574
575 async fn get_available_models(&self) -> Result<Vec<String>> {
576 Ok(vec!["test-model".to_string()])
577 }
578
579 async fn calculate_embeddings(
580 &self,
581 _text: &str,
582 _model: Option<&str>,
583 ) -> Result<Vec<f32>> {
584 Ok(vec![0.1, 0.2, 0.3])
585 }
586
587 fn complete_stream<'a>(
588 &'a self,
589 _model: &'a str,
590 _messages: &'a [LlmMessage],
591 _tools: Option<&'a [Box<dyn LlmTool>]>,
592 _config: &'a CompletionConfig,
593 ) -> Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send + 'a>> {
594 use futures::stream;
595 Box::pin(stream::iter(vec![Ok(StreamChunk::Content("test".to_string()))]))
596 }
597 }
598
599 struct MockTool {
601 name: String,
602 result: Value,
603 }
604
605 impl LlmTool for MockTool {
606 fn run(&self, _args: &HashMap<String, Value>) -> Result<Value> {
607 Ok(self.result.clone())
608 }
609
610 fn descriptor(&self) -> ToolDescriptor {
611 ToolDescriptor {
612 r#type: "function".to_string(),
613 function: FunctionDescriptor {
614 name: self.name.clone(),
615 description: "A mock tool".to_string(),
616 parameters: serde_json::json!({}),
617 },
618 }
619 }
620
621 fn clone_box(&self) -> Box<dyn LlmTool> {
622 Box::new(MockTool {
623 name: self.name.clone(),
624 result: self.result.clone(),
625 })
626 }
627 }
628
629 #[tokio::test]
630 async fn test_broker_new() {
631 let gateway = Arc::new(MockGateway::new(vec![]));
632 let broker = LlmBroker::new("test-model", gateway, None);
633 assert_eq!(broker.model, "test-model");
634 }
635
636 #[tokio::test]
637 async fn test_broker_new_string_conversion() {
638 let gateway = Arc::new(MockGateway::new(vec![]));
639 let broker = LlmBroker::new(String::from("my-model"), gateway, None);
640 assert_eq!(broker.model, "my-model");
641 }
642
643 #[tokio::test]
644 async fn test_generate_simple_response() {
645 let response = LlmGatewayResponse {
646 content: Some("Hello, World!".to_string()),
647 object: None,
648 tool_calls: vec![],
649 thinking: None,
650 };
651
652 let gateway = Arc::new(MockGateway::new(vec![response]));
653 let broker = LlmBroker::new("test-model", gateway, None);
654
655 let messages = vec![LlmMessage::user("Hi")];
656 let result = broker.generate(&messages, None, None, None).await.unwrap();
657
658 assert_eq!(result, "Hello, World!");
659 }
660
661 #[tokio::test]
662 async fn test_generate_with_custom_config() {
663 let response = LlmGatewayResponse {
664 content: Some("Response".to_string()),
665 object: None,
666 tool_calls: vec![],
667 thinking: None,
668 };
669
670 let gateway = Arc::new(MockGateway::new(vec![response]));
671 let broker = LlmBroker::new("test-model", gateway, None);
672
673 let config = CompletionConfig {
674 temperature: 0.5,
675 num_ctx: 2048,
676 max_tokens: 100,
677 num_predict: Some(50),
678 top_p: None,
679 top_k: None,
680 response_format: None,
681 reasoning_effort: None,
682 };
683
684 let messages = vec![LlmMessage::user("Hi")];
685 let result = broker.generate(&messages, None, Some(config), None).await.unwrap();
686
687 assert_eq!(result, "Response");
688 }
689
690 #[tokio::test]
691 async fn test_generate_empty_response_content() {
692 let response = LlmGatewayResponse {
693 content: None,
694 object: None,
695 tool_calls: vec![],
696 thinking: None,
697 };
698
699 let gateway = Arc::new(MockGateway::new(vec![response]));
700 let broker = LlmBroker::new("test-model", gateway, None);
701
702 let messages = vec![LlmMessage::user("Hi")];
703 let result = broker.generate(&messages, None, None, None).await.unwrap();
704
705 assert_eq!(result, "");
706 }
707
708 #[tokio::test]
709 async fn test_generate_with_tool_call() {
710 let tool_call = LlmToolCall {
711 id: Some("call_1".to_string()),
712 name: "test_tool".to_string(),
713 arguments: HashMap::new(),
714 };
715
716 let first_response = LlmGatewayResponse {
717 content: None,
718 object: None,
719 tool_calls: vec![tool_call],
720 thinking: None,
721 };
722
723 let second_response = LlmGatewayResponse {
724 content: Some("After tool execution".to_string()),
725 object: None,
726 tool_calls: vec![],
727 thinking: None,
728 };
729
730 let gateway = Arc::new(MockGateway::new(vec![first_response, second_response]));
731 let broker = LlmBroker::new("test-model", gateway, None);
732
733 let tool = MockTool {
734 name: "test_tool".to_string(),
735 result: serde_json::json!({"result": "success"}),
736 };
737
738 let tools: Vec<Box<dyn LlmTool>> = vec![Box::new(tool)];
739
740 let messages = vec![LlmMessage::user("Use the tool")];
741 let result = broker.generate(&messages, Some(&tools), None, None).await.unwrap();
742
743 assert_eq!(result, "After tool execution");
744 }
745
746 #[tokio::test]
747 async fn test_generate_with_tool_call_no_tools_provided() {
748 let tool_call = LlmToolCall {
749 id: Some("call_1".to_string()),
750 name: "test_tool".to_string(),
751 arguments: HashMap::new(),
752 };
753
754 let response = LlmGatewayResponse {
755 content: Some("fallback".to_string()),
756 object: None,
757 tool_calls: vec![tool_call],
758 thinking: None,
759 };
760
761 let gateway = Arc::new(MockGateway::new(vec![response]));
762 let broker = LlmBroker::new("test-model", gateway, None);
763
764 let messages = vec![LlmMessage::user("Use the tool")];
765 let result = broker.generate(&messages, None, None, None).await.unwrap();
766
767 assert_eq!(result, "fallback");
768 }
769
770 #[tokio::test]
771 async fn test_generate_object() {
772 #[derive(Debug, Serialize, Deserialize, schemars::JsonSchema)]
773 struct TestObject {
774 test: String,
775 }
776
777 let gateway = Arc::new(MockGateway::new(vec![]));
778 let broker = LlmBroker::new("test-model", gateway, None);
779
780 let messages = vec![LlmMessage::user("Generate object")];
781 let result: TestObject = broker.generate_object(&messages, None, None).await.unwrap();
782
783 assert_eq!(result.test, "value");
784 }
785
786 #[tokio::test]
787 async fn test_generate_object_with_config() {
788 #[derive(Debug, Serialize, Deserialize, schemars::JsonSchema)]
789 struct TestData {
790 test: String,
791 }
792
793 let gateway = Arc::new(MockGateway::new(vec![]));
794 let broker = LlmBroker::new("test-model", gateway, None);
795
796 let config = CompletionConfig {
797 temperature: 0.1,
798 num_ctx: 1024,
799 max_tokens: 50,
800 num_predict: None,
801 top_p: None,
802 top_k: None,
803 response_format: None,
804 reasoning_effort: None,
805 };
806
807 let messages = vec![LlmMessage::user("Generate")];
808 let result: TestData = broker.generate_object(&messages, Some(config), None).await.unwrap();
809
810 assert_eq!(result.test, "value");
811 }
812
813 #[tokio::test]
814 async fn test_multiple_messages() {
815 let response = LlmGatewayResponse {
816 content: Some("Response to conversation".to_string()),
817 object: None,
818 tool_calls: vec![],
819 thinking: None,
820 };
821
822 let gateway = Arc::new(MockGateway::new(vec![response]));
823 let broker = LlmBroker::new("test-model", gateway, None);
824
825 let messages = vec![
826 LlmMessage::system("You are helpful"),
827 LlmMessage::user("First message"),
828 LlmMessage::assistant("First response"),
829 LlmMessage::user("Second message"),
830 ];
831
832 let result = broker.generate(&messages, None, None, None).await.unwrap();
833 assert_eq!(result, "Response to conversation");
834 }
835
836 #[tokio::test]
837 async fn test_generate_stream_basic() {
838 use futures::stream;
839
840 struct StreamingMockGateway;
842
843 #[async_trait::async_trait]
844 impl LlmGateway for StreamingMockGateway {
845 async fn complete(
846 &self,
847 _model: &str,
848 _messages: &[LlmMessage],
849 _tools: Option<&[Box<dyn LlmTool>]>,
850 _config: &CompletionConfig,
851 ) -> Result<LlmGatewayResponse> {
852 Ok(LlmGatewayResponse {
853 content: Some("test".to_string()),
854 object: None,
855 tool_calls: vec![],
856 thinking: None,
857 })
858 }
859
860 async fn complete_json(
861 &self,
862 _model: &str,
863 _messages: &[LlmMessage],
864 _schema: Value,
865 _config: &CompletionConfig,
866 ) -> Result<Value> {
867 Ok(serde_json::json!({}))
868 }
869
870 async fn get_available_models(&self) -> Result<Vec<String>> {
871 Ok(vec![])
872 }
873
874 async fn calculate_embeddings(
875 &self,
876 _text: &str,
877 _model: Option<&str>,
878 ) -> Result<Vec<f32>> {
879 Ok(vec![])
880 }
881
882 fn complete_stream<'a>(
883 &'a self,
884 _model: &'a str,
885 _messages: &'a [LlmMessage],
886 _tools: Option<&'a [Box<dyn LlmTool>]>,
887 _config: &'a CompletionConfig,
888 ) -> Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send + 'a>> {
889 Box::pin(stream::iter(vec![
890 Ok(StreamChunk::Content("Hello".to_string())),
891 Ok(StreamChunk::Content(" ".to_string())),
892 Ok(StreamChunk::Content("World".to_string())),
893 ]))
894 }
895 }
896
897 let gateway = Arc::new(StreamingMockGateway);
898 let broker = LlmBroker::new("test-model", gateway, None);
899 let messages = vec![LlmMessage::user("Hello")];
900
901 let mut stream = broker.generate_stream(&messages, None, None, None);
902 let mut result = String::new();
903
904 while let Some(chunk) = stream.next().await {
905 result.push_str(&chunk.unwrap());
906 }
907
908 assert_eq!(result, "Hello World");
909 }
910
911 #[tokio::test]
912 async fn test_generate_stream_with_tool_calls() {
913 use futures::stream;
914
915 struct ToolCallMockGateway {
917 call_count: std::sync::Mutex<usize>,
918 }
919
920 impl ToolCallMockGateway {
921 fn new() -> Self {
922 Self {
923 call_count: std::sync::Mutex::new(0),
924 }
925 }
926 }
927
928 #[async_trait::async_trait]
929 impl LlmGateway for ToolCallMockGateway {
930 async fn complete(
931 &self,
932 _model: &str,
933 _messages: &[LlmMessage],
934 _tools: Option<&[Box<dyn LlmTool>]>,
935 _config: &CompletionConfig,
936 ) -> Result<LlmGatewayResponse> {
937 Ok(LlmGatewayResponse {
938 content: Some("test".to_string()),
939 object: None,
940 tool_calls: vec![],
941 thinking: None,
942 })
943 }
944
945 async fn complete_json(
946 &self,
947 _model: &str,
948 _messages: &[LlmMessage],
949 _schema: Value,
950 _config: &CompletionConfig,
951 ) -> Result<Value> {
952 Ok(serde_json::json!({}))
953 }
954
955 async fn get_available_models(&self) -> Result<Vec<String>> {
956 Ok(vec![])
957 }
958
959 async fn calculate_embeddings(
960 &self,
961 _text: &str,
962 _model: Option<&str>,
963 ) -> Result<Vec<f32>> {
964 Ok(vec![])
965 }
966
967 fn complete_stream<'a>(
968 &'a self,
969 _model: &'a str,
970 _messages: &'a [LlmMessage],
971 _tools: Option<&'a [Box<dyn LlmTool>]>,
972 _config: &'a CompletionConfig,
973 ) -> Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send + 'a>> {
974 let mut count = self.call_count.lock().unwrap();
975 let call_num = *count;
976 *count += 1;
977
978 if call_num == 0 {
979 Box::pin(stream::iter(vec![
981 Ok(StreamChunk::Content("Initial ".to_string())),
982 Ok(StreamChunk::Content("response".to_string())),
983 Ok(StreamChunk::ToolCalls(vec![LlmToolCall {
984 id: Some("call_1".to_string()),
985 name: "test_tool".to_string(),
986 arguments: HashMap::new(),
987 }])),
988 ]))
989 } else {
990 Box::pin(stream::iter(vec![
992 Ok(StreamChunk::Content("After ".to_string())),
993 Ok(StreamChunk::Content("tool".to_string())),
994 ]))
995 }
996 }
997 }
998
999 let gateway = Arc::new(ToolCallMockGateway::new());
1000 let broker = LlmBroker::new("test-model", gateway, None);
1001
1002 let tool = MockTool {
1003 name: "test_tool".to_string(),
1004 result: serde_json::json!({"result": "success"}),
1005 };
1006 let tools: Vec<Box<dyn LlmTool>> = vec![Box::new(tool)];
1007
1008 let messages = vec![LlmMessage::user("Use the tool")];
1009 let mut stream = broker.generate_stream(&messages, Some(&tools), None, None);
1010
1011 let mut result = String::new();
1012 while let Some(chunk) = stream.next().await {
1013 result.push_str(&chunk.unwrap());
1014 }
1015
1016 assert!(result.contains("Initial response"));
1018 assert!(result.contains("After tool"));
1019 }
1020
1021 #[tokio::test]
1022 async fn test_generate_stream_without_tools() {
1023 use futures::stream;
1024
1025 struct SimpleStreamGateway;
1026
1027 #[async_trait::async_trait]
1028 impl LlmGateway for SimpleStreamGateway {
1029 async fn complete(
1030 &self,
1031 _model: &str,
1032 _messages: &[LlmMessage],
1033 _tools: Option<&[Box<dyn LlmTool>]>,
1034 _config: &CompletionConfig,
1035 ) -> Result<LlmGatewayResponse> {
1036 Ok(LlmGatewayResponse {
1037 content: Some("test".to_string()),
1038 object: None,
1039 tool_calls: vec![],
1040 thinking: None,
1041 })
1042 }
1043
1044 async fn complete_json(
1045 &self,
1046 _model: &str,
1047 _messages: &[LlmMessage],
1048 _schema: Value,
1049 _config: &CompletionConfig,
1050 ) -> Result<Value> {
1051 Ok(serde_json::json!({}))
1052 }
1053
1054 async fn get_available_models(&self) -> Result<Vec<String>> {
1055 Ok(vec![])
1056 }
1057
1058 async fn calculate_embeddings(
1059 &self,
1060 _text: &str,
1061 _model: Option<&str>,
1062 ) -> Result<Vec<f32>> {
1063 Ok(vec![])
1064 }
1065
1066 fn complete_stream<'a>(
1067 &'a self,
1068 _model: &'a str,
1069 _messages: &'a [LlmMessage],
1070 _tools: Option<&'a [Box<dyn LlmTool>]>,
1071 _config: &'a CompletionConfig,
1072 ) -> Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send + 'a>> {
1073 Box::pin(stream::iter(vec![
1075 Ok(StreamChunk::Content("Simple ".to_string())),
1076 Ok(StreamChunk::Content("stream".to_string())),
1077 ]))
1078 }
1079 }
1080
1081 let gateway = Arc::new(SimpleStreamGateway);
1082 let broker = LlmBroker::new("test-model", gateway, None);
1083
1084 let messages = vec![LlmMessage::user("Test")];
1085 let mut stream = broker.generate_stream(&messages, None, None, None);
1086
1087 let mut result = String::new();
1088 while let Some(chunk) = stream.next().await {
1089 result.push_str(&chunk.unwrap());
1090 }
1091
1092 assert_eq!(result, "Simple stream");
1093 }
1094
1095 #[tokio::test]
1096 async fn test_tracer_integration() {
1097 use crate::tracer::TracerSystem;
1098
1099 let response = LlmGatewayResponse {
1100 content: Some("Test response".to_string()),
1101 object: None,
1102 tool_calls: vec![],
1103 thinking: None,
1104 };
1105
1106 let gateway = Arc::new(MockGateway::new(vec![response]));
1107 let tracer = Arc::new(TracerSystem::default());
1108 let broker = LlmBroker::new("test-model", gateway, Some(tracer.clone()));
1109
1110 let messages = vec![LlmMessage::user("Test")];
1111 let correlation_id = "test-correlation-123";
1112
1113 let result = broker
1114 .generate(&messages, None, None, Some(correlation_id.to_string()))
1115 .await
1116 .unwrap();
1117
1118 assert_eq!(result, "Test response");
1119
1120 assert_eq!(tracer.len(), 2); let summaries = tracer.get_event_summaries(None, None, None);
1125 assert!(summaries[0].contains(correlation_id));
1126 assert!(summaries[1].contains(correlation_id));
1127 }
1128
1129 #[tokio::test]
1130 async fn test_tracer_with_tool_calls() {
1131 use crate::tracer::TracerSystem;
1132
1133 let tool_call = LlmToolCall {
1134 id: Some("call_1".to_string()),
1135 name: "test_tool".to_string(),
1136 arguments: HashMap::new(),
1137 };
1138
1139 let first_response = LlmGatewayResponse {
1140 content: None,
1141 object: None,
1142 tool_calls: vec![tool_call],
1143 thinking: None,
1144 };
1145
1146 let second_response = LlmGatewayResponse {
1147 content: Some("After tool".to_string()),
1148 object: None,
1149 tool_calls: vec![],
1150 thinking: None,
1151 };
1152
1153 let gateway = Arc::new(MockGateway::new(vec![first_response, second_response]));
1154 let tracer = Arc::new(TracerSystem::default());
1155 let broker = LlmBroker::new("test-model", gateway, Some(tracer.clone()));
1156
1157 let tool = MockTool {
1158 name: "test_tool".to_string(),
1159 result: serde_json::json!({"result": "success"}),
1160 };
1161 let tools: Vec<Box<dyn LlmTool>> = vec![Box::new(tool)];
1162
1163 let messages = vec![LlmMessage::user("Use tool")];
1164 let correlation_id = "tool-test-456";
1165
1166 let result = broker
1167 .generate(&messages, Some(&tools), None, Some(correlation_id.to_string()))
1168 .await
1169 .unwrap();
1170
1171 assert_eq!(result, "After tool");
1172
1173 assert_eq!(tracer.len(), 5);
1175
1176 let summaries = tracer.get_event_summaries(None, None, None);
1178 for summary in &summaries {
1179 assert!(summary.contains(correlation_id));
1180 }
1181 }
1182}