1use crate::error::Result;
7use crate::llm::broker::LlmBroker;
8use crate::llm::gateway::CompletionConfig;
9use crate::llm::gateways::TokenizerGateway;
10use crate::llm::models::{LlmMessage, MessageRole};
11use crate::llm::tools::LlmTool;
12use futures::stream::{Stream, StreamExt};
13use serde::{Deserialize, Serialize};
14use std::pin::Pin;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct SizedLlmMessage {
22 #[serde(flatten)]
23 pub message: LlmMessage,
24 pub token_length: usize,
25}
26
27impl SizedLlmMessage {
28 pub fn new(message: LlmMessage, token_length: usize) -> Self {
30 Self {
31 message,
32 token_length,
33 }
34 }
35
36 pub fn role(&self) -> MessageRole {
38 self.message.role
39 }
40
41 pub fn content(&self) -> Option<&str> {
43 self.message.content.as_deref()
44 }
45}
46
47pub struct ChatSession {
73 broker: LlmBroker,
74 messages: Vec<SizedLlmMessage>,
75 tools: Option<Vec<Box<dyn LlmTool>>>,
76 max_context: usize,
77 tokenizer_gateway: TokenizerGateway,
78 temperature: f32,
79}
80
81impl ChatSession {
82 pub fn new(broker: LlmBroker) -> Self {
100 Self::builder(broker).build()
101 }
102
103 pub fn builder(broker: LlmBroker) -> ChatSessionBuilder {
121 ChatSessionBuilder::new(broker)
122 }
123
124 pub async fn send(&mut self, query: &str) -> Result<String> {
147 self.insert_message(LlmMessage::user(query));
149
150 let messages: Vec<LlmMessage> = self.messages.iter().map(|m| m.message.clone()).collect();
152 let config = CompletionConfig {
153 temperature: self.temperature,
154 ..Default::default()
155 };
156
157 let response = self
158 .broker
159 .generate(&messages, self.tools.as_deref(), Some(config), None)
160 .await?;
161
162 self.ensure_all_messages_are_sized();
164
165 self.insert_message(LlmMessage::assistant(&response));
167
168 Ok(response)
169 }
170
171 pub fn send_stream<'a>(
198 &'a mut self,
199 query: &str,
200 ) -> Pin<Box<dyn Stream<Item = Result<String>> + 'a>> {
201 self.insert_message(LlmMessage::user(query));
203
204 let messages: Vec<LlmMessage> = self.messages.iter().map(|m| m.message.clone()).collect();
206 let config = CompletionConfig {
207 temperature: self.temperature,
208 ..Default::default()
209 };
210
211 Box::pin(async_stream::stream! {
212 let mut accumulated = Vec::new();
213 let tools_ref = self.tools.as_deref();
214 let mut inner_stream = self.broker.generate_stream(&messages, tools_ref, Some(config), None);
215
216 while let Some(result) = inner_stream.next().await {
217 match &result {
218 Ok(chunk) => {
219 accumulated.push(chunk.clone());
220 yield result;
221 }
222 Err(_) => {
223 yield result;
224 return;
225 }
226 }
227 }
228
229 drop(inner_stream);
231 self.ensure_all_messages_are_sized();
232 let full_response = accumulated.join("");
233 self.insert_message(LlmMessage::assistant(&full_response));
234 })
235 }
236
237 pub fn insert_message(&mut self, message: LlmMessage) {
247 let sized_message = self.build_sized_message(message);
248 self.messages.push(sized_message);
249
250 let mut total_length: usize = self.messages.iter().map(|m| m.token_length).sum();
252
253 while total_length > self.max_context && self.messages.len() > 1 {
254 let removed = self.messages.remove(1);
256 total_length -= removed.token_length;
257 }
258 }
259
260 pub fn messages(&self) -> &[SizedLlmMessage] {
262 &self.messages
263 }
264
265 pub fn total_tokens(&self) -> usize {
267 self.messages.iter().map(|m| m.token_length).sum()
268 }
269
270 fn build_sized_message(&self, message: LlmMessage) -> SizedLlmMessage {
272 let token_length = if let Some(content) = &message.content {
273 self.tokenizer_gateway.encode(content).len()
274 } else {
275 0
276 };
277
278 SizedLlmMessage::new(message, token_length)
279 }
280
281 fn ensure_all_messages_are_sized(&mut self) {
283 for i in 0..self.messages.len() {
284 if self.messages[i].token_length == 0 && self.messages[i].message.content.is_some() {
285 let content = self.messages[i].message.content.clone().unwrap();
286 let token_length = self.tokenizer_gateway.encode(&content).len();
287 self.messages[i].token_length = token_length;
288 }
289 }
290 }
291}
292
293pub struct ChatSessionBuilder {
295 broker: LlmBroker,
296 system_prompt: String,
297 tools: Option<Vec<Box<dyn LlmTool>>>,
298 max_context: usize,
299 tokenizer_gateway: Option<TokenizerGateway>,
300 temperature: f32,
301}
302
303impl ChatSessionBuilder {
304 fn new(broker: LlmBroker) -> Self {
306 Self {
307 broker,
308 system_prompt: "You are a helpful assistant.".to_string(),
309 tools: None,
310 max_context: 32768,
311 tokenizer_gateway: None,
312 temperature: 1.0,
313 }
314 }
315
316 pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
318 self.system_prompt = prompt.into();
319 self
320 }
321
322 pub fn tools(mut self, tools: Vec<Box<dyn LlmTool>>) -> Self {
324 self.tools = Some(tools);
325 self
326 }
327
328 pub fn max_context(mut self, max_context: usize) -> Self {
330 self.max_context = max_context;
331 self
332 }
333
334 pub fn tokenizer_gateway(mut self, gateway: TokenizerGateway) -> Self {
336 self.tokenizer_gateway = Some(gateway);
337 self
338 }
339
340 pub fn temperature(mut self, temperature: f32) -> Self {
342 self.temperature = temperature;
343 self
344 }
345
346 pub fn build(self) -> ChatSession {
348 let tokenizer_gateway = self.tokenizer_gateway.unwrap_or_default();
349 let system_message = LlmMessage::system(&self.system_prompt);
350 let token_length = tokenizer_gateway.encode(&self.system_prompt).len();
351
352 ChatSession {
353 broker: self.broker,
354 messages: vec![SizedLlmMessage::new(system_message, token_length)],
355 tools: self.tools,
356 max_context: self.max_context,
357 tokenizer_gateway,
358 temperature: self.temperature,
359 }
360 }
361}
362
363#[cfg(test)]
364mod tests {
365 use super::*;
366 use crate::llm::gateway::{LlmGateway, StreamChunk};
367 use crate::llm::models::LlmGatewayResponse;
368 use crate::llm::tools::{FunctionDescriptor, ToolDescriptor};
369 use futures::stream::{self, Stream};
370 use serde_json::{json, Value};
371 use std::collections::HashMap;
372 use std::pin::Pin;
373 use std::sync::{Arc, Mutex};
374
375 struct MockGateway {
377 responses: Vec<String>,
378 call_count: Mutex<usize>,
379 }
380
381 impl MockGateway {
382 fn new(responses: Vec<String>) -> Self {
383 Self {
384 responses,
385 call_count: Mutex::new(0),
386 }
387 }
388 }
389
390 #[async_trait::async_trait]
391 impl LlmGateway for MockGateway {
392 async fn complete(
393 &self,
394 _model: &str,
395 _messages: &[LlmMessage],
396 _tools: Option<&[Box<dyn LlmTool>]>,
397 _config: &CompletionConfig,
398 ) -> Result<LlmGatewayResponse> {
399 let mut count = self.call_count.lock().unwrap();
400 let idx = *count;
401 *count += 1;
402
403 let content = if idx < self.responses.len() {
404 self.responses[idx].clone()
405 } else {
406 "default response".to_string()
407 };
408
409 Ok(LlmGatewayResponse {
410 content: Some(content),
411 object: None,
412 tool_calls: vec![],
413 thinking: None,
414 })
415 }
416
417 async fn complete_json(
418 &self,
419 _model: &str,
420 _messages: &[LlmMessage],
421 _schema: Value,
422 _config: &CompletionConfig,
423 ) -> Result<Value> {
424 Ok(json!({}))
425 }
426
427 async fn get_available_models(&self) -> Result<Vec<String>> {
428 Ok(vec!["test-model".to_string()])
429 }
430
431 async fn calculate_embeddings(
432 &self,
433 _text: &str,
434 _model: Option<&str>,
435 ) -> Result<Vec<f32>> {
436 Ok(vec![0.1, 0.2, 0.3])
437 }
438
439 fn complete_stream<'a>(
440 &'a self,
441 _model: &'a str,
442 _messages: &'a [LlmMessage],
443 _tools: Option<&'a [Box<dyn LlmTool>]>,
444 _config: &'a CompletionConfig,
445 ) -> Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send + 'a>> {
446 Box::pin(stream::iter(vec![Ok(StreamChunk::Content("test".to_string()))]))
447 }
448 }
449
450 #[derive(Clone)]
452 struct MockTool {
453 name: String,
454 }
455
456 impl LlmTool for MockTool {
457 fn run(&self, _args: &HashMap<String, Value>) -> Result<Value> {
458 Ok(json!({"result": "success"}))
459 }
460
461 fn descriptor(&self) -> ToolDescriptor {
462 ToolDescriptor {
463 r#type: "function".to_string(),
464 function: FunctionDescriptor {
465 name: self.name.clone(),
466 description: "A mock tool".to_string(),
467 parameters: json!({}),
468 },
469 }
470 }
471
472 fn clone_box(&self) -> Box<dyn LlmTool> {
473 Box::new(self.clone())
474 }
475 }
476
477 #[tokio::test]
478 async fn test_new_session_has_system_message() {
479 let gateway = Arc::new(MockGateway::new(vec![]));
480 let broker = LlmBroker::new("test-model", gateway, None);
481 let session = ChatSession::new(broker);
482
483 assert_eq!(session.messages.len(), 1);
484 assert_eq!(session.messages[0].role(), MessageRole::System);
485 assert_eq!(session.messages[0].content(), Some("You are a helpful assistant."));
486 }
487
488 #[tokio::test]
489 async fn test_builder_custom_system_prompt() {
490 let gateway = Arc::new(MockGateway::new(vec![]));
491 let broker = LlmBroker::new("test-model", gateway, None);
492 let session = ChatSession::builder(broker).system_prompt("Custom system prompt").build();
493
494 assert_eq!(session.messages.len(), 1);
495 assert_eq!(session.messages[0].content(), Some("Custom system prompt"));
496 }
497
498 #[tokio::test]
499 async fn test_builder_custom_temperature() {
500 let gateway = Arc::new(MockGateway::new(vec![]));
501 let broker = LlmBroker::new("test-model", gateway, None);
502 let session = ChatSession::builder(broker).temperature(0.5).build();
503
504 assert_eq!(session.temperature, 0.5);
505 }
506
507 #[tokio::test]
508 async fn test_builder_custom_max_context() {
509 let gateway = Arc::new(MockGateway::new(vec![]));
510 let broker = LlmBroker::new("test-model", gateway, None);
511 let session = ChatSession::builder(broker).max_context(16384).build();
512
513 assert_eq!(session.max_context, 16384);
514 }
515
516 #[tokio::test]
517 async fn test_send_adds_messages_to_history() {
518 let gateway = Arc::new(MockGateway::new(vec!["Hello, World!".to_string()]));
519 let broker = LlmBroker::new("test-model", gateway, None);
520 let mut session = ChatSession::new(broker);
521
522 let response = session.send("Hi").await.unwrap();
523
524 assert_eq!(response, "Hello, World!");
525 assert_eq!(session.messages.len(), 3);
527 assert_eq!(session.messages[1].role(), MessageRole::User);
528 assert_eq!(session.messages[1].content(), Some("Hi"));
529 assert_eq!(session.messages[2].role(), MessageRole::Assistant);
530 assert_eq!(session.messages[2].content(), Some("Hello, World!"));
531 }
532
533 #[tokio::test]
534 async fn test_send_multiple_turns() {
535 let gateway = Arc::new(MockGateway::new(vec![
536 "First response".to_string(),
537 "Second response".to_string(),
538 ]));
539 let broker = LlmBroker::new("test-model", gateway, None);
540 let mut session = ChatSession::new(broker);
541
542 session.send("First query").await.unwrap();
543 session.send("Second query").await.unwrap();
544
545 assert_eq!(session.messages.len(), 5);
547 assert_eq!(session.messages[3].content(), Some("Second query"));
548 assert_eq!(session.messages[4].content(), Some("Second response"));
549 }
550
551 #[tokio::test]
552 async fn test_insert_message_calculates_token_length() {
553 let gateway = Arc::new(MockGateway::new(vec![]));
554 let broker = LlmBroker::new("test-model", gateway, None);
555 let mut session = ChatSession::new(broker);
556
557 session.insert_message(LlmMessage::user("Hello"));
558
559 assert_eq!(session.messages.len(), 2);
560 assert!(session.messages[1].token_length > 0);
561 }
562
563 #[tokio::test]
564 async fn test_context_window_trimming() {
565 let gateway = Arc::new(MockGateway::new(vec![]));
566 let broker = LlmBroker::new("test-model", gateway, None);
567
568 let mut session = ChatSession::builder(broker).max_context(50).build();
570
571 for i in 0..10 {
573 session.insert_message(LlmMessage::user(format!(
574 "This is a longer message number {} with more content to increase token count",
575 i
576 )));
577 }
578
579 assert!(session.messages.len() < 11); assert_eq!(session.messages[0].role(), MessageRole::System);
584
585 assert!(session.total_tokens() <= 50);
587 }
588
589 #[tokio::test]
590 async fn test_context_window_preserves_system_prompt() {
591 let gateway = Arc::new(MockGateway::new(vec![]));
592 let broker = LlmBroker::new("test-model", gateway, None);
593
594 let mut session = ChatSession::builder(broker)
595 .system_prompt("Important system prompt")
596 .max_context(50)
597 .build();
598
599 for i in 0..20 {
601 session.insert_message(LlmMessage::user(format!("Message {}", i)));
602 }
603
604 assert_eq!(session.messages[0].role(), MessageRole::System);
606 assert_eq!(session.messages[0].content(), Some("Important system prompt"));
607 }
608
609 #[tokio::test]
610 async fn test_total_tokens() {
611 let gateway = Arc::new(MockGateway::new(vec![]));
612 let broker = LlmBroker::new("test-model", gateway, None);
613 let mut session = ChatSession::new(broker);
614
615 let initial_tokens = session.total_tokens();
616 assert!(initial_tokens > 0); session.insert_message(LlmMessage::user("Hello"));
619
620 assert!(session.total_tokens() > initial_tokens);
621 }
622
623 #[tokio::test]
624 async fn test_messages_accessor() {
625 let gateway = Arc::new(MockGateway::new(vec![]));
626 let broker = LlmBroker::new("test-model", gateway, None);
627 let mut session = ChatSession::new(broker);
628
629 session.insert_message(LlmMessage::user("Test"));
630
631 let messages = session.messages();
632 assert_eq!(messages.len(), 2);
633 assert_eq!(messages[0].role(), MessageRole::System);
634 assert_eq!(messages[1].role(), MessageRole::User);
635 }
636
637 #[tokio::test]
638 async fn test_builder_with_tools() {
639 let gateway = Arc::new(MockGateway::new(vec![]));
640 let broker = LlmBroker::new("test-model", gateway, None);
641
642 let tool: Box<dyn LlmTool> = Box::new(MockTool {
643 name: "test_tool".to_string(),
644 });
645
646 let session = ChatSession::builder(broker).tools(vec![tool]).build();
647
648 assert!(session.tools.is_some());
649 assert_eq!(session.tools.as_ref().unwrap().len(), 1);
650 }
651
652 #[tokio::test]
653 async fn test_sized_message_creation() {
654 let message = LlmMessage::user("Test content");
655 let sized = SizedLlmMessage::new(message, 5);
656
657 assert_eq!(sized.token_length, 5);
658 assert_eq!(sized.role(), MessageRole::User);
659 assert_eq!(sized.content(), Some("Test content"));
660 }
661
662 struct StreamingMockGateway {
664 stream_chunks: Vec<Vec<String>>,
665 call_count: Mutex<usize>,
666 }
667
668 impl StreamingMockGateway {
669 fn new(stream_chunks: Vec<Vec<String>>) -> Self {
670 Self {
671 stream_chunks,
672 call_count: Mutex::new(0),
673 }
674 }
675 }
676
677 #[async_trait::async_trait]
678 impl LlmGateway for StreamingMockGateway {
679 async fn complete(
680 &self,
681 _model: &str,
682 _messages: &[LlmMessage],
683 _tools: Option<&[Box<dyn LlmTool>]>,
684 _config: &CompletionConfig,
685 ) -> Result<LlmGatewayResponse> {
686 Ok(LlmGatewayResponse {
687 content: Some("default".to_string()),
688 object: None,
689 tool_calls: vec![],
690 thinking: None,
691 })
692 }
693
694 async fn complete_json(
695 &self,
696 _model: &str,
697 _messages: &[LlmMessage],
698 _schema: Value,
699 _config: &CompletionConfig,
700 ) -> Result<Value> {
701 Ok(json!({}))
702 }
703
704 async fn get_available_models(&self) -> Result<Vec<String>> {
705 Ok(vec!["test-model".to_string()])
706 }
707
708 async fn calculate_embeddings(
709 &self,
710 _text: &str,
711 _model: Option<&str>,
712 ) -> Result<Vec<f32>> {
713 Ok(vec![0.1, 0.2, 0.3])
714 }
715
716 fn complete_stream<'a>(
717 &'a self,
718 _model: &'a str,
719 _messages: &'a [LlmMessage],
720 _tools: Option<&'a [Box<dyn LlmTool>]>,
721 _config: &'a CompletionConfig,
722 ) -> Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send + 'a>> {
723 let mut count = self.call_count.lock().unwrap();
724 let idx = *count;
725 *count += 1;
726
727 let chunks = if idx < self.stream_chunks.len() {
728 self.stream_chunks[idx].clone()
729 } else {
730 vec!["default".to_string()]
731 };
732
733 Box::pin(stream::iter(
734 chunks.into_iter().map(|c| Ok(StreamChunk::Content(c))).collect::<Vec<_>>(),
735 ))
736 }
737 }
738
739 #[tokio::test]
740 async fn test_send_stream_yields_content_chunks() {
741 let gateway = Arc::new(StreamingMockGateway::new(vec![vec![
742 "Hello".to_string(),
743 " world".to_string(),
744 ]]));
745 let broker = LlmBroker::new("test-model", gateway, None);
746 let mut session = ChatSession::new(broker);
747
748 let mut chunks = Vec::new();
749 let mut stream = session.send_stream("Hi");
750 while let Some(result) = stream.next().await {
751 chunks.push(result.unwrap());
752 }
753
754 assert_eq!(chunks, vec!["Hello", " world"]);
755 }
756
757 #[tokio::test]
758 async fn test_send_stream_grows_message_history() {
759 let gateway = Arc::new(StreamingMockGateway::new(vec![vec!["Response".to_string()]]));
760 let broker = LlmBroker::new("test-model", gateway, None);
761 let mut session = ChatSession::new(broker);
762
763 {
764 let mut stream = session.send_stream("Hi");
765 while stream.next().await.is_some() {}
766 }
767
768 assert_eq!(session.messages.len(), 3);
770 }
771
772 #[tokio::test]
773 async fn test_send_stream_records_assembled_response() {
774 let gateway = Arc::new(StreamingMockGateway::new(vec![vec![
775 "Hello".to_string(),
776 " world".to_string(),
777 ]]));
778 let broker = LlmBroker::new("test-model", gateway, None);
779 let mut session = ChatSession::new(broker);
780
781 {
782 let mut stream = session.send_stream("Hi");
783 while stream.next().await.is_some() {}
784 }
785
786 assert_eq!(session.messages[2].content(), Some("Hello world"));
787 assert_eq!(session.messages[2].role(), MessageRole::Assistant);
788 }
789
790 #[tokio::test]
791 async fn test_send_stream_records_user_message() {
792 let gateway = Arc::new(StreamingMockGateway::new(vec![vec!["Response".to_string()]]));
793 let broker = LlmBroker::new("test-model", gateway, None);
794 let mut session = ChatSession::new(broker);
795
796 {
797 let mut stream = session.send_stream("My question");
798 while stream.next().await.is_some() {}
799 }
800
801 assert_eq!(session.messages[1].role(), MessageRole::User);
802 assert_eq!(session.messages[1].content(), Some("My question"));
803 }
804
805 #[tokio::test]
806 async fn test_send_stream_respects_context_capacity() {
807 let gateway = Arc::new(StreamingMockGateway::new(vec![
808 vec!["This is a longer response to consume tokens in the context window".to_string()],
809 vec!["Another longer response that also consumes many tokens in context".to_string()],
810 ]));
811 let broker = LlmBroker::new("test-model", gateway, None);
812 let mut session = ChatSession::builder(broker).max_context(50).build();
813
814 {
815 let mut stream = session.send_stream("First longer query message with extra words");
816 while stream.next().await.is_some() {}
817 }
818 {
819 let mut stream = session.send_stream("Second longer query message with extra words");
820 while stream.next().await.is_some() {}
821 }
822
823 assert_eq!(session.messages[0].role(), MessageRole::System);
825 assert!(session.total_tokens() <= 50);
827 }
828
829 #[tokio::test]
830 async fn test_message_with_no_content_has_zero_tokens() {
831 let gateway = Arc::new(MockGateway::new(vec![]));
832 let broker = LlmBroker::new("test-model", gateway, None);
833 let mut session = ChatSession::new(broker);
834
835 let message = LlmMessage {
836 role: MessageRole::Assistant,
837 content: None,
838 tool_calls: None,
839 image_paths: None,
840 };
841
842 session.insert_message(message);
843
844 assert_eq!(session.messages.len(), 2);
846 assert_eq!(session.messages[1].token_length, 0);
847 }
848}