Skip to content

Commit

Permalink
One tiny step further
Browse files Browse the repository at this point in the history
  • Loading branch information
sigma-andex committed Jan 5, 2024
1 parent c620277 commit 49b4010
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 57 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ serde_json = "1.0.108"
js-sys = "0.3.64"
wasm-bindgen-futures = "0.4.39"
anyhow = "1.0"
gloo = "0.11.0"

[dependencies.web-sys]
features = [
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
pub mod quantized_mistral;
mod utils;
pub mod utils;
use wasm_bindgen::prelude::*;

#[wasm_bindgen(start)]
Expand Down
29 changes: 9 additions & 20 deletions src/quantized_mistral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use tokenizers::Tokenizer;
use wasm_bindgen::prelude::*;
use web_time as time;

use gloo::console::log;
use web_sys::console;

#[wasm_bindgen]
Expand All @@ -26,7 +27,7 @@ impl Model {
fn process(&mut self, tokens: &[u32]) -> candle::Result<String> {
let u32_array = Uint32Array::new_with_length(tokens.len() as u32);
u32_array.copy_from(tokens);
console::log_2(&"Processing tokens".into(), &u32_array.into());
log!("Processing tokens", u32_array);

const REPEAT_LAST_N: usize = 64;
let dev = Device::Cpu;
Expand Down Expand Up @@ -68,21 +69,7 @@ impl Model {
impl Model {
#[wasm_bindgen(constructor)]
pub fn new(weights: Vec<u8>, tokenizer: Vec<u8>) -> Result<Model, JsError> {
println!("Initialising model...");
// let model = M::load(ModelData {
// tokenizer,
// model: weights,
// });
// let logits_processor = LogitsProcessor::new(299792458, None, None);
// match model {
// Ok(inner) => Ok(Self {
// inner,
// logits_processor,
// tokens: vec![],
// repeat_penalty: 1.,
// }),
// Err(e) => Err(JsError::new(&e.to_string())),
// }
log!("Initialising model...");
let seed = 299792458;
let temperature: Option<f64> = Some(0.8);
let top_p: Option<f64> = None;
Expand All @@ -91,19 +78,21 @@ impl Model {
let mut cursor = Cursor::new(&weights);
let mut cursor2 = Cursor::new(&weights);
let model: ModelWeights = {
log!("Loading gguf file...");
let model = gguf_file::Content::read(&mut cursor)?;
log!("gguf file loaded.");
let mut total_size_in_bytes = 0;
for (_, tensor) in model.tensor_infos.iter() {
let elem_count = tensor.shape.elem_count();
total_size_in_bytes +=
elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.blck_size();
}
println!(
log!(format!(
"loaded {:?} tensors ({}) in {:.2}s",
model.tensor_infos.len(),
&format_size(total_size_in_bytes),
start.elapsed().as_secs_f32(),
);
format_size(total_size_in_bytes),
start.elapsed().as_secs_f32()
));
ModelWeights::from_gguf(model, &mut cursor2)?
};
println!("model built");
Expand Down
40 changes: 40 additions & 0 deletions src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
use gloo::console::log;
use js_sys::{ArrayBuffer, Uint8Array};

use wasm_bindgen::{prelude::*, JsValue};
use wasm_bindgen_futures::JsFuture;
use web_sys::{Request, RequestInit, RequestMode, Response};

pub fn set_panic_hook() {
// When the `console_error_panic_hook` feature is enabled, we can call the
// `set_panic_hook` function at least once during initialization, and then
Expand All @@ -8,3 +15,36 @@ pub fn set_panic_hook() {
#[cfg(feature = "console_error_panic_hook")]
console_error_panic_hook::set_once();
}

async fn fetch(url: &str) -> Result<Response, JsValue> {
let mut opts = RequestInit::new();
opts.method("GET");
opts.mode(RequestMode::Cors);

let request = Request::new_with_str_and_init(&url, &opts)?;

let window = web_sys::window().unwrap();
let resp_value = JsFuture::from(window.fetch_with_request(&request)).await?;

assert!(resp_value.is_instance_of::<Response>());
let resp: Response = resp_value.dyn_into().unwrap();

Ok(resp)
}

pub async fn load_json(url: &str) -> Result<JsValue, JsValue> {
let response = fetch(url).await?;
let json = JsFuture::from(response.json()?).await?;
Ok(json)
}

pub async fn load_binary(url: &str) -> Result<Vec<u8>, JsValue> {
let response = fetch(url).await?;
let ab = JsFuture::from(response.array_buffer()?).await?;

let vec = Uint8Array::new(&ab).to_vec();
let bla = (&vec.iter().take(10).map(|x| x.clone()).collect::<Vec<u8>>()).clone();
let x = js_sys::Uint8Array::from(bla.as_slice());
log!(url, x);
Ok(vec)
}
46 changes: 10 additions & 36 deletions tests/web.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ extern crate wasm_bindgen_test;
use std::fmt::format;
use std::println;

use gloo::console::log;
use js_sys::Uint8Array;
use transformers_wasm::quantized_mistral::Model;
use transformers_wasm::utils;
use wasm_bindgen::{prelude::*, JsValue};
use wasm_bindgen_futures::JsFuture;
use wasm_bindgen_test::*;
Expand All @@ -16,58 +18,30 @@ use web_sys::{Request, RequestInit, RequestMode, Response};

wasm_bindgen_test_configure!(run_in_browser);

async fn fetch(url: &str) -> Result<Response, JsValue> {
let mut opts = RequestInit::new();
opts.method("GET");
opts.mode(RequestMode::Cors);

let request = Request::new_with_str_and_init(&url, &opts)?;

let window = web_sys::window().unwrap();
let resp_value = JsFuture::from(window.fetch_with_request(&request)).await?;

assert!(resp_value.is_instance_of::<Response>());
let resp: Response = resp_value.dyn_into().unwrap();

Ok(resp)
}

async fn load_json(url: &str) -> Result<JsValue, JsValue> {
let response = fetch(url).await?;
let json = JsFuture::from(response.json()?).await?;
Ok(json)
}

async fn load_binary(url: &str) -> Result<Vec<u8>, JsValue> {
let response = fetch(url).await?;
let ab = JsFuture::from(response.array_buffer()?).await?;
let vec = Uint8Array::new(&ab).to_vec();
Ok(vec)
}

#[wasm_bindgen_test]
async fn pass() -> Result<(), JsValue> {
let tokenizer_url = "http://localhost:45678/tokenizer.json";
let model_url = "http://localhost:45678/tinymistral-248m.q4_k_m.gguf";

let tokenizer_blob: Vec<u8> = load_binary(&tokenizer_url).await?;
let tokenizer_blob: Vec<u8> = utils::load_binary(&tokenizer_url).await?;
let tokenizer_blob_len = format!("{}", &tokenizer_blob.len());
console::log_2(&"tokenizer blob size".into(), &tokenizer_blob_len.into());
log!("tokenizer blob size", &tokenizer_blob_len);

let model_blob: Vec<u8> = load_binary(&model_url).await?;
let model_blob: Vec<u8> = utils::load_binary(&model_url).await?;
let model_blob_len = format!("{}", &model_blob.len());
console::log_2(&"model blob size".into(), &model_blob_len.into());

let mut model = Model::new(tokenizer_blob, model_blob)?;
log!("model blob size", &model_blob_len);

log!("loading model...");
let mut model = Model::new(model_blob, tokenizer_blob)?;
log!("model loaded.");
let prompt: String = String::from("What is a good recipe for onion soup");
let temp: f64 = 0.8;
let top_p: f64 = 1.;
let repeat_penalty: f32 = 1.1;
let seed: u64 = 203948203948;
let first_result: String = model.init_with_prompt(prompt, temp, top_p, repeat_penalty, seed)?;

console::log_2(&"first prompt result".into(), &first_result.into());
log!("first prompt result", &first_result);
assert_eq!(1 + 1, 2);
Ok(())
}

0 comments on commit 49b4010

Please sign in to comment.