Skip to content

Commit

Permalink
fix prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
mikecovlee committed Jan 5, 2024
1 parent b767497 commit b0bc516
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 11 deletions.
3 changes: 1 addition & 2 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@
generate_paramas = mlora.GenerateConfig(adapter_name_=adapter_name)

generate_paramas.prompt_template_ = args.template
generate_paramas.prompts_ = [
generate_paramas.generate_prompt(args.instruction, args.input)]
generate_paramas.prompts_ = [(args.instruction, args.input)]

output = mlora.generate(model, tokenizer, [generate_paramas],
temperature=0.5, top_p=0.9, max_gen_len=128,
Expand Down
3 changes: 1 addition & 2 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,7 @@ def evaluate(
if len(input) == 0:
input = None

generation_config.prompts_ = [
generation_config.generate_prompt(instruction, input)]
generation_config.prompts_ = [(instruction, input)]

generate_params = {
"llm_model": model,
Expand Down
21 changes: 14 additions & 7 deletions mlora/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@
from mlora.model import LLMModel, KVCache
from mlora.utils import Prompter

from typing import List, Union, Tuple
from dataclasses import dataclass
from typing import List
import logging
import torch


@dataclass
class GenerateConfig:
adapter_name_: str = None
prompts_: List[str] = None
prompts_: List[Union[str, Tuple[str, str]]] = None
prompt_template_: str = None
# Do not set these manually
batch_start_idx_: int = -1
Expand All @@ -21,17 +21,24 @@ class GenerateConfig:

# Set prompt_template_ to enable the prompter
def generate_prompt(self, instruction: str, input: str = None) -> str:
if input is None and self.prompt_template_ is None:
if self.prompt_template_ is None:
if input is not None:
logging.warn("Drop input when prompt template is not set.")
return instruction

if self.prompter_ is None:
if self.prompt_template_ is None:
logging.warn("Drop input when prompt template is not set.")
return instruction
self.prompter_ = Prompter(self.prompt_template_)

return self.prompter_.generate_prompt(instruction=instruction, input=input)

def get_prompts(self) -> List[str]:
prompts = []
for prompt in self.prompts_:
args = prompt if isinstance(prompt, Tuple) else (prompt, None)
prompts.append(self.generate_prompt(*args))

return prompts

def get_response(self, output: str) -> str:
if self.prompter_ is None:
return output.strip()
Expand Down Expand Up @@ -87,7 +94,7 @@ def generate(llm_model: LLMModel,
batch_data_config: List[LoraBatchDataConfig] = []
for config in configs:
tokens = [tokenizer.encode(prompt, True, False)
for prompt in config.generate_prompt(instruction=config.prompts_)]
for prompt in config.get_prompts()]
config.batch_start_idx_ = len(raw_prompts)
config.batch_end_idx_ = config.batch_start_idx_ + len(tokens)
batch_data_config.append(LoraBatchDataConfig(
Expand Down

0 comments on commit b0bc516

Please sign in to comment.