mojentic/llm/
models.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3
4/// Message role in LLM conversation
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
6#[serde(rename_all = "lowercase")]
7pub enum MessageRole {
8    System,
9    User,
10    Assistant,
11    Tool,
12}
13
14/// Tool call from LLM
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct LlmToolCall {
17    #[serde(skip_serializing_if = "Option::is_none")]
18    pub id: Option<String>,
19    pub name: String,
20    pub arguments: HashMap<String, serde_json::Value>,
21}
22
23/// Message in LLM conversation
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct LlmMessage {
26    #[serde(default = "default_role")]
27    pub role: MessageRole,
28    #[serde(skip_serializing_if = "Option::is_none")]
29    pub content: Option<String>,
30    #[serde(skip_serializing_if = "Option::is_none")]
31    pub tool_calls: Option<Vec<LlmToolCall>>,
32    #[serde(skip_serializing_if = "Option::is_none")]
33    pub image_paths: Option<Vec<String>>,
34}
35
36fn default_role() -> MessageRole {
37    MessageRole::User
38}
39
40/// Response from LLM gateway
41#[derive(Debug, Clone)]
42pub struct LlmGatewayResponse<T = ()> {
43    pub content: Option<String>,
44    pub object: Option<T>,
45    pub tool_calls: Vec<LlmToolCall>,
46    pub thinking: Option<String>,
47}
48
49impl LlmMessage {
50    /// Create a user message
51    pub fn user(content: impl Into<String>) -> Self {
52        Self {
53            role: MessageRole::User,
54            content: Some(content.into()),
55            tool_calls: None,
56            image_paths: None,
57        }
58    }
59
60    /// Create a system message
61    pub fn system(content: impl Into<String>) -> Self {
62        Self {
63            role: MessageRole::System,
64            content: Some(content.into()),
65            tool_calls: None,
66            image_paths: None,
67        }
68    }
69
70    /// Create an assistant message
71    pub fn assistant(content: impl Into<String>) -> Self {
72        Self {
73            role: MessageRole::Assistant,
74            content: Some(content.into()),
75            tool_calls: None,
76            image_paths: None,
77        }
78    }
79
80    /// Add image paths to this message
81    pub fn with_images(mut self, paths: Vec<String>) -> Self {
82        self.image_paths = Some(paths);
83        self
84    }
85}
86
87#[cfg(test)]
88mod tests {
89    use super::*;
90
91    #[test]
92    fn test_message_role_serialization() {
93        assert_eq!(serde_json::to_string(&MessageRole::System).unwrap(), "\"system\"");
94        assert_eq!(serde_json::to_string(&MessageRole::User).unwrap(), "\"user\"");
95        assert_eq!(serde_json::to_string(&MessageRole::Assistant).unwrap(), "\"assistant\"");
96        assert_eq!(serde_json::to_string(&MessageRole::Tool).unwrap(), "\"tool\"");
97    }
98
99    #[test]
100    fn test_message_role_deserialization() {
101        assert_eq!(serde_json::from_str::<MessageRole>("\"system\"").unwrap(), MessageRole::System);
102        assert_eq!(serde_json::from_str::<MessageRole>("\"user\"").unwrap(), MessageRole::User);
103        assert_eq!(
104            serde_json::from_str::<MessageRole>("\"assistant\"").unwrap(),
105            MessageRole::Assistant
106        );
107        assert_eq!(serde_json::from_str::<MessageRole>("\"tool\"").unwrap(), MessageRole::Tool);
108    }
109
110    #[test]
111    fn test_user_message() {
112        let msg = LlmMessage::user("Hello");
113        assert_eq!(msg.role, MessageRole::User);
114        assert_eq!(msg.content, Some("Hello".to_string()));
115        assert!(msg.tool_calls.is_none());
116        assert!(msg.image_paths.is_none());
117    }
118
119    #[test]
120    fn test_system_message() {
121        let msg = LlmMessage::system("You are a helpful assistant");
122        assert_eq!(msg.role, MessageRole::System);
123        assert_eq!(msg.content, Some("You are a helpful assistant".to_string()));
124        assert!(msg.tool_calls.is_none());
125        assert!(msg.image_paths.is_none());
126    }
127
128    #[test]
129    fn test_assistant_message() {
130        let msg = LlmMessage::assistant("I can help with that");
131        assert_eq!(msg.role, MessageRole::Assistant);
132        assert_eq!(msg.content, Some("I can help with that".to_string()));
133        assert!(msg.tool_calls.is_none());
134        assert!(msg.image_paths.is_none());
135    }
136
137    #[test]
138    fn test_message_with_images() {
139        let msg = LlmMessage::user("Describe this image")
140            .with_images(vec!["/path/to/image.jpg".to_string()]);
141        assert_eq!(msg.role, MessageRole::User);
142        assert_eq!(msg.content, Some("Describe this image".to_string()));
143        assert_eq!(msg.image_paths, Some(vec!["/path/to/image.jpg".to_string()]));
144    }
145
146    #[test]
147    fn test_llm_tool_call_serialization() {
148        let mut args = HashMap::new();
149        args.insert("key".to_string(), serde_json::json!("value"));
150
151        let tool_call = LlmToolCall {
152            id: Some("call_123".to_string()),
153            name: "test_tool".to_string(),
154            arguments: args,
155        };
156
157        let json = serde_json::to_string(&tool_call).unwrap();
158        assert!(json.contains("test_tool"));
159        assert!(json.contains("call_123"));
160    }
161
162    #[test]
163    fn test_llm_tool_call_without_id() {
164        let tool_call = LlmToolCall {
165            id: None,
166            name: "test_tool".to_string(),
167            arguments: HashMap::new(),
168        };
169
170        let json = serde_json::to_string(&tool_call).unwrap();
171        // id should be omitted when None
172        assert!(!json.contains("\"id\""));
173        assert!(json.contains("test_tool"));
174    }
175
176    #[test]
177    fn test_llm_message_serialization() {
178        let msg = LlmMessage::user("test content");
179        let json = serde_json::to_string(&msg).unwrap();
180
181        assert!(json.contains("\"role\":\"user\""));
182        assert!(json.contains("\"content\":\"test content\""));
183    }
184
185    #[test]
186    fn test_llm_message_deserialization() {
187        let json = r#"{"role":"assistant","content":"response"}"#;
188        let msg: LlmMessage = serde_json::from_str(json).unwrap();
189
190        assert_eq!(msg.role, MessageRole::Assistant);
191        assert_eq!(msg.content, Some("response".to_string()));
192    }
193
194    #[test]
195    fn test_llm_message_default_role() {
196        let json = r#"{"content":"test"}"#;
197        let msg: LlmMessage = serde_json::from_str(json).unwrap();
198
199        // Should default to User role
200        assert_eq!(msg.role, MessageRole::User);
201    }
202}