Skip to content

Commit

Permalink
Fix formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
g-eoj committed May 30, 2024
1 parent 7d85cef commit 9138c32
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 28 deletions.
6 changes: 3 additions & 3 deletions vllm/engine/output_processor/single_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,9 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
# We reuse the parent sequence here to reduce redundant memory
# copies, especially when using non-beam search sampling methods.
last_child_sample = child_samples[-1]
parent.append_token_id(last_child_sample.output_token,
last_child_sample.logprobs,
last_child_sample.output_classification_probs)
parent.append_token_id(
last_child_sample.output_token, last_child_sample.logprobs,
last_child_sample.output_classification_probs)
child_seqs.append((parent, parent))

for seq, _ in child_seqs:
Expand Down
25 changes: 17 additions & 8 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def __init__(self):
self.include_gpu_probs_tensor = False

self.classification_head = torch.nn.Linear(1, 1, bias=False).to("cuda")
self.classification_head.weight.data = torch.load("classification_head.pth", map_location="cuda").bfloat16()
self.classification_head.weight.data = torch.load(
"classification_head.pth", map_location="cuda").bfloat16()

def forward(
self,
Expand All @@ -66,8 +67,7 @@ def forward(
logits = _apply_min_tokens_penalty(logits, sampling_metadata)

classification_probs = torch.nn.functional.sigmoid(
self.classification_head(logits)
).flatten().tolist()
self.classification_head(logits)).flatten().tolist()

# Prepare sampling tensors with pinned memory to avoid blocking.
(sampling_tensors, do_penalties, do_top_p_top_k,
Expand Down Expand Up @@ -1018,20 +1018,29 @@ def _build_sampler_output(
"""

sampler_output = []
for (seq_group, sample_result, group_prompt_logprobs, group_sample_logprobs) in zip(sampling_metadata.seq_groups, sample_results, prompt_logprobs, sample_logprobs):
for (seq_group, sample_result, group_prompt_logprobs,
group_sample_logprobs) in zip(sampling_metadata.seq_groups,
sample_results, prompt_logprobs,
sample_logprobs):
seq_ids = seq_group.seq_ids
next_token_ids, parent_ids = sample_result
seq_outputs = []
for parent_id, next_token_id, logprobs, sample_idx in zip(parent_ids, next_token_ids, group_sample_logprobs, seq_group.sample_indices):
seq_outputs.append(SequenceOutput(seq_ids[parent_id], next_token_id, logprobs, classification_probs[sample_idx]))
sampler_output.append(CompletionSequenceGroupOutput(seq_outputs, group_prompt_logprobs))
for parent_id, next_token_id, logprobs, sample_idx in zip(
parent_ids, next_token_ids, group_sample_logprobs,
seq_group.sample_indices):
seq_outputs.append(
SequenceOutput(seq_ids[parent_id], next_token_id, logprobs,
classification_probs[sample_idx]))
sampler_output.append(
CompletionSequenceGroupOutput(seq_outputs, group_prompt_logprobs))

# If not specified, store None values in SamplerOutput.
if on_device_tensors is not None:
(sampled_token_probs, logprobs_tensor,
sampled_token_ids) = on_device_tensors
else:
sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None, None)
sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None,
None)
return SamplerOutput(
outputs=sampler_output,
sampled_token_probs=sampled_token_probs,
Expand Down
33 changes: 16 additions & 17 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,10 @@ def __init__(
self._num_computed_tokens = 0
self._stage: SequenceStage = SequenceStage.PREFILL

def append_token_id(self, token_id: int, logprob: float, classification_probs: List[float]) -> None:
def append_token_id(self, token_id: int, logprob: float,
classification_prob: float) -> None:
self.output_token_ids.append(token_id)
self.output_classification_probs.append(classification_probs)
self.output_classification_probs.append(classification_prob)
self.cumulative_logprob += logprob

def get_len(self) -> int:
Expand Down Expand Up @@ -237,7 +238,7 @@ def __init__(
self.eos_token_id = eos_token_id
self.lora_request = lora_request
self.data = SequenceData(self.prompt_token_ids)
self.output_classification_probs = []
self.output_classification_probs: List[float] = []
self.output_logprobs: SampleLogprobs = []
self.output_text = ""

Expand Down Expand Up @@ -319,13 +320,14 @@ def append_token_id(
self,
token_id: int,
logprobs: Dict[int, Logprob],
classification_probs: List[float],
classification_prob: float,
) -> None:
assert token_id in logprobs
self._append_tokens_to_blocks([token_id])
self.output_logprobs.append(logprobs)
self.output_classification_probs.append(classification_probs)
self.data.append_token_id(token_id, logprobs[token_id].logprob, classification_probs)
self.output_classification_probs.append(classification_prob)
self.data.append_token_id(token_id, logprobs[token_id].logprob,
classification_prob)

def get_len(self) -> int:
return self.data.get_len()
Expand Down Expand Up @@ -719,23 +721,20 @@ class SequenceOutput:
(Token id -> logP(x_i+1 | x_0, ..., x_i))
"""

def __init__(
self,
parent_seq_id: int,
output_token: int,
logprobs: Dict[int, Logprob],
classification_probs: List[float]
) -> None:
def __init__(self, parent_seq_id: int, output_token: int,
logprobs: Dict[int, Logprob],
classification_probs: List[float]) -> None:
self.parent_seq_id = parent_seq_id
self.output_classification_probs = classification_probs
self.output_token = output_token
self.logprobs = logprobs

def __repr__(self) -> str:
return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
f"output_classification_probs={self.output_classification_probs}, "
f"output_token={self.output_token}, "
f"logprobs={self.logprobs})")
return (
f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
f"output_classification_probs={self.output_classification_probs}, "
f"output_token={self.output_token}, "
f"logprobs={self.logprobs})")

def __eq__(self, other: object) -> bool:
if not isinstance(other, SequenceOutput):
Expand Down

0 comments on commit 9138c32

Please sign in to comment.