Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tiny llama #1

Merged
merged 2 commits into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"rust-analyzer.runnables.extraEnv": {
"RUSTFLAGS": "--cfg=web_sys_unstable_apis"
},
"rust-analyzer.cargo.extraEnv": {
"RUSTFLAGS": "--cfg=web_sys_unstable_apis"
}
}
7 changes: 7 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ repository = "https://github.com/text-yoga/transformers-wasm"
[lib]
crate-type = ["cdylib", "rlib"]

[build]
rustflags = ["--cfg=web_sys_unstable_apis"]

[features]
default = ["console_error_panic_hook"]

Expand Down Expand Up @@ -45,7 +48,11 @@ features = [
'RequestMode',
'Response',
'Window',
'Navigator',
'Gpu',
'WgslLanguageFeatures'
]

version = "0.3.64"

[dev-dependencies]
Expand Down
3 changes: 3 additions & 0 deletions download-model.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash
hfdownloader -m TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF:q4_k_m -s ./tests/data -t $TOKEN
hfdownloader -m TinyLlama/TinyLlama-1.1B-Chat-v1.0:tokenizer -s ./tests/data -t $TOKEN
4 changes: 2 additions & 2 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
"build": "wasm-pack build -s text-yoga --dev",
"build:release": "wasm-pack build -s text-yoga --release",
"test:server": "npx http-server --cors -p 31300 ./tests/data",
"test:chrome": "RUST_LOG=wasm_bindgen_test_runner wasm-pack -vvv test --chrome --chromedriver \"$(which chromedriver)\" --headless",
"test:gecko": "wasm-pack test --firefox --geckodriver \"$(which geckodriver)\" --headless"
"test:chrome": "wasm-pack -vvv test --chrome --chromedriver \"$(which chromedriver)\"",
"test:firefox": "wasm-pack test --firefox --geckodriver \"$(which geckodriver)\" --headless"
},
"keywords": [],
"author": "",
Expand Down
3 changes: 3 additions & 0 deletions src/quantized_mistral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,11 @@ impl Model {

const REPEAT_LAST_N: usize = 64;
let dev = Device::Cpu;

let input = Tensor::new(tokens, &dev)?.unsqueeze(0)?;
debug!("Starting forward pass...");
let logits = self.inner.forward(&input, tokens.len())?;
debug!("Forward pass done.");
let logits = logits.squeeze(0)?;
let logits = if self.repeat_penalty == 1. || tokens.is_empty() {
logits
Expand Down
19 changes: 18 additions & 1 deletion src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use js_sys::{ArrayBuffer, Uint8Array};

use wasm_bindgen::{prelude::*, JsValue};
use wasm_bindgen_futures::JsFuture;
use web_sys::{Request, RequestInit, RequestMode, Response};
use web_sys::{Navigator, Request, RequestInit, RequestMode, Response, Window};

pub fn set_panic_hook() {
// When the `console_error_panic_hook` feature is enabled, we can call the
Expand Down Expand Up @@ -48,3 +48,20 @@ pub async fn load_binary(url: &str) -> Result<Vec<u8>, JsValue> {
log!(url, x);
Ok(vec)
}

#[cfg(web_sys_unstable_apis)]
pub async fn has_gpu() -> bool {
let window = web_sys::window().expect("no global `window` exists");
let navigator = window.navigator();

let gpu: web_sys::Gpu = navigator.gpu();
let has_gpu_check = JsFuture::from(gpu.request_adapter()).await;

let mut has_gpu = false;
match has_gpu_check {
Ok(_) => has_gpu = true,
Err(err) => {}
}
log!("wgsl_language_features");
has_gpu
}
19 changes: 14 additions & 5 deletions tests/web.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,17 @@ use transformers_wasm::utils;
use wasm_bindgen::{prelude::*, JsValue};
use wasm_bindgen_futures::JsFuture;
use wasm_bindgen_test::*;
use web_sys::console;
use web_sys::{Request, RequestInit, RequestMode, Response};
use web_sys::{console, Request, RequestInit, RequestMode, Response};

wasm_bindgen_test_configure!(run_in_browser);

#[wasm_bindgen_test]
async fn pass() -> Result<(), JsValue> {
let tokenizer_url = "http://localhost:31300/tokenizer.json";
let model_url = "http://localhost:31300/tinymistral-248m.q4_k_m.gguf";
#[cfg(web_sys_unstable_apis)]
log!(utils::has_gpu().await);

let tokenizer_url = "http://localhost:31300/TinyLlama_TinyLlama-1.1B-Chat-v1.0/tokenizer.json";
let model_url = "http://localhost:31300/TheBloke_TinyLlama-1.1B-Chat-v1.0-GGUF/tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf";

let tokenizer_blob: Vec<u8> = utils::load_binary(&tokenizer_url).await?;
let tokenizer_blob_len = format!("{}", &tokenizer_blob.len());
Expand All @@ -32,9 +34,16 @@ async fn pass() -> Result<(), JsValue> {
log!("model blob size", &model_blob_len);

log!("loading model...");

let mut model = Model::new(model_blob, tokenizer_blob)?;
log!("model loaded.");
let prompt: String = String::from("What is a good recipe for onion soup");
let prompt: String = String::from(
"<|system|>
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. </s>
<|user|>
What is borrow checking in rust?</s>
<|assistant|>",
);
let temp: f64 = 0.8;
let top_p: f64 = 1.;
let repeat_penalty: f32 = 1.1;
Expand Down