Skip to content

Commit

Permalink
update: enable GTEModel for cuda
Browse files Browse the repository at this point in the history
  • Loading branch information
kozistr committed Nov 30, 2024
1 parent 9a3676c commit e0e3bc6
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 11 deletions.
3 changes: 2 additions & 1 deletion backends/candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,8 @@ impl CandleBackend {
if dtype != DType::F16
|| !cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
{
return Err(BackendError::Start("GTE is only supported on Cuda devices in fp16 with flash attention enabled".to_string()));
tracing::info!("Starting GTE model on {:?}", device);
Ok(Box::new(GTEModel::load(vb, &config, model_type).s()?))
}
tracing::info!("Starting FlashGTE model on {:?}", device);
Ok(Box::new(FlashGTEModel::load(vb, &config, model_type).s()?))
Expand Down
11 changes: 1 addition & 10 deletions backends/candle/src/models/gte.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::layers::{get_cublas_lt_wrapper, HiddenAct, LayerNorm, Linear};
use crate::models::{apply_rotary, cos_sin, inv_freqs, Model, PositionEmbeddingType};
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle::{Device, IndexOp, Result, Tensor, D};
use candle_nn::{Embedding, Module, VarBuilder};
use serde::Deserialize;
use std::collections::HashMap;
Expand Down Expand Up @@ -339,15 +339,6 @@ pub struct GTEModel {

impl GTEModel {
pub fn load(vb: VarBuilder, config: &GTEConfig, model_type: ModelType) -> Result<Self> {
match vb.device() {
Device::Cuda(_) => {}
_ => candle::bail!("FlashGTE requires Cuda"),
}

if vb.dtype() != DType::F16 {
candle::bail!("FlashGTE requires DType::F16")
}

if config.logn_attention_clip1 {
candle::bail!("`logn_attention_clip1` is not supported");
}
Expand Down

0 comments on commit e0e3bc6

Please sign in to comment.