mojentic/llm/
broker.rs

1use crate::error::Result;
2use crate::llm::gateway::{CompletionConfig, LlmGateway, StreamChunk};
3use crate::llm::models::{LlmGatewayResponse, LlmMessage, MessageRole};
4use crate::llm::tools::LlmTool;
5use crate::tracer::TracerSystem;
6use futures::stream::{Stream, StreamExt};
7use serde::{Deserialize, Serialize};
8use std::pin::Pin;
9use std::sync::Arc;
10use tracing::{info, warn};
11use uuid::Uuid;
12
13/// Main interface for LLM interactions
14#[derive(Clone)]
15pub struct LlmBroker {
16    model: String,
17    gateway: Arc<dyn LlmGateway>,
18    tracer: Option<Arc<TracerSystem>>,
19}
20
21impl LlmBroker {
22    /// Create a new LLM broker
23    ///
24    /// # Arguments
25    ///
26    /// * `model` - The name of the LLM model to use
27    /// * `gateway` - The gateway to use for LLM communication
28    /// * `tracer` - Optional tracer system for observability
29    pub fn new(
30        model: impl Into<String>,
31        gateway: Arc<dyn LlmGateway>,
32        tracer: Option<Arc<TracerSystem>>,
33    ) -> Self {
34        Self {
35            model: model.into(),
36            gateway,
37            tracer,
38        }
39    }
40
41    /// Generate text response from LLM
42    ///
43    /// # Arguments
44    ///
45    /// * `messages` - The messages to send to the LLM
46    /// * `tools` - Optional tools available to the LLM
47    /// * `config` - Optional completion configuration
48    /// * `correlation_id` - Optional correlation ID for tracing (generates UUID if None)
49    pub async fn generate(
50        &self,
51        messages: &[LlmMessage],
52        tools: Option<&[Box<dyn LlmTool>]>,
53        config: Option<CompletionConfig>,
54        correlation_id: Option<String>,
55    ) -> Result<String> {
56        let config = config.unwrap_or_default();
57        let current_messages = messages.to_vec();
58        let correlation_id = correlation_id.unwrap_or_else(|| Uuid::new_v4().to_string());
59
60        // Record LLM call
61        if let Some(tracer) = &self.tracer {
62            let messages_json: Vec<std::collections::HashMap<String, serde_json::Value>> =
63                current_messages
64                    .iter()
65                    .map(|m| {
66                        let mut map = std::collections::HashMap::new();
67                        map.insert("role".to_string(), serde_json::json!(format!("{:?}", m.role)));
68                        if let Some(content) = &m.content {
69                            map.insert("content".to_string(), serde_json::json!(content));
70                        }
71                        map
72                    })
73                    .collect();
74
75            let tools_json = tools.map(|t| {
76                t.iter()
77                    .map(|tool| {
78                        let desc = tool.descriptor();
79                        let mut map = std::collections::HashMap::new();
80                        map.insert("name".to_string(), serde_json::json!(desc.function.name));
81                        map.insert(
82                            "description".to_string(),
83                            serde_json::json!(desc.function.description),
84                        );
85                        map
86                    })
87                    .collect()
88            });
89
90            tracer.record_llm_call(
91                &self.model,
92                messages_json,
93                config.temperature as f64,
94                tools_json,
95                "LlmBroker",
96                &correlation_id,
97            );
98        }
99
100        // Measure call duration
101        let start = std::time::Instant::now();
102
103        // Make initial LLM call
104        let response =
105            self.gateway.complete(&self.model, &current_messages, tools, &config).await?;
106
107        let call_duration_ms = start.elapsed().as_secs_f64() * 1000.0;
108
109        // Record LLM response
110        if let Some(tracer) = &self.tracer {
111            let tool_calls_json = if !response.tool_calls.is_empty() {
112                Some(
113                    response
114                        .tool_calls
115                        .iter()
116                        .map(|tc| {
117                            let mut map = std::collections::HashMap::new();
118                            map.insert("name".to_string(), serde_json::json!(&tc.name));
119                            if let Some(id) = &tc.id {
120                                map.insert("id".to_string(), serde_json::json!(id));
121                            }
122                            map
123                        })
124                        .collect(),
125                )
126            } else {
127                None
128            };
129
130            tracer.record_llm_response(
131                &self.model,
132                response.content.as_ref().unwrap_or(&String::new()),
133                tool_calls_json,
134                Some(call_duration_ms),
135                "LlmBroker",
136                &correlation_id,
137            );
138        }
139
140        // Handle tool calls if present
141        if !response.tool_calls.is_empty() {
142            if let Some(tools) = tools {
143                return self
144                    .handle_tool_calls(current_messages, response, tools, &config, &correlation_id)
145                    .await;
146            }
147        }
148
149        Ok(response.content.unwrap_or_default())
150    }
151
152    fn handle_tool_calls<'a>(
153        &'a self,
154        mut messages: Vec<LlmMessage>,
155        response: LlmGatewayResponse,
156        tools: &'a [Box<dyn LlmTool>],
157        config: &'a CompletionConfig,
158        correlation_id: &'a str,
159    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<String>> + Send + 'a>> {
160        Box::pin(async move {
161            info!("Tool calls requested: {}", response.tool_calls.len());
162
163            for tool_call in &response.tool_calls {
164                // Find matching tool
165                if let Some(tool) = tools.iter().find(|t| t.matches(&tool_call.name)) {
166                    info!("Executing tool: {}", tool_call.name);
167
168                    // Measure tool execution time
169                    let start = std::time::Instant::now();
170                    let output = tool.run(&tool_call.arguments)?;
171                    let tool_duration_ms = start.elapsed().as_secs_f64() * 1000.0;
172
173                    // Record tool call
174                    if let Some(tracer) = &self.tracer {
175                        tracer.record_tool_call(
176                            &tool_call.name,
177                            tool_call.arguments.clone(),
178                            output.clone(),
179                            Some("LlmBroker".to_string()),
180                            Some(tool_duration_ms),
181                            "LlmBroker",
182                            correlation_id,
183                        );
184                    }
185
186                    // Add tool call and response to messages
187                    messages.push(LlmMessage {
188                        role: MessageRole::Assistant,
189                        content: None,
190                        tool_calls: Some(vec![tool_call.clone()]),
191                        image_paths: None,
192                    });
193                    messages.push(LlmMessage {
194                        role: MessageRole::Tool,
195                        content: Some(serde_json::to_string(&output)?),
196                        tool_calls: Some(vec![tool_call.clone()]),
197                        image_paths: None,
198                    });
199
200                    // Recursively call generate with updated messages, passing correlation_id
201                    return self
202                        .generate(
203                            &messages,
204                            Some(tools),
205                            Some(config.clone()),
206                            Some(correlation_id.to_string()),
207                        )
208                        .await;
209                } else {
210                    warn!("Tool not found: {}", tool_call.name);
211                }
212            }
213
214            Ok(response.content.unwrap_or_default())
215        })
216    }
217
218    /// Generate structured object response from LLM
219    ///
220    /// # Arguments
221    ///
222    /// * `messages` - The messages to send to the LLM
223    /// * `config` - Optional completion configuration
224    /// * `correlation_id` - Optional correlation ID for tracing (generates UUID if None)
225    pub async fn generate_object<T>(
226        &self,
227        messages: &[LlmMessage],
228        config: Option<CompletionConfig>,
229        correlation_id: Option<String>,
230    ) -> Result<T>
231    where
232        T: for<'de> Deserialize<'de> + Serialize + schemars::JsonSchema + Send,
233    {
234        let config = config.unwrap_or_default();
235        let correlation_id = correlation_id.unwrap_or_else(|| Uuid::new_v4().to_string());
236
237        // Generate JSON schema for the type
238        let schema = serde_json::to_value(schemars::schema_for!(T))?;
239
240        // Record LLM call
241        if let Some(tracer) = &self.tracer {
242            let messages_json: Vec<std::collections::HashMap<String, serde_json::Value>> = messages
243                .iter()
244                .map(|m| {
245                    let mut map = std::collections::HashMap::new();
246                    map.insert("role".to_string(), serde_json::json!(format!("{:?}", m.role)));
247                    if let Some(content) = &m.content {
248                        map.insert("content".to_string(), serde_json::json!(content));
249                    }
250                    map
251                })
252                .collect();
253
254            tracer.record_llm_call(
255                &self.model,
256                messages_json,
257                config.temperature as f64,
258                None,
259                "LlmBroker::generate_object",
260                &correlation_id,
261            );
262        }
263
264        // Measure call duration
265        let start = std::time::Instant::now();
266
267        // Call the gateway with the schema
268        let json_response =
269            self.gateway.complete_json(&self.model, messages, schema, &config).await?;
270
271        let call_duration_ms = start.elapsed().as_secs_f64() * 1000.0;
272
273        // Deserialize the JSON into the target type
274        let object: T = serde_json::from_value(json_response.clone())?;
275
276        // Record LLM response
277        if let Some(tracer) = &self.tracer {
278            let object_str = serde_json::to_string_pretty(&json_response).unwrap_or_default();
279            tracer.record_llm_response(
280                &self.model,
281                format!("Structured response: {}", object_str),
282                None,
283                Some(call_duration_ms),
284                "LlmBroker::generate_object",
285                &correlation_id,
286            );
287        }
288
289        Ok(object)
290    }
291
292    /// Generate streaming text response from LLM
293    ///
294    /// Returns a stream that yields content chunks as they arrive. When tool calls
295    /// are detected, the broker executes them and recursively streams the LLM's
296    /// follow-up response.
297    ///
298    /// # Arguments
299    ///
300    /// * `messages` - The messages to send to the LLM
301    /// * `tools` - Optional tools available to the LLM
302    /// * `config` - Optional completion configuration
303    /// * `correlation_id` - Optional correlation ID for tracing (generates UUID if None)
304    ///
305    /// # Example
306    ///
307    /// ```ignore
308    /// use futures::stream::StreamExt;
309    ///
310    /// let broker = LlmBroker::new("qwen3:32b", gateway, None);
311    /// let messages = vec![LlmMessage::user("Tell me a story")];
312    ///
313    /// let mut stream = broker.generate_stream(&messages, None, None, None);
314    /// while let Some(result) = stream.next().await {
315    ///     match result {
316    ///         Ok(chunk) => print!("{}", chunk),
317    ///         Err(e) => eprintln!("Error: {}", e),
318    ///     }
319    /// }
320    /// ```
321    pub fn generate_stream<'a>(
322        &'a self,
323        messages: &'a [LlmMessage],
324        tools: Option<&'a [Box<dyn LlmTool>]>,
325        config: Option<CompletionConfig>,
326        correlation_id: Option<String>,
327    ) -> Pin<Box<dyn Stream<Item = Result<String>> + 'a>> {
328        let config = config.unwrap_or_default();
329        let current_messages = messages.to_vec();
330        let correlation_id = correlation_id.unwrap_or_else(|| Uuid::new_v4().to_string());
331
332        Box::pin(async_stream::stream! {
333            // Record LLM call
334            if let Some(tracer) = &self.tracer {
335                let messages_json: Vec<std::collections::HashMap<String, serde_json::Value>> =
336                    current_messages
337                        .iter()
338                        .map(|m| {
339                            let mut map = std::collections::HashMap::new();
340                            map.insert("role".to_string(), serde_json::json!(format!("{:?}", m.role)));
341                            if let Some(content) = &m.content {
342                                map.insert("content".to_string(), serde_json::json!(content));
343                            }
344                            map
345                        })
346                        .collect();
347
348                let tools_json = tools.map(|t| {
349                    t.iter()
350                        .map(|tool| {
351                            let desc = tool.descriptor();
352                            let mut map = std::collections::HashMap::new();
353                            map.insert("name".to_string(), serde_json::json!(desc.function.name));
354                            map.insert(
355                                "description".to_string(),
356                                serde_json::json!(desc.function.description),
357                            );
358                            map
359                        })
360                        .collect()
361                });
362
363                tracer.record_llm_call(
364                    &self.model,
365                    messages_json,
366                    config.temperature as f64,
367                    tools_json,
368                    "LlmBroker::generate_stream",
369                    &correlation_id,
370                );
371            }
372
373            let mut accumulated_content = String::new();
374            let mut accumulated_tool_calls = Vec::new();
375
376            // Measure stream duration
377            let start = std::time::Instant::now();
378
379            // Stream from gateway
380            let mut stream = self.gateway.complete_stream(
381                &self.model,
382                &current_messages,
383                tools,
384                &config,
385            );
386
387            while let Some(chunk_result) = stream.next().await {
388                match chunk_result {
389                    Ok(StreamChunk::Content(content)) => {
390                        accumulated_content.push_str(&content);
391                        yield Ok(content);
392                    }
393                    Ok(StreamChunk::ToolCalls(tool_calls)) => {
394                        accumulated_tool_calls = tool_calls;
395                    }
396                    Err(e) => {
397                        yield Err(e);
398                        return;
399                    }
400                }
401            }
402
403            let call_duration_ms = start.elapsed().as_secs_f64() * 1000.0;
404
405            // Record LLM response
406            if let Some(tracer) = &self.tracer {
407                let tool_calls_json = if !accumulated_tool_calls.is_empty() {
408                    Some(
409                        accumulated_tool_calls
410                            .iter()
411                            .map(|tc| {
412                                let mut map = std::collections::HashMap::new();
413                                map.insert("name".to_string(), serde_json::json!(&tc.name));
414                                if let Some(id) = &tc.id {
415                                    map.insert("id".to_string(), serde_json::json!(id));
416                                }
417                                map
418                            })
419                            .collect(),
420                    )
421                } else {
422                    None
423                };
424
425                tracer.record_llm_response(
426                    &self.model,
427                    &accumulated_content,
428                    tool_calls_json,
429                    Some(call_duration_ms),
430                    "LlmBroker::generate_stream",
431                    &correlation_id,
432                );
433            }
434
435            // Handle tool calls if present
436            if !accumulated_tool_calls.is_empty() {
437                if let Some(tools) = tools {
438                    info!("Processing {} tool call(s) in stream", accumulated_tool_calls.len());
439
440                    // Build new messages with tool results
441                    let mut new_messages = current_messages.clone();
442
443                    // Add assistant message with tool calls
444                    new_messages.push(LlmMessage {
445                        role: MessageRole::Assistant,
446                        content: Some(accumulated_content),
447                        tool_calls: Some(accumulated_tool_calls.clone()),
448                        image_paths: None,
449                    });
450
451                    // Execute tools and add results
452                    for tool_call in &accumulated_tool_calls {
453                        if let Some(tool) = tools.iter().find(|t| t.matches(&tool_call.name)) {
454                            info!("Executing tool: {}", tool_call.name);
455
456                            // Measure tool execution time
457                            let tool_start = std::time::Instant::now();
458
459                            match tool.run(&tool_call.arguments) {
460                                Ok(output) => {
461                                    let tool_duration_ms = tool_start.elapsed().as_secs_f64() * 1000.0;
462
463                                    // Record tool call
464                                    if let Some(tracer) = &self.tracer {
465                                        tracer.record_tool_call(
466                                            &tool_call.name,
467                                            tool_call.arguments.clone(),
468                                            output.clone(),
469                                            Some("LlmBroker::generate_stream".to_string()),
470                                            Some(tool_duration_ms),
471                                            "LlmBroker::generate_stream",
472                                            &correlation_id,
473                                        );
474                                    }
475
476                                    let output_str = match serde_json::to_string(&output) {
477                                        Ok(s) => s,
478                                        Err(e) => {
479                                            yield Err(e.into());
480                                            return;
481                                        }
482                                    };
483
484                                    new_messages.push(LlmMessage {
485                                        role: MessageRole::Tool,
486                                        content: Some(output_str),
487                                        tool_calls: Some(vec![tool_call.clone()]),
488                                        image_paths: None,
489                                    });
490                                }
491                                Err(e) => {
492                                    warn!("Tool execution failed: {}", e);
493                                    yield Err(e);
494                                    return;
495                                }
496                            }
497                        } else {
498                            warn!("Tool not found: {}", tool_call.name);
499                        }
500                    }
501
502                    // Recursively stream with updated messages, passing correlation_id
503                    let mut recursive_stream = self.generate_stream(&new_messages, Some(tools), Some(config.clone()), Some(correlation_id.clone()));
504
505                    while let Some(result) = recursive_stream.next().await {
506                        yield result;
507                    }
508                } else {
509                    warn!("LLM requested tool calls but no tools provided");
510                }
511            }
512        })
513    }
514}
515
516#[cfg(test)]
517mod tests {
518    use super::*;
519    use crate::llm::models::LlmToolCall;
520    use crate::llm::tools::{FunctionDescriptor, ToolDescriptor};
521    use serde::{Deserialize, Serialize};
522    use serde_json::Value;
523    use std::collections::HashMap;
524
525    // Mock gateway for testing
526    struct MockGateway {
527        responses: Vec<LlmGatewayResponse>,
528        call_count: std::sync::Mutex<usize>,
529    }
530
531    impl MockGateway {
532        fn new(responses: Vec<LlmGatewayResponse>) -> Self {
533            Self {
534                responses,
535                call_count: std::sync::Mutex::new(0),
536            }
537        }
538    }
539
540    #[async_trait::async_trait]
541    impl LlmGateway for MockGateway {
542        async fn complete(
543            &self,
544            _model: &str,
545            _messages: &[LlmMessage],
546            _tools: Option<&[Box<dyn LlmTool>]>,
547            _config: &CompletionConfig,
548        ) -> Result<LlmGatewayResponse> {
549            let mut count = self.call_count.lock().unwrap();
550            let idx = *count;
551            *count += 1;
552
553            if idx < self.responses.len() {
554                Ok(self.responses[idx].clone())
555            } else {
556                Ok(LlmGatewayResponse {
557                    content: Some("default response".to_string()),
558                    object: None,
559                    tool_calls: vec![],
560                    thinking: None,
561                })
562            }
563        }
564
565        async fn complete_json(
566            &self,
567            _model: &str,
568            _messages: &[LlmMessage],
569            _schema: Value,
570            _config: &CompletionConfig,
571        ) -> Result<Value> {
572            Ok(serde_json::json!({"test": "value"}))
573        }
574
575        async fn get_available_models(&self) -> Result<Vec<String>> {
576            Ok(vec!["test-model".to_string()])
577        }
578
579        async fn calculate_embeddings(
580            &self,
581            _text: &str,
582            _model: Option<&str>,
583        ) -> Result<Vec<f32>> {
584            Ok(vec![0.1, 0.2, 0.3])
585        }
586
587        fn complete_stream<'a>(
588            &'a self,
589            _model: &'a str,
590            _messages: &'a [LlmMessage],
591            _tools: Option<&'a [Box<dyn LlmTool>]>,
592            _config: &'a CompletionConfig,
593        ) -> Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send + 'a>> {
594            use futures::stream;
595            Box::pin(stream::iter(vec![Ok(StreamChunk::Content("test".to_string()))]))
596        }
597    }
598
599    // Mock tool for testing
600    struct MockTool {
601        name: String,
602        result: Value,
603    }
604
605    impl LlmTool for MockTool {
606        fn run(&self, _args: &HashMap<String, Value>) -> Result<Value> {
607            Ok(self.result.clone())
608        }
609
610        fn descriptor(&self) -> ToolDescriptor {
611            ToolDescriptor {
612                r#type: "function".to_string(),
613                function: FunctionDescriptor {
614                    name: self.name.clone(),
615                    description: "A mock tool".to_string(),
616                    parameters: serde_json::json!({}),
617                },
618            }
619        }
620
621        fn clone_box(&self) -> Box<dyn LlmTool> {
622            Box::new(MockTool {
623                name: self.name.clone(),
624                result: self.result.clone(),
625            })
626        }
627    }
628
629    #[tokio::test]
630    async fn test_broker_new() {
631        let gateway = Arc::new(MockGateway::new(vec![]));
632        let broker = LlmBroker::new("test-model", gateway, None);
633        assert_eq!(broker.model, "test-model");
634    }
635
636    #[tokio::test]
637    async fn test_broker_new_string_conversion() {
638        let gateway = Arc::new(MockGateway::new(vec![]));
639        let broker = LlmBroker::new(String::from("my-model"), gateway, None);
640        assert_eq!(broker.model, "my-model");
641    }
642
643    #[tokio::test]
644    async fn test_generate_simple_response() {
645        let response = LlmGatewayResponse {
646            content: Some("Hello, World!".to_string()),
647            object: None,
648            tool_calls: vec![],
649            thinking: None,
650        };
651
652        let gateway = Arc::new(MockGateway::new(vec![response]));
653        let broker = LlmBroker::new("test-model", gateway, None);
654
655        let messages = vec![LlmMessage::user("Hi")];
656        let result = broker.generate(&messages, None, None, None).await.unwrap();
657
658        assert_eq!(result, "Hello, World!");
659    }
660
661    #[tokio::test]
662    async fn test_generate_with_custom_config() {
663        let response = LlmGatewayResponse {
664            content: Some("Response".to_string()),
665            object: None,
666            tool_calls: vec![],
667            thinking: None,
668        };
669
670        let gateway = Arc::new(MockGateway::new(vec![response]));
671        let broker = LlmBroker::new("test-model", gateway, None);
672
673        let config = CompletionConfig {
674            temperature: 0.5,
675            num_ctx: 2048,
676            max_tokens: 100,
677            num_predict: Some(50),
678            top_p: None,
679            top_k: None,
680            response_format: None,
681            reasoning_effort: None,
682        };
683
684        let messages = vec![LlmMessage::user("Hi")];
685        let result = broker.generate(&messages, None, Some(config), None).await.unwrap();
686
687        assert_eq!(result, "Response");
688    }
689
690    #[tokio::test]
691    async fn test_generate_empty_response_content() {
692        let response = LlmGatewayResponse {
693            content: None,
694            object: None,
695            tool_calls: vec![],
696            thinking: None,
697        };
698
699        let gateway = Arc::new(MockGateway::new(vec![response]));
700        let broker = LlmBroker::new("test-model", gateway, None);
701
702        let messages = vec![LlmMessage::user("Hi")];
703        let result = broker.generate(&messages, None, None, None).await.unwrap();
704
705        assert_eq!(result, "");
706    }
707
708    #[tokio::test]
709    async fn test_generate_with_tool_call() {
710        let tool_call = LlmToolCall {
711            id: Some("call_1".to_string()),
712            name: "test_tool".to_string(),
713            arguments: HashMap::new(),
714        };
715
716        let first_response = LlmGatewayResponse {
717            content: None,
718            object: None,
719            tool_calls: vec![tool_call],
720            thinking: None,
721        };
722
723        let second_response = LlmGatewayResponse {
724            content: Some("After tool execution".to_string()),
725            object: None,
726            tool_calls: vec![],
727            thinking: None,
728        };
729
730        let gateway = Arc::new(MockGateway::new(vec![first_response, second_response]));
731        let broker = LlmBroker::new("test-model", gateway, None);
732
733        let tool = MockTool {
734            name: "test_tool".to_string(),
735            result: serde_json::json!({"result": "success"}),
736        };
737
738        let tools: Vec<Box<dyn LlmTool>> = vec![Box::new(tool)];
739
740        let messages = vec![LlmMessage::user("Use the tool")];
741        let result = broker.generate(&messages, Some(&tools), None, None).await.unwrap();
742
743        assert_eq!(result, "After tool execution");
744    }
745
746    #[tokio::test]
747    async fn test_generate_with_tool_call_no_tools_provided() {
748        let tool_call = LlmToolCall {
749            id: Some("call_1".to_string()),
750            name: "test_tool".to_string(),
751            arguments: HashMap::new(),
752        };
753
754        let response = LlmGatewayResponse {
755            content: Some("fallback".to_string()),
756            object: None,
757            tool_calls: vec![tool_call],
758            thinking: None,
759        };
760
761        let gateway = Arc::new(MockGateway::new(vec![response]));
762        let broker = LlmBroker::new("test-model", gateway, None);
763
764        let messages = vec![LlmMessage::user("Use the tool")];
765        let result = broker.generate(&messages, None, None, None).await.unwrap();
766
767        assert_eq!(result, "fallback");
768    }
769
770    #[tokio::test]
771    async fn test_generate_object() {
772        #[derive(Debug, Serialize, Deserialize, schemars::JsonSchema)]
773        struct TestObject {
774            test: String,
775        }
776
777        let gateway = Arc::new(MockGateway::new(vec![]));
778        let broker = LlmBroker::new("test-model", gateway, None);
779
780        let messages = vec![LlmMessage::user("Generate object")];
781        let result: TestObject = broker.generate_object(&messages, None, None).await.unwrap();
782
783        assert_eq!(result.test, "value");
784    }
785
786    #[tokio::test]
787    async fn test_generate_object_with_config() {
788        #[derive(Debug, Serialize, Deserialize, schemars::JsonSchema)]
789        struct TestData {
790            test: String,
791        }
792
793        let gateway = Arc::new(MockGateway::new(vec![]));
794        let broker = LlmBroker::new("test-model", gateway, None);
795
796        let config = CompletionConfig {
797            temperature: 0.1,
798            num_ctx: 1024,
799            max_tokens: 50,
800            num_predict: None,
801            top_p: None,
802            top_k: None,
803            response_format: None,
804            reasoning_effort: None,
805        };
806
807        let messages = vec![LlmMessage::user("Generate")];
808        let result: TestData = broker.generate_object(&messages, Some(config), None).await.unwrap();
809
810        assert_eq!(result.test, "value");
811    }
812
813    #[tokio::test]
814    async fn test_multiple_messages() {
815        let response = LlmGatewayResponse {
816            content: Some("Response to conversation".to_string()),
817            object: None,
818            tool_calls: vec![],
819            thinking: None,
820        };
821
822        let gateway = Arc::new(MockGateway::new(vec![response]));
823        let broker = LlmBroker::new("test-model", gateway, None);
824
825        let messages = vec![
826            LlmMessage::system("You are helpful"),
827            LlmMessage::user("First message"),
828            LlmMessage::assistant("First response"),
829            LlmMessage::user("Second message"),
830        ];
831
832        let result = broker.generate(&messages, None, None, None).await.unwrap();
833        assert_eq!(result, "Response to conversation");
834    }
835
836    #[tokio::test]
837    async fn test_generate_stream_basic() {
838        use futures::stream;
839
840        // Mock gateway that returns a simple stream
841        struct StreamingMockGateway;
842
843        #[async_trait::async_trait]
844        impl LlmGateway for StreamingMockGateway {
845            async fn complete(
846                &self,
847                _model: &str,
848                _messages: &[LlmMessage],
849                _tools: Option<&[Box<dyn LlmTool>]>,
850                _config: &CompletionConfig,
851            ) -> Result<LlmGatewayResponse> {
852                Ok(LlmGatewayResponse {
853                    content: Some("test".to_string()),
854                    object: None,
855                    tool_calls: vec![],
856                    thinking: None,
857                })
858            }
859
860            async fn complete_json(
861                &self,
862                _model: &str,
863                _messages: &[LlmMessage],
864                _schema: Value,
865                _config: &CompletionConfig,
866            ) -> Result<Value> {
867                Ok(serde_json::json!({}))
868            }
869
870            async fn get_available_models(&self) -> Result<Vec<String>> {
871                Ok(vec![])
872            }
873
874            async fn calculate_embeddings(
875                &self,
876                _text: &str,
877                _model: Option<&str>,
878            ) -> Result<Vec<f32>> {
879                Ok(vec![])
880            }
881
882            fn complete_stream<'a>(
883                &'a self,
884                _model: &'a str,
885                _messages: &'a [LlmMessage],
886                _tools: Option<&'a [Box<dyn LlmTool>]>,
887                _config: &'a CompletionConfig,
888            ) -> Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send + 'a>> {
889                Box::pin(stream::iter(vec![
890                    Ok(StreamChunk::Content("Hello".to_string())),
891                    Ok(StreamChunk::Content(" ".to_string())),
892                    Ok(StreamChunk::Content("World".to_string())),
893                ]))
894            }
895        }
896
897        let gateway = Arc::new(StreamingMockGateway);
898        let broker = LlmBroker::new("test-model", gateway, None);
899        let messages = vec![LlmMessage::user("Hello")];
900
901        let mut stream = broker.generate_stream(&messages, None, None, None);
902        let mut result = String::new();
903
904        while let Some(chunk) = stream.next().await {
905            result.push_str(&chunk.unwrap());
906        }
907
908        assert_eq!(result, "Hello World");
909    }
910
911    #[tokio::test]
912    async fn test_generate_stream_with_tool_calls() {
913        use futures::stream;
914
915        // Mock gateway that returns tool calls
916        struct ToolCallMockGateway {
917            call_count: std::sync::Mutex<usize>,
918        }
919
920        impl ToolCallMockGateway {
921            fn new() -> Self {
922                Self {
923                    call_count: std::sync::Mutex::new(0),
924                }
925            }
926        }
927
928        #[async_trait::async_trait]
929        impl LlmGateway for ToolCallMockGateway {
930            async fn complete(
931                &self,
932                _model: &str,
933                _messages: &[LlmMessage],
934                _tools: Option<&[Box<dyn LlmTool>]>,
935                _config: &CompletionConfig,
936            ) -> Result<LlmGatewayResponse> {
937                Ok(LlmGatewayResponse {
938                    content: Some("test".to_string()),
939                    object: None,
940                    tool_calls: vec![],
941                    thinking: None,
942                })
943            }
944
945            async fn complete_json(
946                &self,
947                _model: &str,
948                _messages: &[LlmMessage],
949                _schema: Value,
950                _config: &CompletionConfig,
951            ) -> Result<Value> {
952                Ok(serde_json::json!({}))
953            }
954
955            async fn get_available_models(&self) -> Result<Vec<String>> {
956                Ok(vec![])
957            }
958
959            async fn calculate_embeddings(
960                &self,
961                _text: &str,
962                _model: Option<&str>,
963            ) -> Result<Vec<f32>> {
964                Ok(vec![])
965            }
966
967            fn complete_stream<'a>(
968                &'a self,
969                _model: &'a str,
970                _messages: &'a [LlmMessage],
971                _tools: Option<&'a [Box<dyn LlmTool>]>,
972                _config: &'a CompletionConfig,
973            ) -> Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send + 'a>> {
974                let mut count = self.call_count.lock().unwrap();
975                let call_num = *count;
976                *count += 1;
977
978                if call_num == 0 {
979                    // First call: return content with tool call
980                    Box::pin(stream::iter(vec![
981                        Ok(StreamChunk::Content("Initial ".to_string())),
982                        Ok(StreamChunk::Content("response".to_string())),
983                        Ok(StreamChunk::ToolCalls(vec![LlmToolCall {
984                            id: Some("call_1".to_string()),
985                            name: "test_tool".to_string(),
986                            arguments: HashMap::new(),
987                        }])),
988                    ]))
989                } else {
990                    // Second call (after tool execution): return final content
991                    Box::pin(stream::iter(vec![
992                        Ok(StreamChunk::Content("After ".to_string())),
993                        Ok(StreamChunk::Content("tool".to_string())),
994                    ]))
995                }
996            }
997        }
998
999        let gateway = Arc::new(ToolCallMockGateway::new());
1000        let broker = LlmBroker::new("test-model", gateway, None);
1001
1002        let tool = MockTool {
1003            name: "test_tool".to_string(),
1004            result: serde_json::json!({"result": "success"}),
1005        };
1006        let tools: Vec<Box<dyn LlmTool>> = vec![Box::new(tool)];
1007
1008        let messages = vec![LlmMessage::user("Use the tool")];
1009        let mut stream = broker.generate_stream(&messages, Some(&tools), None, None);
1010
1011        let mut result = String::new();
1012        while let Some(chunk) = stream.next().await {
1013            result.push_str(&chunk.unwrap());
1014        }
1015
1016        // Should contain both initial response and post-tool response
1017        assert!(result.contains("Initial response"));
1018        assert!(result.contains("After tool"));
1019    }
1020
1021    #[tokio::test]
1022    async fn test_generate_stream_without_tools() {
1023        use futures::stream;
1024
1025        struct SimpleStreamGateway;
1026
1027        #[async_trait::async_trait]
1028        impl LlmGateway for SimpleStreamGateway {
1029            async fn complete(
1030                &self,
1031                _model: &str,
1032                _messages: &[LlmMessage],
1033                _tools: Option<&[Box<dyn LlmTool>]>,
1034                _config: &CompletionConfig,
1035            ) -> Result<LlmGatewayResponse> {
1036                Ok(LlmGatewayResponse {
1037                    content: Some("test".to_string()),
1038                    object: None,
1039                    tool_calls: vec![],
1040                    thinking: None,
1041                })
1042            }
1043
1044            async fn complete_json(
1045                &self,
1046                _model: &str,
1047                _messages: &[LlmMessage],
1048                _schema: Value,
1049                _config: &CompletionConfig,
1050            ) -> Result<Value> {
1051                Ok(serde_json::json!({}))
1052            }
1053
1054            async fn get_available_models(&self) -> Result<Vec<String>> {
1055                Ok(vec![])
1056            }
1057
1058            async fn calculate_embeddings(
1059                &self,
1060                _text: &str,
1061                _model: Option<&str>,
1062            ) -> Result<Vec<f32>> {
1063                Ok(vec![])
1064            }
1065
1066            fn complete_stream<'a>(
1067                &'a self,
1068                _model: &'a str,
1069                _messages: &'a [LlmMessage],
1070                _tools: Option<&'a [Box<dyn LlmTool>]>,
1071                _config: &'a CompletionConfig,
1072            ) -> Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send + 'a>> {
1073                // Simple stream with no tool calls
1074                Box::pin(stream::iter(vec![
1075                    Ok(StreamChunk::Content("Simple ".to_string())),
1076                    Ok(StreamChunk::Content("stream".to_string())),
1077                ]))
1078            }
1079        }
1080
1081        let gateway = Arc::new(SimpleStreamGateway);
1082        let broker = LlmBroker::new("test-model", gateway, None);
1083
1084        let messages = vec![LlmMessage::user("Test")];
1085        let mut stream = broker.generate_stream(&messages, None, None, None);
1086
1087        let mut result = String::new();
1088        while let Some(chunk) = stream.next().await {
1089            result.push_str(&chunk.unwrap());
1090        }
1091
1092        assert_eq!(result, "Simple stream");
1093    }
1094
1095    #[tokio::test]
1096    async fn test_tracer_integration() {
1097        use crate::tracer::TracerSystem;
1098
1099        let response = LlmGatewayResponse {
1100            content: Some("Test response".to_string()),
1101            object: None,
1102            tool_calls: vec![],
1103            thinking: None,
1104        };
1105
1106        let gateway = Arc::new(MockGateway::new(vec![response]));
1107        let tracer = Arc::new(TracerSystem::default());
1108        let broker = LlmBroker::new("test-model", gateway, Some(tracer.clone()));
1109
1110        let messages = vec![LlmMessage::user("Test")];
1111        let correlation_id = "test-correlation-123";
1112
1113        let result = broker
1114            .generate(&messages, None, None, Some(correlation_id.to_string()))
1115            .await
1116            .unwrap();
1117
1118        assert_eq!(result, "Test response");
1119
1120        // Verify tracer recorded events
1121        assert_eq!(tracer.len(), 2); // One LLM call + one LLM response
1122
1123        // Verify correlation ID is preserved
1124        let summaries = tracer.get_event_summaries(None, None, None);
1125        assert!(summaries[0].contains(correlation_id));
1126        assert!(summaries[1].contains(correlation_id));
1127    }
1128
1129    #[tokio::test]
1130    async fn test_tracer_with_tool_calls() {
1131        use crate::tracer::TracerSystem;
1132
1133        let tool_call = LlmToolCall {
1134            id: Some("call_1".to_string()),
1135            name: "test_tool".to_string(),
1136            arguments: HashMap::new(),
1137        };
1138
1139        let first_response = LlmGatewayResponse {
1140            content: None,
1141            object: None,
1142            tool_calls: vec![tool_call],
1143            thinking: None,
1144        };
1145
1146        let second_response = LlmGatewayResponse {
1147            content: Some("After tool".to_string()),
1148            object: None,
1149            tool_calls: vec![],
1150            thinking: None,
1151        };
1152
1153        let gateway = Arc::new(MockGateway::new(vec![first_response, second_response]));
1154        let tracer = Arc::new(TracerSystem::default());
1155        let broker = LlmBroker::new("test-model", gateway, Some(tracer.clone()));
1156
1157        let tool = MockTool {
1158            name: "test_tool".to_string(),
1159            result: serde_json::json!({"result": "success"}),
1160        };
1161        let tools: Vec<Box<dyn LlmTool>> = vec![Box::new(tool)];
1162
1163        let messages = vec![LlmMessage::user("Use tool")];
1164        let correlation_id = "tool-test-456";
1165
1166        let result = broker
1167            .generate(&messages, Some(&tools), None, Some(correlation_id.to_string()))
1168            .await
1169            .unwrap();
1170
1171        assert_eq!(result, "After tool");
1172
1173        // Should have: 2 LLM calls, 2 LLM responses, 1 tool call
1174        assert_eq!(tracer.len(), 5);
1175
1176        // Verify all events share the same correlation ID
1177        let summaries = tracer.get_event_summaries(None, None, None);
1178        for summary in &summaries {
1179            assert!(summary.contains(correlation_id));
1180        }
1181    }
1182}