Skip to content

Commit

Permalink
implement inference webui using mlora
Browse files Browse the repository at this point in the history
  • Loading branch information
mikecovlee committed Jan 5, 2024
1 parent a667ee1 commit b767497
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 156 deletions.
191 changes: 37 additions & 154 deletions inference.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,10 @@
import sys
import fire
import json
import torch
import mlora
import traceback
import transformers

import gradio as gr
import os.path as osp

from queue import Queue
from typing import Union
from peft import PeftModel
from threading import Thread
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer


class Stream(transformers.StoppingCriteria):
def __init__(self, callback_func=None):
self.callback_func = callback_func

def __call__(self, input_ids, scores) -> bool:
if self.callback_func is not None:
self.callback_func(input_ids[0])
return False


class Iteratorize:
Expand All @@ -40,10 +22,10 @@ def __init__(self, func, kwargs={}, callback=None):
self.kwargs = kwargs
self.stop_now = False

def _callback(val):
def _callback(seq_pos, output):
if self.stop_now:
raise ValueError
self.q.put(val)
self.q.put(output["m-LoRA"][0])

def gentask():
try:
Expand Down Expand Up @@ -78,138 +60,60 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self.stop_now = True


class Prompter(object):
__slots__ = ("template", "_verbose")

def __init__(self, template_name: str = "", verbose: bool = False):
self._verbose = verbose
if not template_name:
# Enforce the default here, so the constructor can be called with '' and will not break.
template_name = "alpaca"
file_name = osp.join("template", f"{template_name}.json")
if not osp.exists(file_name):
raise ValueError(f"Can't read {file_name}")
with open(file_name) as fp:
self.template = json.load(fp)
if self._verbose:
print(
f"Using prompt template {template_name}: {self.template['description']}"
)

def generate_prompt(
self,
instruction: str,
input: Union[None, str] = None,
label: Union[None, str] = None,
) -> str:
# returns the full prompt from instruction and optional input
# if a label (=response, =output) is provided, it's also appended.
if input:
res = self.template["prompt"].format(
instruction=instruction, input=input, output=""
)
else:
res = self.template["prompt_no_input"].format(
instruction=instruction, output=""
)
if label:
res = f"{res}{label}"
if self._verbose:
print(res)
return res

def get_response(self, output: str) -> str:
if self.template["response_split"] in output:
return output.split(self.template["response_split"])[1].strip()
else:
return ""


def main(
load_8bit: bool = False,
base_model: str = "",
lora_weights: str = "",
prompt_template: str = "alpaca",
target_device: str = "cuda:0",
prompt_template: str = None,
server_name: str = "0.0.0.0",
share_gradio: bool = False,
):
assert (
base_model
), "Please specify a --base_model"

assert (
lora_weights
), "Please specify a --lora_weights"

prompter = Prompter(prompt_template)
tokenizer = LlamaTokenizer.from_pretrained(base_model)
model = LlamaForCausalLM.from_pretrained(
base_model,
load_in_8bit=load_8bit,
torch_dtype=torch.float16,
device_map="cuda:0",
)
model = PeftModel.from_pretrained(
model,
lora_weights,
torch_dtype=torch.float16,
)
model = mlora.LlamaModel.from_pretrained(
base_model, device=target_device, bits=8 if load_8bit else None)
tokenizer = mlora.Tokenizer(base_model, device=target_device)

# unwind broken decapoda-research config
model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk
model.config.bos_token_id = tokenizer.bos_token_id
model.config.eos_token_id = tokenizer.eos_token_id
if lora_weights:
model.load_adapter_weight(lora_weights, "m-LoRA")
generation_config = model.get_generate_paramas()["m-LoRA"]
else:
generation_config = mlora.GenerateConfig(adapter_name_="m-LoRA")

if not load_8bit:
model.half() # seems to fix bugs for some users.

model.eval()
if torch.__version__ >= "2" and sys.platform != "win32":
model = torch.compile(model)
generation_config.prompt_template_ = prompt_template

