Skip to content

Commit

Permalink
Heap based router queue (huggingface#63)
Browse files Browse the repository at this point in the history
  • Loading branch information
mrs303 authored Feb 26, 2024
1 parent cf05929 commit e1e0457
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 11 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ Environment Variables Added:
| SKIP_TOKENIZER_IN_TGI | True/False | False | Skip tokenizer for input/output processing | add -e in docker run command |
| TGI_PROFILER_ENABLED | True/False | False | Collect high-level server tracing events | add -e in docker run command |
| WARMUP_ENABLED | True/False | True | Enable warmup during server initialization to recompile all graphs. This can increase TGI setup time. | add -e in docker run command |
| QUEUE_THRESHOLD_MS | integer | 120 | Controls the threshold beyond which the request are considered overdue and handled with priority. Shorter requests are prioritized otherwise. | add -e in docker run command |
</div>


Expand Down
124 changes: 114 additions & 10 deletions router/src/queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ use crate::infer::InferStreamResponse;
use crate::validation::ValidGenerateRequest;
use nohash_hasher::{BuildNoHashHasher, IntMap};
use std::cmp::min;
use std::collections::VecDeque;
use std::cmp::{Eq, Ord, PartialEq, PartialOrd};
use std::collections::BinaryHeap;
use std::env;
use std::time::Duration;
use text_generation_client::{Batch, Request};
use tokio::sync::{mpsc, oneshot};
use tokio::time::Instant;
Expand Down Expand Up @@ -132,11 +135,104 @@ async fn queue_task(
}
}

#[derive(Debug)]
struct IdentifiableEntry(u64, Entry);

impl Eq for IdentifiableEntry {}

impl PartialEq for IdentifiableEntry {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
}
}

impl Ord for IdentifiableEntry {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
let ordering = match self
.1
.request
.input_length
.cmp(&other.1.request.input_length)
{
std::cmp::Ordering::Equal => self.0.cmp(&other.0),
any => any,
};

// inverse to get min heap
return ordering.reverse();
}
}

impl PartialOrd for IdentifiableEntry {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}

#[derive(Debug)]
struct QueueImpl {
regular_entries: BinaryHeap<IdentifiableEntry>,
overdue_entries: BinaryHeap<IdentifiableEntry>,
overdue_threshold: Duration,
}

impl QueueImpl {
fn new(capacity: usize, overdue_threshold: Duration) -> Self {
Self {
regular_entries: BinaryHeap::with_capacity(capacity),
overdue_entries: BinaryHeap::with_capacity(capacity),
overdue_threshold,
}
}

fn update(&mut self) {
if self.regular_entries.is_empty() {
return;
}

let mut left = BinaryHeap::with_capacity(self.regular_entries.capacity());

for entry in self.regular_entries.drain() {
if entry.1.queue_time.elapsed() > self.overdue_threshold {
self.overdue_entries.push(entry);
} else {
left.push(entry);
}
}

self.regular_entries = left;
}

fn push(&mut self, entry: IdentifiableEntry) {
if entry.1.queue_time.elapsed() > self.overdue_threshold {
self.overdue_entries.push(entry);
} else {
self.regular_entries.push(entry);
}
}

fn pop(&mut self) -> Option<IdentifiableEntry> {
if !self.overdue_entries.is_empty() {
self.overdue_entries.pop()
} else {
self.regular_entries.pop()
}
}

fn is_empty(&self) -> bool {
self.regular_entries.is_empty() && self.overdue_entries.is_empty()
}

fn len(&self) -> usize {
self.regular_entries.len() + self.overdue_entries.len()
}
}

/// Queue State
#[derive(Debug)]
struct State {
/// Queue entries organized in a Vec
entries: VecDeque<(u64, Entry)>,
/// Queue entries
entries: QueueImpl,

/// Id of the next entry
next_id: u64,
Expand Down Expand Up @@ -166,10 +262,16 @@ impl State {
max_input_length: u32,
max_total_tokens: u32,
block_size: u32,
window_size: Option<u32>
window_size: Option<u32>,
) -> Self {
let default_threshold: u64 = 120;
let threshold: u64 = match env::var("QUEUE_THRESHOLD_MS") {
Ok(val) => val.parse().unwrap_or(default_threshold),
Err(_) => default_threshold,
};

Self {
entries: VecDeque::with_capacity(128),
entries: QueueImpl::new(128, Duration::from_millis(threshold)),
next_id: 0,
next_batch_id: 0,
requires_padding,
Expand All @@ -187,7 +289,7 @@ impl State {
entry.temp_span = Some(queue_span);

// Push entry in the queue
self.entries.push_back((self.next_id, entry));
self.entries.push(IdentifiableEntry(self.next_id, entry));
self.next_id += 1;
}

Expand All @@ -209,6 +311,8 @@ impl State {
}
}

self.entries.update();

// Create span for this batch to add context to inference calls
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
next_batch_span.follows_from(&Span::current());
Expand All @@ -221,7 +325,7 @@ impl State {
let mut decode_tokens: u32 = 0;

// Pop entries starting from the front of the queue
while let Some((id, mut entry)) = self.entries.pop_front() {
while let Some(IdentifiableEntry(id, mut entry)) = self.entries.pop() {
// Filter entries where the response receiver was dropped (== entries where the request
// was dropped by the client)
if entry.response_tx.is_closed() {
Expand Down Expand Up @@ -263,7 +367,7 @@ impl State {
{
// Entry is over budget
// Add it back to the front
self.entries.push_front((id, entry));
self.entries.push(IdentifiableEntry(id, entry));
break;
}

Expand Down Expand Up @@ -303,7 +407,7 @@ impl State {
for r in batch_requests.into_iter().rev() {
let id = r.id;
let entry = batch_entries.remove(&id).unwrap();
self.entries.push_front((id, entry));
self.entries.push(IdentifiableEntry(id, entry));
}

return None;
Expand Down Expand Up @@ -399,7 +503,7 @@ mod tests {

assert_eq!(state.next_id, 1);
assert_eq!(state.entries.len(), 1);
let (id, _) = state.entries.remove(0).unwrap();
let id = state.entries.pop().unwrap().0;
assert_eq!(id, 0);
}

Expand Down
13 changes: 12 additions & 1 deletion router/src/validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
use crate::{GenerateParameters, GenerateRequest};
use rand::{thread_rng, Rng};
use std::env;
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
use thiserror::Error;
use tokenizers::tokenizer::Tokenizer;
Expand All @@ -21,6 +22,7 @@ pub struct Validation {
max_total_tokens: usize,
/// Channel to communicate with the background tokenization task
sender: Option<mpsc::UnboundedSender<TokenizerRequest>>,
skip_tokenizer_in_tgi: bool,
}

impl Validation {
Expand Down Expand Up @@ -59,13 +61,18 @@ impl Validation {
None
};

let skip_tokenizer_in_tgi = env::var("SKIP_TOKENIZER_IN_TGI")
.ok()
.map_or(false, |value| value.to_lowercase() == "true");

Self {
max_best_of,
sender,
max_stop_sequences,
max_top_n_tokens,
max_input_length,
max_total_tokens,
skip_tokenizer_in_tgi,
}
}

Expand Down Expand Up @@ -130,7 +137,11 @@ impl Validation {
} else {
return Err(ValidationError::UnsetMaxNewTokens);
};
let input_length = truncate.unwrap_or(self.max_input_length);
let input_length = if self.skip_tokenizer_in_tgi {
inputs.chars().filter(|&c| c == ',').count() + 1
} else {
truncate.unwrap_or(self.max_input_length)
};

// Validate MaxNewTokens
if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 {
Expand Down

0 comments on commit e1e0457

Please sign in to comment.