Skip to content

Commit

Permalink
Optimize compute graph for dynamic params
Browse files Browse the repository at this point in the history
  • Loading branch information
leng-yue committed May 2, 2024
1 parent 1441f50 commit 813868f
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 38 deletions.
3 changes: 0 additions & 3 deletions tools/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ class InvokeRequest(BaseModel):
reference_audio: Optional[str] = None
max_new_tokens: int = 0
chunk_length: int = 30
top_k: int = 0
top_p: float = 0.7
repetition_penalty: float = 1.5
temperature: float = 0.7
Expand Down Expand Up @@ -104,7 +103,6 @@ def inference(req: InvokeRequest):
device=vqgan_model.device,
max_new_tokens=req.max_new_tokens,
text=req.text,
top_k=int(req.top_k) if req.top_k > 0 else None,
top_p=req.top_p,
repetition_penalty=req.repetition_penalty,
temperature=req.temperature,
Expand Down Expand Up @@ -281,7 +279,6 @@ def parse_args():
reference_audio=None,
max_new_tokens=0,
chunk_length=30,
top_k=0,
top_p=0.7,
repetition_penalty=1.5,
temperature=0.7,
Expand Down
54 changes: 27 additions & 27 deletions tools/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,38 +42,31 @@ def multinomial_sample_one_no_sync(
def logits_to_probs(
logits,
previous_tokens: Optional[torch.Tensor] = None,
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: Optional[int] = None,
repetition_penalty: float = 1.0,
):
if previous_tokens is not None and repetition_penalty != 1.0:
temperature: torch.Tensor = 1.0,
top_p: torch.Tensor = 1.0,
repetition_penalty: torch.Tensor = 1.0,
) -> torch.Tensor:
# Apply repetition penalty
if previous_tokens is not None:
previous_tokens = previous_tokens.long()
score = torch.gather(logits, dim=0, index=previous_tokens)
score = torch.where(
score < 0, score * repetition_penalty, score / repetition_penalty
)
logits.scatter_(dim=0, index=previous_tokens, src=score)

if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
)
sorted_indices_to_remove = cum_probs > top_p
sorted_indices_to_remove[0] = False # keep at least one option
indices_to_remove = sorted_indices_to_remove.scatter(
dim=0, index=sorted_indices, src=sorted_indices_to_remove
)
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
# Apply top-p sampling
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cum_probs > top_p
sorted_indices_to_remove[0] = False # keep at least one option
indices_to_remove = sorted_indices_to_remove.scatter(
dim=0, index=sorted_indices, src=sorted_indices_to_remove
)
logits = logits.masked_fill(indices_to_remove, -float("Inf"))

logits = logits / max(temperature, 1e-5)

if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
pivot = v.select(-1, -1).unsqueeze(-1)
logits = torch.where(logits < pivot, -float("Inf"), logits)

probs = torch.nn.functional.softmax(logits, dim=-1)
return probs

Expand Down Expand Up @@ -449,7 +442,6 @@ def generate_long(
text: str,
num_samples: int = 1,
max_new_tokens: int = 0,
top_k: int = None,
top_p: int = 0.7,
repetition_penalty: float = 1.5,
temperature: float = 0.7,
Expand All @@ -462,6 +454,10 @@ def generate_long(
prompt_tokens: Optional[torch.Tensor] = None,
is_streaming: bool = False,
):
assert 0 < top_p <= 1, "top_p must be in (0, 1]"
assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
assert 0 < temperature < 2, "temperature must be in (0, 2)"

model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")

Expand Down Expand Up @@ -493,6 +489,14 @@ def generate_long(
)
logger.info(f"Encoded text: {text}")

# Move temperature, top_p, repetition_penalty to device
# This is important so that changing params doesn't trigger recompile
temperature = torch.tensor(temperature, device=device, dtype=torch.float)
top_p = torch.tensor(top_p, device=device, dtype=torch.float)
repetition_penalty = torch.tensor(
repetition_penalty, device=device, dtype=torch.float
)

for sample_idx in range(num_samples):
if torch.cuda.is_available():
torch.cuda.synchronize()
Expand Down Expand Up @@ -542,7 +546,6 @@ def generate_long(
im_end_id=im_end_id,
decode_one_token=decode_one_token,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
)
Expand Down Expand Up @@ -660,7 +663,6 @@ def worker():
)
@click.option("--num-samples", type=int, default=1)
@click.option("--max-new-tokens", type=int, default=0)
@click.option("--top-k", type=int, default=None)
@click.option("--top-p", type=float, default=0.7)
@click.option("--repetition-penalty", type=float, default=1.5)
@click.option("--temperature", type=float, default=0.7)
Expand All @@ -684,7 +686,6 @@ def main(
prompt_tokens: Optional[Path],
num_samples: int,
max_new_tokens: int,
top_k: int,
top_p: int,
repetition_penalty: float,
temperature: float,
Expand Down Expand Up @@ -733,7 +734,6 @@ def main(
text=text,
num_samples=num_samples,
max_new_tokens=max_new_tokens,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
temperature=temperature,
Expand Down
8 changes: 0 additions & 8 deletions tools/webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def inference(
reference_text,
max_new_tokens,
chunk_length,
top_k,
top_p,
repetition_penalty,
temperature,
Expand Down Expand Up @@ -107,7 +106,6 @@ def inference(
device=vqgan_model.device,
max_new_tokens=max_new_tokens,
text=text,
top_k=int(top_k) if top_k > 0 else None,
top_p=top_p,
repetition_penalty=repetition_penalty,
temperature=temperature,
Expand Down Expand Up @@ -193,10 +191,6 @@ def build_app():
step=8,
)

top_k = gr.Slider(
label="Top-K", minimum=0, maximum=100, value=0, step=1
)

top_p = gr.Slider(
label="Top-P", minimum=0, maximum=1, value=0.7, step=0.01
)
Expand Down Expand Up @@ -266,7 +260,6 @@ def build_app():
reference_text,
max_new_tokens,
chunk_length,
top_k,
top_p,
repetition_penalty,
temperature,
Expand Down Expand Up @@ -337,7 +330,6 @@ def parse_args():
reference_text="",
max_new_tokens=0,
chunk_length=0,
top_k=0, # 0 means no limit
top_p=0.7,
repetition_penalty=1.5,
temperature=0.7,
Expand Down

0 comments on commit 813868f

Please sign in to comment.