From e0e3bc6ca833fb49e2438f3d6d28bd6805104786 Mon Sep 17 00:00:00 2001 From: kozistr Date: Sat, 30 Nov 2024 12:28:33 +0900 Subject: [PATCH] update: enable GTEModel for cuda --- backends/candle/src/lib.rs | 3 ++- backends/candle/src/models/gte.rs | 11 +---------- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/backends/candle/src/lib.rs b/backends/candle/src/lib.rs index c8f762dd..c52cd776 100644 --- a/backends/candle/src/lib.rs +++ b/backends/candle/src/lib.rs @@ -349,7 +349,8 @@ impl CandleBackend { if dtype != DType::F16 || !cfg!(any(feature = "flash-attn", feature = "flash-attn-v1")) { - return Err(BackendError::Start("GTE is only supported on Cuda devices in fp16 with flash attention enabled".to_string())); + tracing::info!("Starting GTE model on {:?}", device); + Ok(Box::new(GTEModel::load(vb, &config, model_type).s()?)) } tracing::info!("Starting FlashGTE model on {:?}", device); Ok(Box::new(FlashGTEModel::load(vb, &config, model_type).s()?)) diff --git a/backends/candle/src/models/gte.rs b/backends/candle/src/models/gte.rs index ba360eaf..c5830c39 100644 --- a/backends/candle/src/models/gte.rs +++ b/backends/candle/src/models/gte.rs @@ -1,6 +1,6 @@ use crate::layers::{get_cublas_lt_wrapper, HiddenAct, LayerNorm, Linear}; use crate::models::{apply_rotary, cos_sin, inv_freqs, Model, PositionEmbeddingType}; -use candle::{DType, Device, IndexOp, Result, Tensor, D}; +use candle::{Device, IndexOp, Result, Tensor, D}; use candle_nn::{Embedding, Module, VarBuilder}; use serde::Deserialize; use std::collections::HashMap; @@ -339,15 +339,6 @@ pub struct GTEModel { impl GTEModel { pub fn load(vb: VarBuilder, config: >EConfig, model_type: ModelType) -> Result { - match vb.device() { - Device::Cuda(_) => {} - _ => candle::bail!("FlashGTE requires Cuda"), - } - - if vb.dtype() != DType::F16 { - candle::bail!("FlashGTE requires DType::F16") - } - if config.logn_attention_clip1 { candle::bail!("`logn_attention_clip1` is not supported"); }