diff --git a/tools/api.py b/tools/api.py index 3672a7e8..68593bd9 100644 --- a/tools/api.py +++ b/tools/api.py @@ -1,5 +1,6 @@ import base64 import io +import queue import threading import traceback from argparse import ArgumentParser @@ -114,17 +115,26 @@ def inference(req: InvokeRequest): ) payload = dict( - event=threading.Event(), + response_queue=queue.Queue(), request=request, ) llama_queue.put(payload) - # Wait for the result - payload["event"].wait() - if payload["success"] is False: - raise payload["response"] + codes = [] + while True: + result = payload["response_queue"].get() + if result == "next": + # TODO: handle next sentence + continue - codes = payload["response"][0] + if result == "done": + if payload["success"] is False: + raise payload["response"] + break + + codes.append(result) + + codes = torch.cat(codes, dim=1) # VQGAN Inference feature_lengths = torch.tensor([codes.shape[1]], device=vqgan_model.device) diff --git a/tools/llama/generate.py b/tools/llama/generate.py index 36b65962..82d04a45 100644 --- a/tools/llama/generate.py +++ b/tools/llama/generate.py @@ -470,16 +470,14 @@ def generate_long( texts = split_text(text, chunk_length) if iterative_prompt else [text] if use_prompt: - encoded.append( - encode_tokens( - tokenizer, - prompt_text, - prompt_tokens=prompt_tokens, - bos=True, - device=device, - speaker=speaker, - num_codebooks=model.config.num_codebooks, - ) + encoded_prompts = encode_tokens( + tokenizer, + prompt_text, + prompt_tokens=prompt_tokens, + bos=True, + device=device, + speaker=speaker, + num_codebooks=model.config.num_codebooks, ) for idx, text in enumerate(texts): @@ -501,10 +499,6 @@ def generate_long( all_codes = [] seg_idx = 0 - if use_prompt: - seg_idx = 1 - global_encoded.append(encoded[0]) - while seg_idx < len(encoded): logger.info( f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}" @@ -531,6 +525,9 @@ def generate_long( else: partial_encoded = global_encoded + if use_prompt: + partial_encoded = [encoded_prompts] + partial_encoded + cat_encoded = torch.cat(partial_encoded, dim=1) prompt_length = cat_encoded.size(1) @@ -593,7 +590,7 @@ def generate_long( if is_streaming: # This indicates the end of the current sample - yield None + yield "next" else: all_codes = torch.cat(all_codes, dim=1) assert (all_codes >= 0).all(), f"Negative code found: {codes}" @@ -623,20 +620,21 @@ def worker(): break kwargs = item["request"] - event = item["event"] + response_queue = item["response_queue"] try: item["success"] = True - item["response"] = list( - generate_long( - model=model, decode_one_token=decode_one_token, **kwargs - ) - ) + for chunk in generate_long( + model=model, decode_one_token=decode_one_token, **kwargs + ): + response_queue.put(chunk) + + response_queue.put("done") except Exception as e: item["success"] = False item["response"] = e - event.set() + response_queue.put("done") threading.Thread(target=worker, daemon=True).start() init_event.wait() diff --git a/tools/webui.py b/tools/webui.py index 7c41d845..829dd7fd 100644 --- a/tools/webui.py +++ b/tools/webui.py @@ -1,6 +1,7 @@ import gc import html import os +import queue import threading from argparse import ArgumentParser from pathlib import Path @@ -119,17 +120,26 @@ def inference( ) payload = dict( - event=threading.Event(), + response_queue=queue.Queue(), request=request, ) llama_queue.put(payload) - # Wait for the result - payload["event"].wait() - if payload["success"] is False: - raise payload["response"] + codes = [] + while True: + result = payload["response_queue"].get() + if result == "next": + # TODO: handle next sentence + continue - codes = payload["response"][0] + if result == "done": + if payload["success"] is False: + raise payload["response"] + break + + codes.append(result) + + codes = torch.cat(codes, dim=1) # VQGAN Inference feature_lengths = torch.tensor([codes.shape[1]], device=vqgan_model.device)