Skip to content

Commit

Permalink
Merge pull request EricLBuehler#36 from LLukas22/tokensource
Browse files Browse the repository at this point in the history
Server: Allow `TokenSource` to be configured via cli
  • Loading branch information
EricLBuehler authored Mar 27, 2024
2 parents 27cff95 + 8651b17 commit 3207f84
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 1 deletion.
39 changes: 39 additions & 0 deletions mistralrs-core/src/pipeline/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ mod llama;
mod mistral;
mod mixtral;
use candle_sampling::logits_processor::Logprobs;
use core::fmt;
use either::Either;
pub use gemma::{GemmaLoader, GemmaSpecificConfig, GEMMA_IS_GPTX};
use hf_hub::{
Expand Down Expand Up @@ -83,13 +84,51 @@ pub struct ChatTemplate {
use_default_system_prompt: Option<bool>,
}

#[derive(Debug, Clone)]
pub enum TokenSource {
Literal(String),
EnvVar(String),
Path(String),
CacheToken,
}

impl FromStr for TokenSource {
type Err = String;

fn from_str(s: &str) -> Result<Self, Self::Err> {
let parts: Vec<&str> = s.splitn(2, ':').collect();
match parts[0] {
"literal" => parts
.get(1)
.map(|&value| TokenSource::Literal(value.to_string()))
.ok_or_else(|| "Expected a value for 'literal'".to_string()),
"env" => Ok(TokenSource::EnvVar(
parts
.get(1)
.unwrap_or(&"HUGGING_FACE_HUB_TOKEN")
.to_string(),
)),
"path" => parts
.get(1)
.map(|&value| TokenSource::Path(value.to_string()))
.ok_or_else(|| "Expected a value for 'path'".to_string()),
"cache" => Ok(TokenSource::CacheToken),
_ => Err("Invalid token source format".to_string()),
}
}
}

impl fmt::Display for TokenSource {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TokenSource::Literal(value) => write!(f, "literal:{}", value),
TokenSource::EnvVar(value) => write!(f, "env:{}", value),
TokenSource::Path(value) => write!(f, "path:{}", value),
TokenSource::CacheToken => write!(f, "cache"),
}
}
}

pub enum ModelKind {
Normal,
XLoraNormal,
Expand Down
12 changes: 11 additions & 1 deletion mistralrs-server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,10 @@ pub enum ModelSelected {
},
}

fn parse_token_source(s: &str) -> Result<TokenSource, String> {
s.parse()
}

#[derive(Parser)]
#[command(version, about, long_about = None)]
struct Args {
Expand Down Expand Up @@ -492,6 +496,12 @@ struct Args {
/// Used if the automatic deserialization fails. If this ends with `.json` (ie., it is a file) then that template is loaded.
#[arg(short, long)]
chat_template: Option<String>,

/// Source of the token for authentication.
/// Can be in the formats: "literal:<value>", "env:<value>", "path:<value>", or "cache" to use a cached token.

Check warning on line 501 in mistralrs-server/src/main.rs

View workflow job for this annotation

GitHub Actions / Docs

unclosed HTML tag `value`

Check warning on line 501 in mistralrs-server/src/main.rs

View workflow job for this annotation

GitHub Actions / Docs

unclosed HTML tag `value`

Check warning on line 501 in mistralrs-server/src/main.rs

View workflow job for this annotation

GitHub Actions / Docs

unclosed HTML tag `value`

Check warning on line 501 in mistralrs-server/src/main.rs

View workflow job for this annotation

GitHub Actions / Docs

unclosed HTML tag `value`

Check warning on line 501 in mistralrs-server/src/main.rs

View workflow job for this annotation

GitHub Actions / Docs

unclosed HTML tag `value`

Check warning on line 501 in mistralrs-server/src/main.rs

View workflow job for this annotation

GitHub Actions / Docs

unclosed HTML tag `value`
/// Defaults to using a cached token.
#[arg(long, default_value_t = TokenSource::CacheToken, value_parser = parse_token_source)]
token_source: TokenSource,
}

async fn chatcompletions(
Expand Down Expand Up @@ -981,7 +991,7 @@ async fn main() -> Result<()> {
#[cfg(not(feature = "metal"))]
let device = Device::cuda_if_available(0)?;

let pipeline = loader.load_model(None, TokenSource::CacheToken, None, &device)?;
let pipeline = loader.load_model(None, args.token_source, None, &device)?;
let mistralrs = MistralRs::new(
pipeline,
SchedulerMethod::Fixed(args.max_seqs.try_into().unwrap()),
Expand Down

0 comments on commit 3207f84

Please sign in to comment.