From 017bde1ecc55e25cc1e20d61e36a3d161b6ba4da Mon Sep 17 00:00:00 2001 From: sigma-andex <77549848+sigma-andex@users.noreply.github.com> Date: Sun, 7 Jan 2024 20:10:20 +0000 Subject: [PATCH 1/2] WIP: Try out tinyllama --- download-model.sh | 3 +++ package.json | 4 ++-- src/quantized_mistral.rs | 3 +++ tests/web.rs | 13 ++++++++++--- 4 files changed, 18 insertions(+), 5 deletions(-) create mode 100755 download-model.sh diff --git a/download-model.sh b/download-model.sh new file mode 100755 index 0000000..062d48f --- /dev/null +++ b/download-model.sh @@ -0,0 +1,3 @@ +#!/bin/bash +hfdownloader -m TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF:q4_k_m -s ./tests/data -t $TOKEN +hfdownloader -m TinyLlama/TinyLlama-1.1B-Chat-v1.0:tokenizer -s ./tests/data -t $TOKEN diff --git a/package.json b/package.json index 88d11c7..d6c2b4e 100644 --- a/package.json +++ b/package.json @@ -10,8 +10,8 @@ "build": "wasm-pack build -s text-yoga --dev", "build:release": "wasm-pack build -s text-yoga --release", "test:server": "npx http-server --cors -p 31300 ./tests/data", - "test:chrome": "RUST_LOG=wasm_bindgen_test_runner wasm-pack -vvv test --chrome --chromedriver \"$(which chromedriver)\" --headless", - "test:gecko": "wasm-pack test --firefox --geckodriver \"$(which geckodriver)\" --headless" + "test:chrome": "wasm-pack -vvv test --chrome --chromedriver \"$(which chromedriver)\"", + "test:firefox": "wasm-pack test --firefox --geckodriver \"$(which geckodriver)\" --headless" }, "keywords": [], "author": "", diff --git a/src/quantized_mistral.rs b/src/quantized_mistral.rs index a1f724f..c331333 100644 --- a/src/quantized_mistral.rs +++ b/src/quantized_mistral.rs @@ -31,8 +31,11 @@ impl Model { const REPEAT_LAST_N: usize = 64; let dev = Device::Cpu; + let input = Tensor::new(tokens, &dev)?.unsqueeze(0)?; + debug!("Starting forward pass..."); let logits = self.inner.forward(&input, tokens.len())?; + debug!("Forward pass done."); let logits = logits.squeeze(0)?; let logits = if self.repeat_penalty == 1. || tokens.is_empty() { logits diff --git a/tests/web.rs b/tests/web.rs index 624cab0..0f57f19 100644 --- a/tests/web.rs +++ b/tests/web.rs @@ -20,8 +20,8 @@ wasm_bindgen_test_configure!(run_in_browser); #[wasm_bindgen_test] async fn pass() -> Result<(), JsValue> { - let tokenizer_url = "http://localhost:31300/tokenizer.json"; - let model_url = "http://localhost:31300/tinymistral-248m.q4_k_m.gguf"; + let tokenizer_url = "http://localhost:31300/TinyLlama_TinyLlama-1.1B-Chat-v1.0/tokenizer.json"; + let model_url = "http://localhost:31300/TheBloke_TinyLlama-1.1B-Chat-v1.0-GGUF/tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf"; let tokenizer_blob: Vec = utils::load_binary(&tokenizer_url).await?; let tokenizer_blob_len = format!("{}", &tokenizer_blob.len()); @@ -32,9 +32,16 @@ async fn pass() -> Result<(), JsValue> { log!("model blob size", &model_blob_len); log!("loading model..."); + let mut model = Model::new(model_blob, tokenizer_blob)?; log!("model loaded."); - let prompt: String = String::from("What is a good recipe for onion soup"); + let prompt: String = String::from( + "<|system|> + You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. + <|user|> + What is borrow checking in rust? + <|assistant|>", + ); let temp: f64 = 0.8; let top_p: f64 = 1.; let repeat_penalty: f32 = 1.1; From a2cdb08c17bb38aaefe978041e6509c281bebb79 Mon Sep 17 00:00:00 2001 From: sigma-andex <77549848+sigma-andex@users.noreply.github.com> Date: Mon, 8 Jan 2024 20:28:33 +0000 Subject: [PATCH 2/2] Doing something, not sure what though --- .vscode/settings.json | 8 ++++++++ Cargo.toml | 7 +++++++ src/utils.rs | 19 ++++++++++++++++++- tests/web.rs | 6 ++++-- 4 files changed, 37 insertions(+), 3 deletions(-) create mode 100644 .vscode/settings.json diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..cd4d37d --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,8 @@ +{ + "rust-analyzer.runnables.extraEnv": { + "RUSTFLAGS": "--cfg=web_sys_unstable_apis" + }, + "rust-analyzer.cargo.extraEnv": { + "RUSTFLAGS": "--cfg=web_sys_unstable_apis" + } +} diff --git a/Cargo.toml b/Cargo.toml index 4a01085..f9c73d0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,9 @@ repository = "https://github.com/text-yoga/transformers-wasm" [lib] crate-type = ["cdylib", "rlib"] +[build] +rustflags = ["--cfg=web_sys_unstable_apis"] + [features] default = ["console_error_panic_hook"] @@ -45,7 +48,11 @@ features = [ 'RequestMode', 'Response', 'Window', + 'Navigator', + 'Gpu', + 'WgslLanguageFeatures' ] + version = "0.3.64" [dev-dependencies] diff --git a/src/utils.rs b/src/utils.rs index 691ff5f..2ef882a 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -3,7 +3,7 @@ use js_sys::{ArrayBuffer, Uint8Array}; use wasm_bindgen::{prelude::*, JsValue}; use wasm_bindgen_futures::JsFuture; -use web_sys::{Request, RequestInit, RequestMode, Response}; +use web_sys::{Navigator, Request, RequestInit, RequestMode, Response, Window}; pub fn set_panic_hook() { // When the `console_error_panic_hook` feature is enabled, we can call the @@ -48,3 +48,20 @@ pub async fn load_binary(url: &str) -> Result, JsValue> { log!(url, x); Ok(vec) } + +#[cfg(web_sys_unstable_apis)] +pub async fn has_gpu() -> bool { + let window = web_sys::window().expect("no global `window` exists"); + let navigator = window.navigator(); + + let gpu: web_sys::Gpu = navigator.gpu(); + let has_gpu_check = JsFuture::from(gpu.request_adapter()).await; + + let mut has_gpu = false; + match has_gpu_check { + Ok(_) => has_gpu = true, + Err(err) => {} + } + log!("wgsl_language_features"); + has_gpu +} diff --git a/tests/web.rs b/tests/web.rs index 0f57f19..307c865 100644 --- a/tests/web.rs +++ b/tests/web.rs @@ -13,13 +13,15 @@ use transformers_wasm::utils; use wasm_bindgen::{prelude::*, JsValue}; use wasm_bindgen_futures::JsFuture; use wasm_bindgen_test::*; -use web_sys::console; -use web_sys::{Request, RequestInit, RequestMode, Response}; +use web_sys::{console, Request, RequestInit, RequestMode, Response}; wasm_bindgen_test_configure!(run_in_browser); #[wasm_bindgen_test] async fn pass() -> Result<(), JsValue> { + #[cfg(web_sys_unstable_apis)] + log!(utils::has_gpu().await); + let tokenizer_url = "http://localhost:31300/TinyLlama_TinyLlama-1.1B-Chat-v1.0/tokenizer.json"; let model_url = "http://localhost:31300/TheBloke_TinyLlama-1.1B-Chat-v1.0-GGUF/tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf";