diff --git a/Cargo.toml b/Cargo.toml index 37108b3..7ad22b1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,6 +33,7 @@ serde_json = "1.0.108" js-sys = "0.3.64" wasm-bindgen-futures = "0.4.39" anyhow = "1.0" +gloo = "0.11.0" [dependencies.web-sys] features = [ diff --git a/src/lib.rs b/src/lib.rs index 5e7cedd..76252c3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,5 @@ pub mod quantized_mistral; -mod utils; +pub mod utils; use wasm_bindgen::prelude::*; #[wasm_bindgen(start)] diff --git a/src/quantized_mistral.rs b/src/quantized_mistral.rs index 5feee36..926e6b6 100644 --- a/src/quantized_mistral.rs +++ b/src/quantized_mistral.rs @@ -11,6 +11,7 @@ use tokenizers::Tokenizer; use wasm_bindgen::prelude::*; use web_time as time; +use gloo::console::log; use web_sys::console; #[wasm_bindgen] @@ -26,7 +27,7 @@ impl Model { fn process(&mut self, tokens: &[u32]) -> candle::Result { let u32_array = Uint32Array::new_with_length(tokens.len() as u32); u32_array.copy_from(tokens); - console::log_2(&"Processing tokens".into(), &u32_array.into()); + log!("Processing tokens", u32_array); const REPEAT_LAST_N: usize = 64; let dev = Device::Cpu; @@ -68,21 +69,7 @@ impl Model { impl Model { #[wasm_bindgen(constructor)] pub fn new(weights: Vec, tokenizer: Vec) -> Result { - println!("Initialising model..."); - // let model = M::load(ModelData { - // tokenizer, - // model: weights, - // }); - // let logits_processor = LogitsProcessor::new(299792458, None, None); - // match model { - // Ok(inner) => Ok(Self { - // inner, - // logits_processor, - // tokens: vec![], - // repeat_penalty: 1., - // }), - // Err(e) => Err(JsError::new(&e.to_string())), - // } + log!("Initialising model..."); let seed = 299792458; let temperature: Option = Some(0.8); let top_p: Option = None; @@ -91,19 +78,21 @@ impl Model { let mut cursor = Cursor::new(&weights); let mut cursor2 = Cursor::new(&weights); let model: ModelWeights = { + log!("Loading gguf file..."); let model = gguf_file::Content::read(&mut cursor)?; + log!("gguf file loaded."); let mut total_size_in_bytes = 0; for (_, tensor) in model.tensor_infos.iter() { let elem_count = tensor.shape.elem_count(); total_size_in_bytes += elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.blck_size(); } - println!( + log!(format!( "loaded {:?} tensors ({}) in {:.2}s", model.tensor_infos.len(), - &format_size(total_size_in_bytes), - start.elapsed().as_secs_f32(), - ); + format_size(total_size_in_bytes), + start.elapsed().as_secs_f32() + )); ModelWeights::from_gguf(model, &mut cursor2)? }; println!("model built"); diff --git a/src/utils.rs b/src/utils.rs index b1d7929..691ff5f 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,3 +1,10 @@ +use gloo::console::log; +use js_sys::{ArrayBuffer, Uint8Array}; + +use wasm_bindgen::{prelude::*, JsValue}; +use wasm_bindgen_futures::JsFuture; +use web_sys::{Request, RequestInit, RequestMode, Response}; + pub fn set_panic_hook() { // When the `console_error_panic_hook` feature is enabled, we can call the // `set_panic_hook` function at least once during initialization, and then @@ -8,3 +15,36 @@ pub fn set_panic_hook() { #[cfg(feature = "console_error_panic_hook")] console_error_panic_hook::set_once(); } + +async fn fetch(url: &str) -> Result { + let mut opts = RequestInit::new(); + opts.method("GET"); + opts.mode(RequestMode::Cors); + + let request = Request::new_with_str_and_init(&url, &opts)?; + + let window = web_sys::window().unwrap(); + let resp_value = JsFuture::from(window.fetch_with_request(&request)).await?; + + assert!(resp_value.is_instance_of::()); + let resp: Response = resp_value.dyn_into().unwrap(); + + Ok(resp) +} + +pub async fn load_json(url: &str) -> Result { + let response = fetch(url).await?; + let json = JsFuture::from(response.json()?).await?; + Ok(json) +} + +pub async fn load_binary(url: &str) -> Result, JsValue> { + let response = fetch(url).await?; + let ab = JsFuture::from(response.array_buffer()?).await?; + + let vec = Uint8Array::new(&ab).to_vec(); + let bla = (&vec.iter().take(10).map(|x| x.clone()).collect::>()).clone(); + let x = js_sys::Uint8Array::from(bla.as_slice()); + log!(url, x); + Ok(vec) +} diff --git a/tests/web.rs b/tests/web.rs index 9eff23e..d81b9c1 100644 --- a/tests/web.rs +++ b/tests/web.rs @@ -6,8 +6,10 @@ extern crate wasm_bindgen_test; use std::fmt::format; use std::println; +use gloo::console::log; use js_sys::Uint8Array; use transformers_wasm::quantized_mistral::Model; +use transformers_wasm::utils; use wasm_bindgen::{prelude::*, JsValue}; use wasm_bindgen_futures::JsFuture; use wasm_bindgen_test::*; @@ -16,50 +18,22 @@ use web_sys::{Request, RequestInit, RequestMode, Response}; wasm_bindgen_test_configure!(run_in_browser); -async fn fetch(url: &str) -> Result { - let mut opts = RequestInit::new(); - opts.method("GET"); - opts.mode(RequestMode::Cors); - - let request = Request::new_with_str_and_init(&url, &opts)?; - - let window = web_sys::window().unwrap(); - let resp_value = JsFuture::from(window.fetch_with_request(&request)).await?; - - assert!(resp_value.is_instance_of::()); - let resp: Response = resp_value.dyn_into().unwrap(); - - Ok(resp) -} - -async fn load_json(url: &str) -> Result { - let response = fetch(url).await?; - let json = JsFuture::from(response.json()?).await?; - Ok(json) -} - -async fn load_binary(url: &str) -> Result, JsValue> { - let response = fetch(url).await?; - let ab = JsFuture::from(response.array_buffer()?).await?; - let vec = Uint8Array::new(&ab).to_vec(); - Ok(vec) -} - #[wasm_bindgen_test] async fn pass() -> Result<(), JsValue> { let tokenizer_url = "http://localhost:45678/tokenizer.json"; let model_url = "http://localhost:45678/tinymistral-248m.q4_k_m.gguf"; - let tokenizer_blob: Vec = load_binary(&tokenizer_url).await?; + let tokenizer_blob: Vec = utils::load_binary(&tokenizer_url).await?; let tokenizer_blob_len = format!("{}", &tokenizer_blob.len()); - console::log_2(&"tokenizer blob size".into(), &tokenizer_blob_len.into()); + log!("tokenizer blob size", &tokenizer_blob_len); - let model_blob: Vec = load_binary(&model_url).await?; + let model_blob: Vec = utils::load_binary(&model_url).await?; let model_blob_len = format!("{}", &model_blob.len()); - console::log_2(&"model blob size".into(), &model_blob_len.into()); - - let mut model = Model::new(tokenizer_blob, model_blob)?; + log!("model blob size", &model_blob_len); + log!("loading model..."); + let mut model = Model::new(model_blob, tokenizer_blob)?; + log!("model loaded."); let prompt: String = String::from("What is a good recipe for onion soup"); let temp: f64 = 0.8; let top_p: f64 = 1.; @@ -67,7 +41,7 @@ async fn pass() -> Result<(), JsValue> { let seed: u64 = 203948203948; let first_result: String = model.init_with_prompt(prompt, temp, top_p, repeat_penalty, seed)?; - console::log_2(&"first prompt result".into(), &first_result.into()); + log!("first prompt result", &first_result); assert_eq!(1 + 1, 2); Ok(()) }