Skip to content

Commit

Permalink
Record number of skipped tokens in the response
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair committed Nov 14, 2024
1 parent fc2cebb commit ee24cf4
Show file tree
Hide file tree
Showing 8 changed files with 37 additions and 5 deletions.
4 changes: 4 additions & 0 deletions clients/python/lorax/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,8 @@ class Token(BaseModel):
special: bool
# Alternative tokens
alternative_tokens: Optional[List[AlternativeToken]] = None
# If token was skipped due to speculative decoding
skipped: bool


# Generation finish reason
Expand Down Expand Up @@ -312,6 +314,8 @@ class Details(BaseModel):
prompt_tokens: int
# Number of generated tokens
generated_tokens: int
# Number of skipped tokens
skipped_tokens: int
# Sampling seed if sampling was activated
seed: Optional[int] = None
# Decoder input tokens, empty if decoder_input_details is False
Expand Down
6 changes: 4 additions & 2 deletions proto/generate.proto
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,12 @@ message GeneratedText {
string text = 1;
/// Number of generated tokens
uint32 generated_tokens = 2;
/// Number of skipped tokens due to speculative decoding hits
uint32 skipped_tokens = 3;
/// Finish reason
FinishReason finish_reason = 3;
FinishReason finish_reason = 4;
/// Seed
optional uint64 seed = 4;
optional uint64 seed = 5;
}

message PrefillTokens {
Expand Down
5 changes: 4 additions & 1 deletion router/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1398,9 +1398,11 @@ fn send_responses(
next_tokens.is_special,
alternative_tokens,
))
.enumerate()
.peekable();

while let Some((id, logprob, text, special, alternative_tokens)) = iterator.next() {
while let Some((idx, (id, logprob, text, special, alternative_tokens))) = iterator.next() {
let skipped = idx > 0;
let token = Token {
id,
text,
Expand All @@ -1416,6 +1418,7 @@ fn send_responses(
.collect(),
)
}),
skipped,
};

match (&generation.generated_text, iterator.peek()) {
Expand Down
4 changes: 4 additions & 0 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,8 @@ pub struct Token {
#[schema(nullable = true)]
#[serde(skip_serializing_if = "Option::is_none")]
alternative_tokens: Option<Vec<AlternativeToken>>,
#[schema(example = "false")]
skipped: bool,
}

#[derive(Debug, Serialize, ToSchema)]
Expand Down Expand Up @@ -462,6 +464,8 @@ pub(crate) struct Details {
pub prompt_tokens: u32,
#[schema(example = 1)]
pub generated_tokens: u32,
#[schema(example = 1)]
pub skipped_tokens: u32,
#[schema(nullable = true, example = 42)]
pub seed: Option<u64>,
pub prefill: Vec<PrefillToken>,
Expand Down
7 changes: 7 additions & 0 deletions router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,7 @@ async fn generate(
};

let generated_tokens = response.generated_text.generated_tokens;
let skipped_tokens = response.generated_text.skipped_tokens;
let prompt_tokens = response.prompt_tokens;
let total_tokens = prompt_tokens + generated_tokens;

Expand Down Expand Up @@ -680,6 +681,7 @@ async fn generate(
finish_reason: FinishReason::from(response.generated_text.finish_reason),
prompt_tokens: prompt_tokens,
generated_tokens: generated_tokens,
skipped_tokens: skipped_tokens,
prefill: response.prefill,
tokens: response.tokens,
seed: response.generated_text.seed,
Expand All @@ -705,6 +707,7 @@ async fn generate(
span.record("seed", format!("{:?}", response.generated_text.seed));
span.record("prompt_tokens", format!("{prompt_tokens:?}"));
span.record("generated_tokens", format!("{generated_tokens:?}"));
span.record("skipped_tokens", format!("{skipped_tokens:?}"));

// Headers
let mut headers = HeaderMap::new();
Expand All @@ -729,6 +732,10 @@ async fn generate(
"x-generated-tokens",
generated_tokens.to_string().parse().unwrap(),
);
headers.insert(
"x-skipped-tokens",
skipped_tokens.to_string().parse().unwrap(),
);
headers.insert("x-total-tokens", total_tokens.to_string().parse().unwrap());
headers.insert(
"x-validation-time",
Expand Down
6 changes: 6 additions & 0 deletions server/lorax_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1980,6 +1980,8 @@ def generate_token(
if n_accepted_ids > 1:
logger.debug(f"speculated ids {n_accepted_ids - 1}")

# First token is not skipped, next tokens are
skipped = False
current_stopped = False
for j in range(index, index + n_accepted_ids):
# Generated token
Expand All @@ -1995,8 +1997,11 @@ def generate_token(
stop, reason = stopping_criteria(
next_token_id,
next_token_text,
skipped=skipped,
)

# All subsequent tokens are skipped
skipped = True
if stop:
left = index + n_accepted_ids - j - 1
current_stopped = True
Expand All @@ -2022,6 +2027,7 @@ def generate_token(
generated_text = GeneratedText(
output_text,
stopping_criteria.current_tokens,
stopping_criteria.current_skipped,
reason,
seed if do_sample else None,
)
Expand Down
2 changes: 2 additions & 0 deletions server/lorax_server/models/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,15 @@ def __len__(self):
class GeneratedText:
text: str
generated_tokens: int
skipped_tokens: int
finish_reason: FinishReason
seed: Optional[int]

def to_pb(self) -> generate_pb2.GeneratedText:
return generate_pb2.GeneratedText(
text=self.text,
generated_tokens=self.generated_tokens,
skipped_tokens=self.skipped_tokens,
finish_reason=self.finish_reason,
seed=self.seed,
)
Expand Down
8 changes: 6 additions & 2 deletions server/lorax_server/utils/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,13 @@ def __init__(
self.max_new_tokens = max_new_tokens
self.current_tokens = 0
self.current_output = ""
self.current_skipped = 0
self.ignore_eos_token = ignore_eos_token

def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:

def __call__(self, last_token: int, last_output: str, skipped: bool = False) -> Tuple[bool, Optional[str]]:
if skipped:
self.current_skipped += 1

self.current_tokens += 1
if self.current_tokens >= self.max_new_tokens:
return True, FinishReason.FINISH_REASON_LENGTH
Expand Down

0 comments on commit ee24cf4

Please sign in to comment.