mojentic/llm/
chat_session.rs

1//! Chat session management with context window tracking.
2//!
3//! This module provides a chat session abstraction that manages conversation history
4//! and automatically handles context window limits using token counting.
5
6use crate::error::Result;
7use crate::llm::broker::LlmBroker;
8use crate::llm::gateway::CompletionConfig;
9use crate::llm::gateways::TokenizerGateway;
10use crate::llm::models::{LlmMessage, MessageRole};
11use crate::llm::tools::LlmTool;
12use futures::stream::{Stream, StreamExt};
13use serde::{Deserialize, Serialize};
14use std::pin::Pin;
15
16/// An LLM message with token count metadata.
17///
18/// This extends the standard `LlmMessage` with token length information
19/// for context window management.
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct SizedLlmMessage {
22    #[serde(flatten)]
23    pub message: LlmMessage,
24    pub token_length: usize,
25}
26
27impl SizedLlmMessage {
28    /// Create a new sized message
29    pub fn new(message: LlmMessage, token_length: usize) -> Self {
30        Self {
31            message,
32            token_length,
33        }
34    }
35
36    /// Get the role of the message
37    pub fn role(&self) -> MessageRole {
38        self.message.role
39    }
40
41    /// Get the content of the message
42    pub fn content(&self) -> Option<&str> {
43        self.message.content.as_deref()
44    }
45}
46
47/// A chat session that manages conversation history with context window limits.
48///
49/// `ChatSession` maintains a list of messages and automatically trims old messages
50/// when the total token count exceeds the configured maximum context size. The system
51/// prompt (first message) is always preserved.
52///
53/// # Examples
54///
55/// ```ignore
56/// use mojentic::llm::{ChatSession, LlmBroker};
57/// use mojentic::llm::gateways::OllamaGateway;
58/// use std::sync::Arc;
59///
60/// #[tokio::main]
61/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
62///     let gateway = Arc::new(OllamaGateway::default());
63///     let broker = LlmBroker::new("qwen3:32b", gateway);
64///     let mut session = ChatSession::new(broker);
65///
66///     let response = session.send("What is Rust?").await?;
67///     println!("Response: {}", response);
68///
69///     Ok(())
70/// }
71/// ```
72pub struct ChatSession {
73    broker: LlmBroker,
74    messages: Vec<SizedLlmMessage>,
75    tools: Option<Vec<Box<dyn LlmTool>>>,
76    max_context: usize,
77    tokenizer_gateway: TokenizerGateway,
78    temperature: f32,
79}
80
81impl ChatSession {
82    /// Create a new chat session with default settings.
83    ///
84    /// # Arguments
85    ///
86    /// * `broker` - The LLM broker to use for generating responses
87    ///
88    /// # Examples
89    ///
90    /// ```ignore
91    /// use mojentic::llm::{ChatSession, LlmBroker};
92    /// use mojentic::llm::gateways::OllamaGateway;
93    /// use std::sync::Arc;
94    ///
95    /// let gateway = Arc::new(OllamaGateway::default());
96    /// let broker = LlmBroker::new("qwen3:32b", gateway);
97    /// let session = ChatSession::new(broker);
98    /// ```
99    pub fn new(broker: LlmBroker) -> Self {
100        Self::builder(broker).build()
101    }
102
103    /// Create a chat session builder for custom configuration.
104    ///
105    /// # Arguments
106    ///
107    /// * `broker` - The LLM broker to use for generating responses
108    ///
109    /// # Examples
110    ///
111    /// ```ignore
112    /// use mojentic::llm::ChatSession;
113    ///
114    /// let session = ChatSession::builder(broker)
115    ///     .system_prompt("You are a helpful coding assistant.")
116    ///     .temperature(0.7)
117    ///     .max_context(16384)
118    ///     .build();
119    /// ```
120    pub fn builder(broker: LlmBroker) -> ChatSessionBuilder {
121        ChatSessionBuilder::new(broker)
122    }
123
124    /// Send a message to the LLM and get a response.
125    ///
126    /// This method:
127    /// 1. Adds the user message to the conversation history
128    /// 2. Generates a response using the LLM
129    /// 3. Adds the assistant's response to the history
130    /// 4. Automatically trims old messages if context window is exceeded
131    ///
132    /// # Arguments
133    ///
134    /// * `query` - The user's message
135    ///
136    /// # Returns
137    ///
138    /// The LLM's response as a string
139    ///
140    /// # Examples
141    ///
142    /// ```ignore
143    /// let response = session.send("What is 2 + 2?").await?;
144    /// println!("Answer: {}", response);
145    /// ```
146    pub async fn send(&mut self, query: &str) -> Result<String> {
147        // Add user message
148        self.insert_message(LlmMessage::user(query));
149
150        // Generate response
151        let messages: Vec<LlmMessage> = self.messages.iter().map(|m| m.message.clone()).collect();
152        let config = CompletionConfig {
153            temperature: self.temperature,
154            ..Default::default()
155        };
156
157        let response = self
158            .broker
159            .generate(&messages, self.tools.as_deref(), Some(config), None)
160            .await?;
161
162        // Ensure all messages in history have token counts
163        self.ensure_all_messages_are_sized();
164
165        // Add assistant response
166        self.insert_message(LlmMessage::assistant(&response));
167
168        Ok(response)
169    }
170
171    /// Send a message to the LLM and get a streaming response.
172    ///
173    /// This method:
174    /// 1. Adds the user message to the conversation history
175    /// 2. Streams the response from the LLM, yielding chunks as they arrive
176    /// 3. After the stream is fully consumed, adds the assembled response to history
177    /// 4. Automatically trims old messages if context window is exceeded
178    ///
179    /// # Arguments
180    ///
181    /// * `query` - The user's message
182    ///
183    /// # Returns
184    ///
185    /// A stream of string chunks from the LLM response
186    ///
187    /// # Examples
188    ///
189    /// ```ignore
190    /// use futures::stream::StreamExt;
191    ///
192    /// let mut stream = session.send_stream("Tell me a story");
193    /// while let Some(result) = stream.next().await {
194    ///     print!("{}", result?);
195    /// }
196    /// ```
197    pub fn send_stream<'a>(
198        &'a mut self,
199        query: &str,
200    ) -> Pin<Box<dyn Stream<Item = Result<String>> + 'a>> {
201        // Add user message
202        self.insert_message(LlmMessage::user(query));
203
204        // Clone messages for the broker call
205        let messages: Vec<LlmMessage> = self.messages.iter().map(|m| m.message.clone()).collect();
206        let config = CompletionConfig {
207            temperature: self.temperature,
208            ..Default::default()
209        };
210
211        Box::pin(async_stream::stream! {
212            let mut accumulated = Vec::new();
213            let tools_ref = self.tools.as_deref();
214            let mut inner_stream = self.broker.generate_stream(&messages, tools_ref, Some(config), None);
215
216            while let Some(result) = inner_stream.next().await {
217                match &result {
218                    Ok(chunk) => {
219                        accumulated.push(chunk.clone());
220                        yield result;
221                    }
222                    Err(_) => {
223                        yield result;
224                        return;
225                    }
226                }
227            }
228
229            // Stream consumed — finalize
230            drop(inner_stream);
231            self.ensure_all_messages_are_sized();
232            let full_response = accumulated.join("");
233            self.insert_message(LlmMessage::assistant(&full_response));
234        })
235    }
236
237    /// Insert a message into the conversation history.
238    ///
239    /// If the total token count exceeds `max_context`, the oldest messages
240    /// are removed until the total is under the limit. The system prompt
241    /// (index 0) is always preserved.
242    ///
243    /// # Arguments
244    ///
245    /// * `message` - The message to add
246    pub fn insert_message(&mut self, message: LlmMessage) {
247        let sized_message = self.build_sized_message(message);
248        self.messages.push(sized_message);
249
250        // Trim messages if over context limit
251        let mut total_length: usize = self.messages.iter().map(|m| m.token_length).sum();
252
253        while total_length > self.max_context && self.messages.len() > 1 {
254            // Remove the oldest message (index 1 to preserve system prompt at 0)
255            let removed = self.messages.remove(1);
256            total_length -= removed.token_length;
257        }
258    }
259
260    /// Get the current conversation history
261    pub fn messages(&self) -> &[SizedLlmMessage] {
262        &self.messages
263    }
264
265    /// Get the total token count of the current conversation
266    pub fn total_tokens(&self) -> usize {
267        self.messages.iter().map(|m| m.token_length).sum()
268    }
269
270    /// Build a sized message from a regular message
271    fn build_sized_message(&self, message: LlmMessage) -> SizedLlmMessage {
272        let token_length = if let Some(content) = &message.content {
273            self.tokenizer_gateway.encode(content).len()
274        } else {
275            0
276        };
277
278        SizedLlmMessage::new(message, token_length)
279    }
280
281    /// Ensure all messages in history have token counts
282    fn ensure_all_messages_are_sized(&mut self) {
283        for i in 0..self.messages.len() {
284            if self.messages[i].token_length == 0 && self.messages[i].message.content.is_some() {
285                let content = self.messages[i].message.content.clone().unwrap();
286                let token_length = self.tokenizer_gateway.encode(&content).len();
287                self.messages[i].token_length = token_length;
288            }
289        }
290    }
291}
292
293/// Builder for constructing a `ChatSession` with custom configuration.
294pub struct ChatSessionBuilder {
295    broker: LlmBroker,
296    system_prompt: String,
297    tools: Option<Vec<Box<dyn LlmTool>>>,
298    max_context: usize,
299    tokenizer_gateway: Option<TokenizerGateway>,
300    temperature: f32,
301}
302
303impl ChatSessionBuilder {
304    /// Create a new builder
305    fn new(broker: LlmBroker) -> Self {
306        Self {
307            broker,
308            system_prompt: "You are a helpful assistant.".to_string(),
309            tools: None,
310            max_context: 32768,
311            tokenizer_gateway: None,
312            temperature: 1.0,
313        }
314    }
315
316    /// Set the system prompt (default: "You are a helpful assistant.")
317    pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
318        self.system_prompt = prompt.into();
319        self
320    }
321
322    /// Set the tools available to the LLM
323    pub fn tools(mut self, tools: Vec<Box<dyn LlmTool>>) -> Self {
324        self.tools = Some(tools);
325        self
326    }
327
328    /// Set the maximum context window in tokens (default: 32768)
329    pub fn max_context(mut self, max_context: usize) -> Self {
330        self.max_context = max_context;
331        self
332    }
333
334    /// Set a custom tokenizer gateway (default: cl100k_base)
335    pub fn tokenizer_gateway(mut self, gateway: TokenizerGateway) -> Self {
336        self.tokenizer_gateway = Some(gateway);
337        self
338    }
339
340    /// Set the temperature for generation (default: 1.0)
341    pub fn temperature(mut self, temperature: f32) -> Self {
342        self.temperature = temperature;
343        self
344    }
345
346    /// Build the chat session
347    pub fn build(self) -> ChatSession {
348        let tokenizer_gateway = self.tokenizer_gateway.unwrap_or_default();
349        let system_message = LlmMessage::system(&self.system_prompt);
350        let token_length = tokenizer_gateway.encode(&self.system_prompt).len();
351
352        ChatSession {
353            broker: self.broker,
354            messages: vec![SizedLlmMessage::new(system_message, token_length)],
355            tools: self.tools,
356            max_context: self.max_context,
357            tokenizer_gateway,
358            temperature: self.temperature,
359        }
360    }
361}
362
363#[cfg(test)]
364mod tests {
365    use super::*;
366    use crate::llm::gateway::{LlmGateway, StreamChunk};
367    use crate::llm::models::LlmGatewayResponse;
368    use crate::llm::tools::{FunctionDescriptor, ToolDescriptor};
369    use futures::stream::{self, Stream};
370    use serde_json::{json, Value};
371    use std::collections::HashMap;
372    use std::pin::Pin;
373    use std::sync::{Arc, Mutex};
374
375    // Mock gateway for testing
376    struct MockGateway {
377        responses: Vec<String>,
378        call_count: Mutex<usize>,
379    }
380
381    impl MockGateway {
382        fn new(responses: Vec<String>) -> Self {
383            Self {
384                responses,
385                call_count: Mutex::new(0),
386            }
387        }
388    }
389
390    #[async_trait::async_trait]
391    impl LlmGateway for MockGateway {
392        async fn complete(
393            &self,
394            _model: &str,
395            _messages: &[LlmMessage],
396            _tools: Option<&[Box<dyn LlmTool>]>,
397            _config: &CompletionConfig,
398        ) -> Result<LlmGatewayResponse> {
399            let mut count = self.call_count.lock().unwrap();
400            let idx = *count;
401            *count += 1;
402
403            let content = if idx < self.responses.len() {
404                self.responses[idx].clone()
405            } else {
406                "default response".to_string()
407            };
408
409            Ok(LlmGatewayResponse {
410                content: Some(content),
411                object: None,
412                tool_calls: vec![],
413                thinking: None,
414            })
415        }
416
417        async fn complete_json(
418            &self,
419            _model: &str,
420            _messages: &[LlmMessage],
421            _schema: Value,
422            _config: &CompletionConfig,
423        ) -> Result<Value> {
424            Ok(json!({}))
425        }
426
427        async fn get_available_models(&self) -> Result<Vec<String>> {
428            Ok(vec!["test-model".to_string()])
429        }
430
431        async fn calculate_embeddings(
432            &self,
433            _text: &str,
434            _model: Option<&str>,
435        ) -> Result<Vec<f32>> {
436            Ok(vec![0.1, 0.2, 0.3])
437        }
438
439        fn complete_stream<'a>(
440            &'a self,
441            _model: &'a str,
442            _messages: &'a [LlmMessage],
443            _tools: Option<&'a [Box<dyn LlmTool>]>,
444            _config: &'a CompletionConfig,
445        ) -> Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send + 'a>> {
446            Box::pin(stream::iter(vec![Ok(StreamChunk::Content("test".to_string()))]))
447        }
448    }
449
450    // Mock tool for testing
451    #[derive(Clone)]
452    struct MockTool {
453        name: String,
454    }
455
456    impl LlmTool for MockTool {
457        fn run(&self, _args: &HashMap<String, Value>) -> Result<Value> {
458            Ok(json!({"result": "success"}))
459        }
460
461        fn descriptor(&self) -> ToolDescriptor {
462            ToolDescriptor {
463                r#type: "function".to_string(),
464                function: FunctionDescriptor {
465                    name: self.name.clone(),
466                    description: "A mock tool".to_string(),
467                    parameters: json!({}),
468                },
469            }
470        }
471
472        fn clone_box(&self) -> Box<dyn LlmTool> {
473            Box::new(self.clone())
474        }
475    }
476
477    #[tokio::test]
478    async fn test_new_session_has_system_message() {
479        let gateway = Arc::new(MockGateway::new(vec![]));
480        let broker = LlmBroker::new("test-model", gateway, None);
481        let session = ChatSession::new(broker);
482
483        assert_eq!(session.messages.len(), 1);
484        assert_eq!(session.messages[0].role(), MessageRole::System);
485        assert_eq!(session.messages[0].content(), Some("You are a helpful assistant."));
486    }
487
488    #[tokio::test]
489    async fn test_builder_custom_system_prompt() {
490        let gateway = Arc::new(MockGateway::new(vec![]));
491        let broker = LlmBroker::new("test-model", gateway, None);
492        let session = ChatSession::builder(broker).system_prompt("Custom system prompt").build();
493
494        assert_eq!(session.messages.len(), 1);
495        assert_eq!(session.messages[0].content(), Some("Custom system prompt"));
496    }
497
498    #[tokio::test]
499    async fn test_builder_custom_temperature() {
500        let gateway = Arc::new(MockGateway::new(vec![]));
501        let broker = LlmBroker::new("test-model", gateway, None);
502        let session = ChatSession::builder(broker).temperature(0.5).build();
503
504        assert_eq!(session.temperature, 0.5);
505    }
506
507    #[tokio::test]
508    async fn test_builder_custom_max_context() {
509        let gateway = Arc::new(MockGateway::new(vec![]));
510        let broker = LlmBroker::new("test-model", gateway, None);
511        let session = ChatSession::builder(broker).max_context(16384).build();
512
513        assert_eq!(session.max_context, 16384);
514    }
515
516    #[tokio::test]
517    async fn test_send_adds_messages_to_history() {
518        let gateway = Arc::new(MockGateway::new(vec!["Hello, World!".to_string()]));
519        let broker = LlmBroker::new("test-model", gateway, None);
520        let mut session = ChatSession::new(broker);
521
522        let response = session.send("Hi").await.unwrap();
523
524        assert_eq!(response, "Hello, World!");
525        // Should have: system, user, assistant
526        assert_eq!(session.messages.len(), 3);
527        assert_eq!(session.messages[1].role(), MessageRole::User);
528        assert_eq!(session.messages[1].content(), Some("Hi"));
529        assert_eq!(session.messages[2].role(), MessageRole::Assistant);
530        assert_eq!(session.messages[2].content(), Some("Hello, World!"));
531    }
532
533    #[tokio::test]
534    async fn test_send_multiple_turns() {
535        let gateway = Arc::new(MockGateway::new(vec![
536            "First response".to_string(),
537            "Second response".to_string(),
538        ]));
539        let broker = LlmBroker::new("test-model", gateway, None);
540        let mut session = ChatSession::new(broker);
541
542        session.send("First query").await.unwrap();
543        session.send("Second query").await.unwrap();
544
545        // Should have: system, user1, assistant1, user2, assistant2
546        assert_eq!(session.messages.len(), 5);
547        assert_eq!(session.messages[3].content(), Some("Second query"));
548        assert_eq!(session.messages[4].content(), Some("Second response"));
549    }
550
551    #[tokio::test]
552    async fn test_insert_message_calculates_token_length() {
553        let gateway = Arc::new(MockGateway::new(vec![]));
554        let broker = LlmBroker::new("test-model", gateway, None);
555        let mut session = ChatSession::new(broker);
556
557        session.insert_message(LlmMessage::user("Hello"));
558
559        assert_eq!(session.messages.len(), 2);
560        assert!(session.messages[1].token_length > 0);
561    }
562
563    #[tokio::test]
564    async fn test_context_window_trimming() {
565        let gateway = Arc::new(MockGateway::new(vec![]));
566        let broker = LlmBroker::new("test-model", gateway, None);
567
568        // Create session with very small context window
569        let mut session = ChatSession::builder(broker).max_context(50).build();
570
571        // Add several messages with longer content to force trimming
572        for i in 0..10 {
573            session.insert_message(LlmMessage::user(format!(
574                "This is a longer message number {} with more content to increase token count",
575                i
576            )));
577        }
578
579        // Should have trimmed old messages
580        assert!(session.messages.len() < 11); // Less than 1 system + 10 user messages
581
582        // System prompt should still be first
583        assert_eq!(session.messages[0].role(), MessageRole::System);
584
585        // Total tokens should be under limit
586        assert!(session.total_tokens() <= 50);
587    }
588
589    #[tokio::test]
590    async fn test_context_window_preserves_system_prompt() {
591        let gateway = Arc::new(MockGateway::new(vec![]));
592        let broker = LlmBroker::new("test-model", gateway, None);
593
594        let mut session = ChatSession::builder(broker)
595            .system_prompt("Important system prompt")
596            .max_context(50)
597            .build();
598
599        // Add many messages to force trimming
600        for i in 0..20 {
601            session.insert_message(LlmMessage::user(format!("Message {}", i)));
602        }
603
604        // System prompt should still be first
605        assert_eq!(session.messages[0].role(), MessageRole::System);
606        assert_eq!(session.messages[0].content(), Some("Important system prompt"));
607    }
608
609    #[tokio::test]
610    async fn test_total_tokens() {
611        let gateway = Arc::new(MockGateway::new(vec![]));
612        let broker = LlmBroker::new("test-model", gateway, None);
613        let mut session = ChatSession::new(broker);
614
615        let initial_tokens = session.total_tokens();
616        assert!(initial_tokens > 0); // System prompt has tokens
617
618        session.insert_message(LlmMessage::user("Hello"));
619
620        assert!(session.total_tokens() > initial_tokens);
621    }
622
623    #[tokio::test]
624    async fn test_messages_accessor() {
625        let gateway = Arc::new(MockGateway::new(vec![]));
626        let broker = LlmBroker::new("test-model", gateway, None);
627        let mut session = ChatSession::new(broker);
628
629        session.insert_message(LlmMessage::user("Test"));
630
631        let messages = session.messages();
632        assert_eq!(messages.len(), 2);
633        assert_eq!(messages[0].role(), MessageRole::System);
634        assert_eq!(messages[1].role(), MessageRole::User);
635    }
636
637    #[tokio::test]
638    async fn test_builder_with_tools() {
639        let gateway = Arc::new(MockGateway::new(vec![]));
640        let broker = LlmBroker::new("test-model", gateway, None);
641
642        let tool: Box<dyn LlmTool> = Box::new(MockTool {
643            name: "test_tool".to_string(),
644        });
645
646        let session = ChatSession::builder(broker).tools(vec![tool]).build();
647
648        assert!(session.tools.is_some());
649        assert_eq!(session.tools.as_ref().unwrap().len(), 1);
650    }
651
652    #[tokio::test]
653    async fn test_sized_message_creation() {
654        let message = LlmMessage::user("Test content");
655        let sized = SizedLlmMessage::new(message, 5);
656
657        assert_eq!(sized.token_length, 5);
658        assert_eq!(sized.role(), MessageRole::User);
659        assert_eq!(sized.content(), Some("Test content"));
660    }
661
662    // Streaming mock gateway
663    struct StreamingMockGateway {
664        stream_chunks: Vec<Vec<String>>,
665        call_count: Mutex<usize>,
666    }
667
668    impl StreamingMockGateway {
669        fn new(stream_chunks: Vec<Vec<String>>) -> Self {
670            Self {
671                stream_chunks,
672                call_count: Mutex::new(0),
673            }
674        }
675    }
676
677    #[async_trait::async_trait]
678    impl LlmGateway for StreamingMockGateway {
679        async fn complete(
680            &self,
681            _model: &str,
682            _messages: &[LlmMessage],
683            _tools: Option<&[Box<dyn LlmTool>]>,
684            _config: &CompletionConfig,
685        ) -> Result<LlmGatewayResponse> {
686            Ok(LlmGatewayResponse {
687                content: Some("default".to_string()),
688                object: None,
689                tool_calls: vec![],
690                thinking: None,
691            })
692        }
693
694        async fn complete_json(
695            &self,
696            _model: &str,
697            _messages: &[LlmMessage],
698            _schema: Value,
699            _config: &CompletionConfig,
700        ) -> Result<Value> {
701            Ok(json!({}))
702        }
703
704        async fn get_available_models(&self) -> Result<Vec<String>> {
705            Ok(vec!["test-model".to_string()])
706        }
707
708        async fn calculate_embeddings(
709            &self,
710            _text: &str,
711            _model: Option<&str>,
712        ) -> Result<Vec<f32>> {
713            Ok(vec![0.1, 0.2, 0.3])
714        }
715
716        fn complete_stream<'a>(
717            &'a self,
718            _model: &'a str,
719            _messages: &'a [LlmMessage],
720            _tools: Option<&'a [Box<dyn LlmTool>]>,
721            _config: &'a CompletionConfig,
722        ) -> Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send + 'a>> {
723            let mut count = self.call_count.lock().unwrap();
724            let idx = *count;
725            *count += 1;
726
727            let chunks = if idx < self.stream_chunks.len() {
728                self.stream_chunks[idx].clone()
729            } else {
730                vec!["default".to_string()]
731            };
732
733            Box::pin(stream::iter(
734                chunks.into_iter().map(|c| Ok(StreamChunk::Content(c))).collect::<Vec<_>>(),
735            ))
736        }
737    }
738
739    #[tokio::test]
740    async fn test_send_stream_yields_content_chunks() {
741        let gateway = Arc::new(StreamingMockGateway::new(vec![vec![
742            "Hello".to_string(),
743            " world".to_string(),
744        ]]));
745        let broker = LlmBroker::new("test-model", gateway, None);
746        let mut session = ChatSession::new(broker);
747
748        let mut chunks = Vec::new();
749        let mut stream = session.send_stream("Hi");
750        while let Some(result) = stream.next().await {
751            chunks.push(result.unwrap());
752        }
753
754        assert_eq!(chunks, vec!["Hello", " world"]);
755    }
756
757    #[tokio::test]
758    async fn test_send_stream_grows_message_history() {
759        let gateway = Arc::new(StreamingMockGateway::new(vec![vec!["Response".to_string()]]));
760        let broker = LlmBroker::new("test-model", gateway, None);
761        let mut session = ChatSession::new(broker);
762
763        {
764            let mut stream = session.send_stream("Hi");
765            while stream.next().await.is_some() {}
766        }
767
768        // system + user + assistant
769        assert_eq!(session.messages.len(), 3);
770    }
771
772    #[tokio::test]
773    async fn test_send_stream_records_assembled_response() {
774        let gateway = Arc::new(StreamingMockGateway::new(vec![vec![
775            "Hello".to_string(),
776            " world".to_string(),
777        ]]));
778        let broker = LlmBroker::new("test-model", gateway, None);
779        let mut session = ChatSession::new(broker);
780
781        {
782            let mut stream = session.send_stream("Hi");
783            while stream.next().await.is_some() {}
784        }
785
786        assert_eq!(session.messages[2].content(), Some("Hello world"));
787        assert_eq!(session.messages[2].role(), MessageRole::Assistant);
788    }
789
790    #[tokio::test]
791    async fn test_send_stream_records_user_message() {
792        let gateway = Arc::new(StreamingMockGateway::new(vec![vec!["Response".to_string()]]));
793        let broker = LlmBroker::new("test-model", gateway, None);
794        let mut session = ChatSession::new(broker);
795
796        {
797            let mut stream = session.send_stream("My question");
798            while stream.next().await.is_some() {}
799        }
800
801        assert_eq!(session.messages[1].role(), MessageRole::User);
802        assert_eq!(session.messages[1].content(), Some("My question"));
803    }
804
805    #[tokio::test]
806    async fn test_send_stream_respects_context_capacity() {
807        let gateway = Arc::new(StreamingMockGateway::new(vec![
808            vec!["This is a longer response to consume tokens in the context window".to_string()],
809            vec!["Another longer response that also consumes many tokens in context".to_string()],
810        ]));
811        let broker = LlmBroker::new("test-model", gateway, None);
812        let mut session = ChatSession::builder(broker).max_context(50).build();
813
814        {
815            let mut stream = session.send_stream("First longer query message with extra words");
816            while stream.next().await.is_some() {}
817        }
818        {
819            let mut stream = session.send_stream("Second longer query message with extra words");
820            while stream.next().await.is_some() {}
821        }
822
823        // System prompt should still be first
824        assert_eq!(session.messages[0].role(), MessageRole::System);
825        // Total tokens should be under limit
826        assert!(session.total_tokens() <= 50);
827    }
828
829    #[tokio::test]
830    async fn test_message_with_no_content_has_zero_tokens() {
831        let gateway = Arc::new(MockGateway::new(vec![]));
832        let broker = LlmBroker::new("test-model", gateway, None);
833        let mut session = ChatSession::new(broker);
834
835        let message = LlmMessage {
836            role: MessageRole::Assistant,
837            content: None,
838            tool_calls: None,
839            image_paths: None,
840        };
841
842        session.insert_message(message);
843
844        // Should have system + the message with no content
845        assert_eq!(session.messages.len(), 2);
846        assert_eq!(session.messages[1].token_length, 0);
847    }
848}