def evaluate(
instruction,
input=None,
input="",
temperature=0.1,
top_p=0.75,
top_k=40,
num_beams=4,
max_new_tokens=128,
stream_output=False,
**kwargs,
):
prompt = prompter.generate_prompt(instruction, input)
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to("cuda:0")
generation_config = GenerationConfig(
temperature=temperature,
top_p=top_p,
top_k=top_k,
num_beams=num_beams,
**kwargs,
)
input = input.strip()
if len(input) == 0:
input = None

generation_config.prompts_ = [
generation_config.generate_prompt(instruction, input)]

generate_params = {
"input_ids": input_ids,
"generation_config": generation_config,
"return_dict_in_generate": True,
"output_scores": True,
"max_new_tokens": max_new_tokens,
"llm_model": model,
"tokenizer": tokenizer,
"configs": [generation_config],
"temperature": temperature,
"top_p": top_p,
"max_gen_len": max_new_tokens
}

if stream_output:
# Stream the reply 1 token at a time.
# This is based on the trick of using 'stopping_criteria' to create an iterator.

def generate_with_callback(callback=None, **kwargs):
kwargs.setdefault(
"stopping_criteria", transformers.StoppingCriteriaList()
)
kwargs["stopping_criteria"].append(
Stream(callback_func=callback)
)
with torch.no_grad():
model.generate(**kwargs)
mlora.generate(stream_callback=callback, **kwargs)

def generate_with_streaming(**kwargs):
return Iteratorize(
Expand All @@ -218,65 +122,44 @@ def generate_with_streaming(**kwargs):

with generate_with_streaming(**generate_params) as generator:
for output in generator:
# new_tokens = len(output) - len(input_ids[0])
decoded_output = tokenizer.decode(output)

if output[-1] in [tokenizer.eos_token_id]:
break

yield prompter.get_response(decoded_output)
yield output
return # early return for stream_output

# Without streaming
with torch.no_grad():
generation_output = model.generate(
input_ids=input_ids,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=max_new_tokens,
)
s = generation_output.sequences[0]
output = tokenizer.decode(s)
yield prompter.get_response(output)
output = mlora.generate(**generate_params)
yield output["m-LoRA"][0]

gr.Interface(
fn=evaluate,
inputs=[
gr.components.Textbox(
lines=2,
label="Instruction",
placeholder="Tell me about Sichuan University.",
placeholder="Could you provide an introduction to m-LoRA?",
),
gr.components.Textbox(lines=2, label="Input", placeholder="none"),
gr.components.Slider(
minimum=0, maximum=1, value=1, label="Temperature"
),
gr.components.Slider(
minimum=0, maximum=1, value=0.9, label="Top p"
minimum=0, maximum=1, value=0.2, label="Temperature"
),
gr.components.Slider(
minimum=0, maximum=100, step=1, value=40, label="Top k"
minimum=0, maximum=1, value=0.9, label="Top-p"
),
gr.components.Slider(
minimum=1, maximum=4, step=1, value=4, label="Beams"
),
gr.components.Slider(
minimum=1, maximum=2000, step=1, value=512, label="Max tokens"
minimum=1, maximum=2000, step=1, value=128, label="Max tokens"
),
gr.components.Checkbox(
label="Stream output", value=True
),
],
outputs=[
gr.inputs.Textbox(
gr.components.Textbox(
lines=5,
label="Output",
)
],
title="m-LoRA LLM Evaluator",
description="Evaluate basic LLaMA model and LoRA weights", # noqa: E501
).queue().launch(server_name="0.0.0.0", share=share_gradio)
).queue().launch(server_name=server_name, share=share_gradio)


if __name__ == "__main__":
Expand Down
5 changes: 4 additions & 1 deletion mlora/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from dataclasses import dataclass
from typing import List
import logging
import torch


Expand All @@ -24,7 +25,9 @@ def generate_prompt(self, instruction: str, input: str = None) -> str:
return instruction

if self.prompter_ is None:
assert self.prompt_template_ is not None
if self.prompt_template_ is None:
logging.warn("Drop input when prompt template is not set.")
return instruction
self.prompter_ = Prompter(self.prompt_template_)

return self.prompter_.generate_prompt(instruction=instruction, input=input)
Expand Down
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,6 @@ jieba
rouge
rouge_chinese
flask
peft
peft
gradio==3.50
fire

0 comments on commit b767497

Please sign in to comment.