diff --git a/mistralrs-core/src/pipeline/mod.rs b/mistralrs-core/src/pipeline/mod.rs index e7d2e913a3..9bcaafaf5d 100644 --- a/mistralrs-core/src/pipeline/mod.rs +++ b/mistralrs-core/src/pipeline/mod.rs @@ -3,6 +3,7 @@ mod llama; mod mistral; mod mixtral; use candle_sampling::logits_processor::Logprobs; +use core::fmt; use either::Either; pub use gemma::{GemmaLoader, GemmaSpecificConfig, GEMMA_IS_GPTX}; use hf_hub::{ @@ -83,6 +84,7 @@ pub struct ChatTemplate { use_default_system_prompt: Option, } +#[derive(Debug, Clone)] pub enum TokenSource { Literal(String), EnvVar(String), @@ -90,6 +92,43 @@ pub enum TokenSource { CacheToken, } +impl FromStr for TokenSource { + type Err = String; + + fn from_str(s: &str) -> Result { + let parts: Vec<&str> = s.splitn(2, ':').collect(); + match parts[0] { + "literal" => parts + .get(1) + .map(|&value| TokenSource::Literal(value.to_string())) + .ok_or_else(|| "Expected a value for 'literal'".to_string()), + "env" => Ok(TokenSource::EnvVar( + parts + .get(1) + .unwrap_or(&"HUGGING_FACE_HUB_TOKEN") + .to_string(), + )), + "path" => parts + .get(1) + .map(|&value| TokenSource::Path(value.to_string())) + .ok_or_else(|| "Expected a value for 'path'".to_string()), + "cache" => Ok(TokenSource::CacheToken), + _ => Err("Invalid token source format".to_string()), + } + } +} + +impl fmt::Display for TokenSource { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + TokenSource::Literal(value) => write!(f, "literal:{}", value), + TokenSource::EnvVar(value) => write!(f, "env:{}", value), + TokenSource::Path(value) => write!(f, "path:{}", value), + TokenSource::CacheToken => write!(f, "cache"), + } + } +} + pub enum ModelKind { Normal, XLoraNormal, diff --git a/mistralrs-server/src/main.rs b/mistralrs-server/src/main.rs index da302e3614..b644a329cf 100644 --- a/mistralrs-server/src/main.rs +++ b/mistralrs-server/src/main.rs @@ -455,6 +455,10 @@ pub enum ModelSelected { }, } +fn parse_token_source(s: &str) -> Result { + s.parse() +} + #[derive(Parser)] #[command(version, about, long_about = None)] struct Args { @@ -492,6 +496,12 @@ struct Args { /// Used if the automatic deserialization fails. If this ends with `.json` (ie., it is a file) then that template is loaded. #[arg(short, long)] chat_template: Option, + + /// Source of the token for authentication. + /// Can be in the formats: "literal:", "env:", "path:", or "cache" to use a cached token. + /// Defaults to using a cached token. + #[arg(long, default_value_t = TokenSource::CacheToken, value_parser = parse_token_source)] + token_source: TokenSource, } async fn chatcompletions( @@ -981,7 +991,7 @@ async fn main() -> Result<()> { #[cfg(not(feature = "metal"))] let device = Device::cuda_if_available(0)?; - let pipeline = loader.load_model(None, TokenSource::CacheToken, None, &device)?; + let pipeline = loader.load_model(None, args.token_source, None, &device)?; let mistralrs = MistralRs::new( pipeline, SchedulerMethod::Fixed(args.max_seqs.try_into().unwrap()),