Skip to content

Commit fe820e0

Browse files
committed
Initial commit
1 parent 68af0f5 commit fe820e0

File tree

2 files changed

+93
-5
lines changed

2 files changed

+93
-5
lines changed

Cargo.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@ rust-version = "1.57.0"
66

77
[lib]
88
name = "_tiktoken_core"
9-
crate-type = ["lib"]
9+
crate-type = ["cdylib"]
1010

1111
[dependencies]
12+
mlua = { version = "0.9.1", features = ["lua54", "module"] }
1213
# tiktoken dependencies
1314
fancy-regex = "0.11.0"
1415
regex = "1.8.3"
@@ -17,4 +18,4 @@ bstr = "1.5.0"
1718

1819
[features]
1920
default = []
20-
multithreading = []
21+
multithreading = []

src/lib.rs

Lines changed: 90 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
use std::collections::HashSet;
2-
use std::thread;
3-
41
use fancy_regex::Regex;
2+
use mlua::prelude::*;
53
use rustc_hash::FxHashMap as HashMap;
4+
use std::collections::HashSet;
5+
use std::sync::{Arc, Mutex};
6+
use std::thread;
67

78
#[cfg(feature = "multithreading")]
89
const MAX_NUM_THREADS: usize = 128;
@@ -176,6 +177,92 @@ fn hash_current_thread() -> usize {
176177
u64::from(x) as usize
177178
}
178179

180+
struct State {
181+
core_bpe: Mutex<Option<CoreBPENative>>,
182+
}
183+
184+
#[mlua::lua_module]
185+
pub fn core_module(lua: &mlua::Lua) -> LuaResult<LuaTable> {
186+
let core_bpe = State {
187+
core_bpe: Mutex::new(None),
188+
};
189+
let state = Arc::new(core_bpe);
190+
let state2 = Arc::clone(&state);
191+
192+
let _new = lua.create_function(
193+
move |_,
194+
(encoder, special_tokens_encoder, pattern): (
195+
HashMap<Vec<u8>, usize>,
196+
HashMap<String, usize>,
197+
String,
198+
)| {
199+
new(&*state, encoder, special_tokens_encoder, pattern);
200+
Ok(())
201+
},
202+
)?;
203+
let _encode = lua.create_function(move |_, text: String| encode(&*state2, text))?;
204+
205+
let exports = lua.create_table()?;
206+
exports.set("new", _new)?;
207+
exports.set("encode", _encode)?;
208+
Ok(exports)
209+
}
210+
211+
fn new(
212+
state: &State,
213+
encoder: HashMap<Vec<u8>, usize>,
214+
special_tokens_encoder: HashMap<String, usize>,
215+
pattern: String,
216+
) {
217+
let regex = Regex::new(&pattern)
218+
.map_err(|e| mlua::Error::external(e))
219+
.unwrap();
220+
let special_regex = {
221+
let _parts = special_tokens_encoder
222+
.keys()
223+
.map(|s| fancy_regex::escape(s))
224+
.collect::<Vec<_>>();
225+
Regex::new(&_parts.join("|"))
226+
.map_err(|e| mlua::Error::external(e))
227+
.unwrap()
228+
};
229+
let decoder: HashMap<usize, Vec<u8>> = encoder.iter().map(|(k, v)| (*v, k.clone())).collect();
230+
assert!(
231+
encoder.len() == decoder.len(),
232+
"Encoder and decoder must be of equal length; maybe you had duplicate token indices in your encoder?"
233+
);
234+
let special_tokens_decoder: HashMap<usize, Vec<u8>> = special_tokens_encoder
235+
.iter()
236+
.map(|(k, v)| (*v, k.as_bytes().to_vec()))
237+
.collect();
238+
let mut sorted_token_bytes: Vec<Vec<u8>> = encoder.keys().cloned().collect();
239+
sorted_token_bytes.sort();
240+
let mut core_bpe_lock = state.core_bpe.lock().unwrap();
241+
*core_bpe_lock = Some(CoreBPENative {
242+
encoder,
243+
special_tokens_encoder,
244+
decoder,
245+
special_tokens_decoder,
246+
regex_tls: (0..MAX_NUM_THREADS).map(|_| regex.clone()).collect(),
247+
special_regex_tls: (0..MAX_NUM_THREADS)
248+
.map(|_| special_regex.clone())
249+
.collect(),
250+
sorted_token_bytes,
251+
});
252+
}
253+
254+
fn encode(state: &State, text: String) -> LuaResult<(Vec<usize>, usize, usize)> {
255+
let allowed_special = HashSet::new();
256+
let max_tokens = None;
257+
Ok(state
258+
.core_bpe
259+
.lock()
260+
.unwrap()
261+
.as_ref()
262+
.unwrap()
263+
._encode_native(&text, &allowed_special, max_tokens))
264+
}
265+
179266
pub struct CoreBPENative {
180267
encoder: HashMap<Vec<u8>, usize>,
181268
special_tokens_encoder: HashMap<String, usize>,

0 commit comments

Comments
 (0)