Skip to content

Commit

Permalink
Add queue to support streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
leng-yue committed May 1, 2024
1 parent 89e2aa9 commit dcbe986
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 34 deletions.
22 changes: 16 additions & 6 deletions tools/api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import base64
import io
import queue
import threading
import traceback
from argparse import ArgumentParser
Expand Down Expand Up @@ -114,17 +115,26 @@ def inference(req: InvokeRequest):
)

payload = dict(
event=threading.Event(),
response_queue=queue.Queue(),
request=request,
)
llama_queue.put(payload)

# Wait for the result
payload["event"].wait()
if payload["success"] is False:
raise payload["response"]
codes = []
while True:
result = payload["response_queue"].get()
if result == "next":
# TODO: handle next sentence
continue

codes = payload["response"][0]
if result == "done":
if payload["success"] is False:
raise payload["response"]
break

codes.append(result)

codes = torch.cat(codes, dim=1)

# VQGAN Inference
feature_lengths = torch.tensor([codes.shape[1]], device=vqgan_model.device)
Expand Down
42 changes: 20 additions & 22 deletions tools/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,16 +470,14 @@ def generate_long(
texts = split_text(text, chunk_length) if iterative_prompt else [text]

if use_prompt:
encoded.append(
encode_tokens(
tokenizer,
prompt_text,
prompt_tokens=prompt_tokens,
bos=True,
device=device,
speaker=speaker,
num_codebooks=model.config.num_codebooks,
)
encoded_prompts = encode_tokens(
tokenizer,
prompt_text,
prompt_tokens=prompt_tokens,
bos=True,
device=device,
speaker=speaker,
num_codebooks=model.config.num_codebooks,
)

for idx, text in enumerate(texts):
Expand All @@ -501,10 +499,6 @@ def generate_long(
all_codes = []
seg_idx = 0

if use_prompt:
seg_idx = 1
global_encoded.append(encoded[0])

while seg_idx < len(encoded):
logger.info(
f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
Expand All @@ -531,6 +525,9 @@ def generate_long(
else:
partial_encoded = global_encoded

if use_prompt:
partial_encoded = [encoded_prompts] + partial_encoded

cat_encoded = torch.cat(partial_encoded, dim=1)
prompt_length = cat_encoded.size(1)

Expand Down Expand Up @@ -593,7 +590,7 @@ def generate_long(

if is_streaming:
# This indicates the end of the current sample
yield None
yield "next"
else:
all_codes = torch.cat(all_codes, dim=1)
assert (all_codes >= 0).all(), f"Negative code found: {codes}"
Expand Down Expand Up @@ -623,20 +620,21 @@ def worker():
break

kwargs = item["request"]
event = item["event"]
response_queue = item["response_queue"]

try:
item["success"] = True
item["response"] = list(
generate_long(
model=model, decode_one_token=decode_one_token, **kwargs
)
)
for chunk in generate_long(
model=model, decode_one_token=decode_one_token, **kwargs
):
response_queue.put(chunk)

response_queue.put("done")
except Exception as e:
item["success"] = False
item["response"] = e

event.set()
response_queue.put("done")

threading.Thread(target=worker, daemon=True).start()
init_event.wait()
Expand Down
22 changes: 16 additions & 6 deletions tools/webui.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import gc
import html
import os
import queue
import threading
from argparse import ArgumentParser
from pathlib import Path
Expand Down Expand Up @@ -119,17 +120,26 @@ def inference(
)

payload = dict(
event=threading.Event(),
response_queue=queue.Queue(),
request=request,
)
llama_queue.put(payload)

# Wait for the result
payload["event"].wait()
if payload["success"] is False:
raise payload["response"]
codes = []
while True:
result = payload["response_queue"].get()
if result == "next":
# TODO: handle next sentence
continue

codes = payload["response"][0]
if result == "done":
if payload["success"] is False:
raise payload["response"]
break

codes.append(result)

codes = torch.cat(codes, dim=1)

# VQGAN Inference
feature_lengths = torch.tensor([codes.shape[1]], device=vqgan_model.device)
Expand Down

0 comments on commit dcbe986

Please sign in to comment.