Skip to content

Commit

Permalink
fix: Download model.onnx_data (#343)
Browse files Browse the repository at this point in the history
  • Loading branch information
kozistr authored Jul 15, 2024
1 parent 661a77f commit ce1edf4
Showing 1 changed file with 35 additions and 7 deletions.
42 changes: 35 additions & 7 deletions backends/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -310,19 +310,16 @@ pub async fn download_weights(api: &ApiRepo) -> Result<Vec<PathBuf>, ApiError> {
}
}
} else if cfg!(feature = "ort") {
tracing::info!("Downloading `model.onnx`");
match api.get("model.onnx").await {
Ok(p) => vec![p],
match download_onnx(api).await {
Ok(p) => p,
Err(err) => {
tracing::warn!("Could not download `model.onnx`: {err}");
tracing::info!("Downloading `onnx/model.onnx`");
let p = api.get("onnx/model.onnx").await?;
vec![p.parent().unwrap().to_path_buf()]
panic!("failed to download `model.onnx` or `model.onnx_data`. Check the onnx file exists in the repository. {err}");
}
}
} else {
unreachable!()
};

Ok(model_files)
}

Expand Down Expand Up @@ -364,3 +361,34 @@ async fn download_safetensors(api: &ApiRepo) -> Result<Vec<PathBuf>, ApiError> {

Ok(safetensors_files)
}

async fn download_onnx(api: &ApiRepo) -> Result<Vec<PathBuf>, ApiError> {
let mut model_files: Vec<PathBuf> = Vec::new();

tracing::info!("Downloading `model.onnx`");
match api.get("model.onnx").await {
Ok(p) => model_files.push(p),
Err(err) => {
tracing::warn!("Could not download `model.onnx`: {err}");
tracing::info!("Downloading `onnx/model.onnx`");
let p = api.get("onnx/model.onnx").await?;
model_files.push(p.parent().unwrap().to_path_buf())
}
};

tracing::info!("Downloading `model.onnx_data`");
match api.get("model.onnx_data").await {
Ok(p) => model_files.push(p),
Err(err) => {
tracing::warn!("Could not download `model.onnx_data`: {err}");
tracing::info!("Downloading `onnx/model.onnx_data`");

match api.get("onnx/model.onnx_data").await {
Ok(p) => model_files.push(p.parent().unwrap().to_path_buf()),
Err(err) => tracing::warn!("Could not download `onnx/model.onnx_data`: {err}"),
}
}
}

Ok(model_files)
}

0 comments on commit ce1edf4

Please sign in to comment.