mojentic/llm/gateways/
tokenizer_gateway.rs1use tiktoken_rs::CoreBPE;
10
11pub struct TokenizerGateway {
29 tokenizer: CoreBPE,
30}
31
32impl TokenizerGateway {
33 pub fn new(model: &str) -> Result<Self, Box<dyn std::error::Error>> {
54 let tokenizer = match model {
55 "cl100k_base" => tiktoken_rs::cl100k_base()?,
56 "p50k_base" => tiktoken_rs::p50k_base()?,
57 "r50k_base" => tiktoken_rs::r50k_base()?,
58 _ => return Err(format!("Unsupported encoding model: {}", model).into()),
59 };
60 Ok(Self { tokenizer })
61 }
62
63 pub fn encode(&self, text: &str) -> Vec<usize> {
83 tracing::debug!("Encoding text: {}", text);
84 self.tokenizer.encode_with_special_tokens(text)
85 }
86
87 pub fn decode(&self, tokens: &[usize]) -> String {
108 tracing::debug!("Decoding {} tokens", tokens.len());
109 self.tokenizer.decode(tokens.to_vec()).unwrap_or_else(|e| {
110 tracing::error!("Failed to decode tokens: {}", e);
111 String::new()
112 })
113 }
114
115 pub fn count_tokens(&self, text: &str) -> usize {
138 self.encode(text).len()
139 }
140}
141
142impl Default for TokenizerGateway {
143 fn default() -> Self {
144 Self::new("cl100k_base").expect("cl100k_base should always be available")
146 }
147}
148
149#[cfg(test)]
150mod tests {
151 use super::*;
152
153 #[test]
154 fn test_encode_basic() {
155 let tokenizer = TokenizerGateway::default();
156 let text = "Hello, world!";
157 let tokens = tokenizer.encode(text);
158
159 assert!(!tokens.is_empty());
160 assert!(!tokens.is_empty());
162 }
163
164 #[test]
165 fn test_encode_empty() {
166 let tokenizer = TokenizerGateway::default();
167 let tokens = tokenizer.encode("");
168 assert_eq!(tokens.len(), 0);
169 }
170
171 #[test]
172 fn test_encode_consistent() {
173 let tokenizer = TokenizerGateway::default();
174 let text = "The quick brown fox";
175 let tokens1 = tokenizer.encode(text);
176 let tokens2 = tokenizer.encode(text);
177
178 assert_eq!(tokens1, tokens2);
179 }
180
181 #[test]
182 fn test_decode_basic() {
183 let tokenizer = TokenizerGateway::default();
184 let original = "Hello, world!";
185 let tokens = tokenizer.encode(original);
186 let decoded = tokenizer.decode(&tokens);
187
188 assert_eq!(original, decoded);
189 }
190
191 #[test]
192 fn test_decode_empty() {
193 let tokenizer = TokenizerGateway::default();
194 let text = tokenizer.decode(&[]);
195 assert_eq!(text, "");
196 }
197
198 #[test]
199 fn test_round_trip() {
200 let tokenizer = TokenizerGateway::default();
201 let test_cases = vec![
202 "Simple text",
203 "Text with numbers: 123456",
204 "Special characters: !@#$%^&*()",
205 "Multi-line\ntext\nwith\nnewlines",
206 "Unicode: δ½ ε₯½δΈη π",
207 ];
208
209 for original in test_cases {
210 let tokens = tokenizer.encode(original);
211 let decoded = tokenizer.decode(&tokens);
212 assert_eq!(original, decoded, "Round-trip failed for: {}", original);
213 }
214 }
215
216 #[test]
217 fn test_different_encodings() {
218 let tokenizer_cl100k = TokenizerGateway::default();
220
221 let tokenizer_p50k = TokenizerGateway::new("p50k_base").unwrap();
223
224 let text = "Hello, world!";
225 let tokens_cl100k = tokenizer_cl100k.encode(text);
226 let tokens_p50k = tokenizer_p50k.encode(text);
227
228 assert!(!tokens_cl100k.is_empty());
230 assert!(!tokens_p50k.is_empty());
231
232 assert_eq!(tokenizer_cl100k.decode(&tokens_cl100k), text);
234 assert_eq!(tokenizer_p50k.decode(&tokens_p50k), text);
235 }
236
237 #[test]
238 fn test_count_tokens() {
239 let tokenizer = TokenizerGateway::default();
240 let text = "What is the capital of France?";
241 let count = tokenizer.count_tokens(text);
242
243 assert!(count > 5);
245 assert!(count < 15);
246 }
247
248 #[test]
249 fn test_count_tokens_matches_encode() {
250 let tokenizer = TokenizerGateway::default();
251 let text = "The quick brown fox jumps over the lazy dog.";
252
253 let tokens = tokenizer.encode(text);
254 let count = tokenizer.count_tokens(text);
255
256 assert_eq!(tokens.len(), count);
257 }
258
259 #[test]
260 fn test_long_text() {
261 let tokenizer = TokenizerGateway::default();
262 let long_text = "word ".repeat(1000);
263 let tokens = tokenizer.encode(&long_text);
264
265 assert!(tokens.len() > 1000);
266
267 let decoded = tokenizer.decode(&tokens);
268 assert_eq!(long_text, decoded);
269 }
270
271 #[test]
272 fn test_unicode_handling() {
273 let tokenizer = TokenizerGateway::default();
274 let unicode_text = "Hello δΈη! π Special chars: @#$%";
275 let tokens = tokenizer.encode(unicode_text);
276 let decoded = tokenizer.decode(&tokens);
277
278 assert_eq!(unicode_text, decoded);
279 assert!(!tokens.is_empty());
280 }
281}