Skip to content

Commit

Permalink
fix arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
mikecovlee committed Jan 5, 2024
1 parent b0bc516 commit ef283cb
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 61 deletions.
80 changes: 37 additions & 43 deletions generate.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,39 @@
import argparse
import mlora
import fire

# Command Line Arguments
parser = argparse.ArgumentParser(description='m-LoRA LLM generator')
parser.add_argument('--base_model', type=str, required=True,
help='Path to or name of base model')
parser.add_argument('--instruction', type=str, required=True,
help='Instruction of the prompt')
parser.add_argument('--input', type=str,
help='Input of the prompt')
parser.add_argument('--template', type=str,
help='Prompt template')
parser.add_argument('--load_8bit', action="store_true",
help='Load model in 8bit mode')
parser.add_argument('--load_4bit', action="store_true",
help='Load model in 4bit mode')
parser.add_argument('--lora_weights', type=str,
help='Path to or name of LoRA weights')
parser.add_argument('--device', type=str, default='cuda:0',
help='Specify which GPU to be used, default is cuda:0')
args = parser.parse_args()

model = mlora.LlamaModel.from_pretrained(args.base_model, device=args.device,
bits=(8 if args.load_8bit else (4 if args.load_4bit else None)))
tokenizer = mlora.Tokenizer(args.base_model, device=args.device)

if args.lora_weights:
adapter_name = model.load_adapter_weight(args.lora_weights)
generate_paramas = model.get_generate_paramas()[adapter_name]
else:
adapter_name = args.base_model
generate_paramas = mlora.GenerateConfig(adapter_name_=adapter_name)

generate_paramas.prompt_template_ = args.template
generate_paramas.prompts_ = [(args.instruction, args.input)]

output = mlora.generate(model, tokenizer, [generate_paramas],
temperature=0.5, top_p=0.9, max_gen_len=128,
device=args.device)

for prompt in output[adapter_name]:
print(f"\n{'='*10}\n")
print(prompt)
print(f"\n{'='*10}\n")

def main(base_model: str,
instruction: str,
input: str = None,
template: str = None,
lora_weights: str = None,
load_8bit: bool = False,
load_4bit: bool = False,
device: str = "cuda:0"):

model = mlora.LlamaModel.from_pretrained(base_model, device=device,
bits=(8 if load_8bit else (4 if load_4bit else None)))
tokenizer = mlora.Tokenizer(base_model, device=device)

if lora_weights:
adapter_name = model.load_adapter_weight(lora_weights)
generate_paramas = model.get_generate_paramas()[adapter_name]
else:
adapter_name = base_model
generate_paramas = mlora.GenerateConfig(adapter_name_=adapter_name)

generate_paramas.prompt_template_ = template
generate_paramas.prompts_ = [(instruction, input)]

output = mlora.generate(model, tokenizer, [generate_paramas],
temperature=0.5, top_p=0.9, max_gen_len=128,
device=device)

for prompt in output[adapter_name]:
print(f"\n{'='*10}\n")
print(prompt)
print(f"\n{'='*10}\n")


if __name__ == "__main__":
fire.Fire(main)
33 changes: 15 additions & 18 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,30 +60,26 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self.stop_now = True


def main(
load_8bit: bool = False,
base_model: str = "",
lora_weights: str = "",
target_device: str = "cuda:0",
prompt_template: str = None,
server_name: str = "0.0.0.0",
share_gradio: bool = False,
):
assert (
base_model
), "Please specify a --base_model"

model = mlora.LlamaModel.from_pretrained(
base_model, device=target_device, bits=8 if load_8bit else None)
tokenizer = mlora.Tokenizer(base_model, device=target_device)
def main(base_model: str,
template: str = None,
lora_weights: str = "",
load_8bit: bool = False,
load_4bit: bool = False,
device: str = "cuda:0",
server_name: str = "0.0.0.0",
share_gradio: bool = False):

model = mlora.LlamaModel.from_pretrained(base_model, device=device,
bits=(8 if load_8bit else (4 if load_4bit else None)))
tokenizer = mlora.Tokenizer(base_model, device=device)

if lora_weights:
model.load_adapter_weight(lora_weights, "m-LoRA")
generation_config = model.get_generate_paramas()["m-LoRA"]
else:
generation_config = mlora.GenerateConfig(adapter_name_="m-LoRA")

generation_config.prompt_template_ = prompt_template
generation_config.prompt_template_ = template

def evaluate(
instruction,
Expand All @@ -105,7 +101,8 @@ def evaluate(
"configs": [generation_config],
"temperature": temperature,
"top_p": top_p,
"max_gen_len": max_new_tokens
"max_gen_len": max_new_tokens,
"device": device
}

if stream_output:
Expand Down

0 comments on commit ef283cb

Please sign in to comment.