mojentic/llm/gateways/
tokenizer_gateway.rs

1//! Tokenizer gateway for encoding and decoding text using tiktoken.
2//!
3//! This module provides token counting functionality which is useful for:
4//! - Managing context window limits
5//! - Estimating API costs
6//! - Debugging tokenization issues
7//! - Optimizing prompt engineering
8
9use tiktoken_rs::CoreBPE;
10
11/// Gateway for tokenizing and detokenizing text using tiktoken.
12///
13/// The tokenizer gateway provides encoding and decoding functionality,
14/// allowing you to convert text to tokens and back. This is essential
15/// for understanding token usage and managing context windows.
16///
17/// # Examples
18///
19/// ```
20/// use mojentic::llm::gateways::TokenizerGateway;
21///
22/// let tokenizer = TokenizerGateway::new("cl100k_base").unwrap();
23/// let text = "Hello, world!";
24/// let tokens = tokenizer.encode(text);
25/// let decoded = tokenizer.decode(&tokens);
26/// assert_eq!(text, decoded);
27/// ```
28pub struct TokenizerGateway {
29    tokenizer: CoreBPE,
30}
31
32impl TokenizerGateway {
33    /// Creates a new TokenizerGateway with the specified encoding model.
34    ///
35    /// # Arguments
36    ///
37    /// * `model` - The encoding model to use. Common options:
38    ///   - "cl100k_base" - Used by GPT-4 and GPT-3.5-turbo (default)
39    ///   - "p50k_base" - Used by older GPT-3 models
40    ///   - "r50k_base" - Used by even older models
41    ///
42    /// # Errors
43    ///
44    /// Returns an error if the specified model is not available.
45    ///
46    /// # Examples
47    ///
48    /// ```
49    /// use mojentic::llm::gateways::TokenizerGateway;
50    ///
51    /// let tokenizer = TokenizerGateway::new("cl100k_base").unwrap();
52    /// ```
53    pub fn new(model: &str) -> Result<Self, Box<dyn std::error::Error>> {
54        let tokenizer = match model {
55            "cl100k_base" => tiktoken_rs::cl100k_base()?,
56            "p50k_base" => tiktoken_rs::p50k_base()?,
57            "r50k_base" => tiktoken_rs::r50k_base()?,
58            _ => return Err(format!("Unsupported encoding model: {}", model).into()),
59        };
60        Ok(Self { tokenizer })
61    }
62
63    /// Encodes text into tokens.
64    ///
65    /// # Arguments
66    ///
67    /// * `text` - The text to encode
68    ///
69    /// # Returns
70    ///
71    /// A vector of token IDs representing the encoded text.
72    ///
73    /// # Examples
74    ///
75    /// ```
76    /// use mojentic::llm::gateways::TokenizerGateway;
77    ///
78    /// let tokenizer = TokenizerGateway::default();
79    /// let tokens = tokenizer.encode("Hello, world!");
80    /// println!("Token count: {}", tokens.len());
81    /// ```
82    pub fn encode(&self, text: &str) -> Vec<usize> {
83        tracing::debug!("Encoding text: {}", text);
84        self.tokenizer.encode_with_special_tokens(text)
85    }
86
87    /// Decodes tokens back into text.
88    ///
89    /// # Arguments
90    ///
91    /// * `tokens` - The slice of token IDs to decode
92    ///
93    /// # Returns
94    ///
95    /// The decoded text.
96    ///
97    /// # Examples
98    ///
99    /// ```
100    /// use mojentic::llm::gateways::TokenizerGateway;
101    ///
102    /// let tokenizer = TokenizerGateway::default();
103    /// let tokens = vec![9906, 11, 1917, 0];
104    /// let text = tokenizer.decode(&tokens);
105    /// println!("Decoded: {}", text);
106    /// ```
107    pub fn decode(&self, tokens: &[usize]) -> String {
108        tracing::debug!("Decoding {} tokens", tokens.len());
109        self.tokenizer.decode(tokens.to_vec()).unwrap_or_else(|e| {
110            tracing::error!("Failed to decode tokens: {}", e);
111            String::new()
112        })
113    }
114
115    /// Counts the number of tokens in a text string.
116    ///
117    /// This is a convenience method that encodes the text and returns
118    /// the token count without allocating the token vector.
119    ///
120    /// # Arguments
121    ///
122    /// * `text` - The text to count tokens for
123    ///
124    /// # Returns
125    ///
126    /// The number of tokens in the text.
127    ///
128    /// # Examples
129    ///
130    /// ```
131    /// use mojentic::llm::gateways::TokenizerGateway;
132    ///
133    /// let tokenizer = TokenizerGateway::default();
134    /// let count = tokenizer.count_tokens("Hello, world!");
135    /// println!("Token count: {}", count);
136    /// ```
137    pub fn count_tokens(&self, text: &str) -> usize {
138        self.encode(text).len()
139    }
140}
141
142impl Default for TokenizerGateway {
143    fn default() -> Self {
144        // Use cl100k_base as the default tokenizer
145        Self::new("cl100k_base").expect("cl100k_base should always be available")
146    }
147}
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152
153    #[test]
154    fn test_encode_basic() {
155        let tokenizer = TokenizerGateway::default();
156        let text = "Hello, world!";
157        let tokens = tokenizer.encode(text);
158
159        assert!(!tokens.is_empty());
160        // Tokens are valid usize values (some tokens can be 0, like BOS/EOS)
161        assert!(!tokens.is_empty());
162    }
163
164    #[test]
165    fn test_encode_empty() {
166        let tokenizer = TokenizerGateway::default();
167        let tokens = tokenizer.encode("");
168        assert_eq!(tokens.len(), 0);
169    }
170
171    #[test]
172    fn test_encode_consistent() {
173        let tokenizer = TokenizerGateway::default();
174        let text = "The quick brown fox";
175        let tokens1 = tokenizer.encode(text);
176        let tokens2 = tokenizer.encode(text);
177
178        assert_eq!(tokens1, tokens2);
179    }
180
181    #[test]
182    fn test_decode_basic() {
183        let tokenizer = TokenizerGateway::default();
184        let original = "Hello, world!";
185        let tokens = tokenizer.encode(original);
186        let decoded = tokenizer.decode(&tokens);
187
188        assert_eq!(original, decoded);
189    }
190
191    #[test]
192    fn test_decode_empty() {
193        let tokenizer = TokenizerGateway::default();
194        let text = tokenizer.decode(&[]);
195        assert_eq!(text, "");
196    }
197
198    #[test]
199    fn test_round_trip() {
200        let tokenizer = TokenizerGateway::default();
201        let test_cases = vec![
202            "Simple text",
203            "Text with numbers: 123456",
204            "Special characters: !@#$%^&*()",
205            "Multi-line\ntext\nwith\nnewlines",
206            "Unicode: δ½ ε₯½δΈ–η•Œ 🌍",
207        ];
208
209        for original in test_cases {
210            let tokens = tokenizer.encode(original);
211            let decoded = tokenizer.decode(&tokens);
212            assert_eq!(original, decoded, "Round-trip failed for: {}", original);
213        }
214    }
215
216    #[test]
217    fn test_different_encodings() {
218        // cl100k_base is the default and most common
219        let tokenizer_cl100k = TokenizerGateway::default();
220
221        // p50k_base for older models
222        let tokenizer_p50k = TokenizerGateway::new("p50k_base").unwrap();
223
224        let text = "Hello, world!";
225        let tokens_cl100k = tokenizer_cl100k.encode(text);
226        let tokens_p50k = tokenizer_p50k.encode(text);
227
228        // Both should work
229        assert!(!tokens_cl100k.is_empty());
230        assert!(!tokens_p50k.is_empty());
231
232        // Both should decode correctly
233        assert_eq!(tokenizer_cl100k.decode(&tokens_cl100k), text);
234        assert_eq!(tokenizer_p50k.decode(&tokens_p50k), text);
235    }
236
237    #[test]
238    fn test_count_tokens() {
239        let tokenizer = TokenizerGateway::default();
240        let text = "What is the capital of France?";
241        let count = tokenizer.count_tokens(text);
242
243        // This specific message should be around 7-8 tokens with cl100k_base
244        assert!(count > 5);
245        assert!(count < 15);
246    }
247
248    #[test]
249    fn test_count_tokens_matches_encode() {
250        let tokenizer = TokenizerGateway::default();
251        let text = "The quick brown fox jumps over the lazy dog.";
252
253        let tokens = tokenizer.encode(text);
254        let count = tokenizer.count_tokens(text);
255
256        assert_eq!(tokens.len(), count);
257    }
258
259    #[test]
260    fn test_long_text() {
261        let tokenizer = TokenizerGateway::default();
262        let long_text = "word ".repeat(1000);
263        let tokens = tokenizer.encode(&long_text);
264
265        assert!(tokens.len() > 1000);
266
267        let decoded = tokenizer.decode(&tokens);
268        assert_eq!(long_text, decoded);
269    }
270
271    #[test]
272    fn test_unicode_handling() {
273        let tokenizer = TokenizerGateway::default();
274        let unicode_text = "Hello δΈ–η•Œ! 🌍 Special chars: @#$%";
275        let tokens = tokenizer.encode(unicode_text);
276        let decoded = tokenizer.decode(&tokens);
277
278        assert_eq!(unicode_text, decoded);
279        assert!(!tokens.is_empty());
280    }
281}