From ce1edf4e1ae8e5ab95474b228d2fc55fbdca9fca Mon Sep 17 00:00:00 2001 From: Hyeongchan Kim Date: Tue, 16 Jul 2024 00:53:11 +0900 Subject: [PATCH] fix: Download `model.onnx_data` (#343) --- backends/src/lib.rs | 42 +++++++++++++++++++++++++++++++++++------- 1 file changed, 35 insertions(+), 7 deletions(-) diff --git a/backends/src/lib.rs b/backends/src/lib.rs index 2ee63279..9d44cdbd 100644 --- a/backends/src/lib.rs +++ b/backends/src/lib.rs @@ -310,19 +310,16 @@ pub async fn download_weights(api: &ApiRepo) -> Result, ApiError> { } } } else if cfg!(feature = "ort") { - tracing::info!("Downloading `model.onnx`"); - match api.get("model.onnx").await { - Ok(p) => vec![p], + match download_onnx(api).await { + Ok(p) => p, Err(err) => { - tracing::warn!("Could not download `model.onnx`: {err}"); - tracing::info!("Downloading `onnx/model.onnx`"); - let p = api.get("onnx/model.onnx").await?; - vec![p.parent().unwrap().to_path_buf()] + panic!("failed to download `model.onnx` or `model.onnx_data`. Check the onnx file exists in the repository. {err}"); } } } else { unreachable!() }; + Ok(model_files) } @@ -364,3 +361,34 @@ async fn download_safetensors(api: &ApiRepo) -> Result, ApiError> { Ok(safetensors_files) } + +async fn download_onnx(api: &ApiRepo) -> Result, ApiError> { + let mut model_files: Vec = Vec::new(); + + tracing::info!("Downloading `model.onnx`"); + match api.get("model.onnx").await { + Ok(p) => model_files.push(p), + Err(err) => { + tracing::warn!("Could not download `model.onnx`: {err}"); + tracing::info!("Downloading `onnx/model.onnx`"); + let p = api.get("onnx/model.onnx").await?; + model_files.push(p.parent().unwrap().to_path_buf()) + } + }; + + tracing::info!("Downloading `model.onnx_data`"); + match api.get("model.onnx_data").await { + Ok(p) => model_files.push(p), + Err(err) => { + tracing::warn!("Could not download `model.onnx_data`: {err}"); + tracing::info!("Downloading `onnx/model.onnx_data`"); + + match api.get("onnx/model.onnx_data").await { + Ok(p) => model_files.push(p.parent().unwrap().to_path_buf()), + Err(err) => tracing::warn!("Could not download `onnx/model.onnx_data`: {err}"), + } + } + } + + Ok(model_files) +}