diff --git a/router/tests/test_http_predict.rs b/router/tests/test_http_predict.rs index 6869f721..850e0874 100644 --- a/router/tests/test_http_predict.rs +++ b/router/tests/test_http_predict.rs @@ -16,12 +16,13 @@ pub struct SnapshotPrediction { #[tokio::test] #[cfg(feature = "http")] async fn test_predict() -> Result<()> { - start_server( - "SamLowe/roberta-base-go_emotions".to_string(), - None, - DType::Float32, - ) - .await?; + let model_id = if cfg!(feature = "ort") { + "SamLowe/roberta-base-go_emotions-onnx" + } else { + "SamLowe/roberta-base-go_emotions" + }; + + start_server(model_id.to_string(), None, DType::Float32).await?; let request = json!({ "inputs": "test" diff --git a/router/tests/test_http_rerank.rs b/router/tests/test_http_rerank.rs index 0c5b6079..920eca4d 100644 --- a/router/tests/test_http_rerank.rs +++ b/router/tests/test_http_rerank.rs @@ -17,12 +17,7 @@ pub struct SnapshotRank { #[tokio::test] #[cfg(feature = "http")] async fn test_rerank() -> Result<()> { - start_server( - "BAAI/bge-reranker-base".to_string(), - Some("refs/pr/5".to_string()), - DType::Float32, - ) - .await?; + start_server("BAAI/bge-reranker-base".to_string(), None, DType::Float32).await?; let request = json!({ "query": "test",