diff --git a/backends/candle/src/models/flash_gte.rs b/backends/candle/src/models/flash_gte.rs index 7f3386ab..52db0854 100644 --- a/backends/candle/src/models/flash_gte.rs +++ b/backends/candle/src/models/flash_gte.rs @@ -1,7 +1,8 @@ use crate::flash_attn::flash_attn_varlen; use crate::layers::{HiddenAct, LayerNorm, Linear}; -use crate::models::{ - GTEClassificationHead, GTEConfig, Model, NTKScaling, PositionEmbeddingType, RopeScaling, +use crate::models::gte::{ + ClassificationHead, GTEClassificationHead, GTEConfig, Model, NTKScaling, PositionEmbeddingType, + RopeScaling, }; use candle::{DType, Device, IndexOp, Result, Tensor}; use candle_nn::{Embedding, Module, VarBuilder};