Skip to content

Commit

Permalink
train and serve, rough
Browse files Browse the repository at this point in the history
  • Loading branch information
daanelson committed Mar 16, 2023
1 parent c6432bf commit 61db2d7
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 33 deletions.
81 changes: 48 additions & 33 deletions predict.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,57 @@
from typing import List, Optional
from cog import BasePredictor, Input
import os
from transformers import LLaMAForCausalLM, LLaMATokenizer
import torch

CACHE_DIR = 'alpaca_out'

class Predictor(BasePredictor):
def setup(self):
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.model = LLaMAForCausalLM.from_pretrained("weights/llama-7b", cache_dir=CACHE_DIR, local_files_only=True)
self.model = self.model
self.model.to(self.device)
self.tokenizer = LLaMATokenizer.from_pretrained("weights/tokenizer", cache_dir=CACHE_DIR, local_files_only=True)

def predict(
self,
model_path: str = Input(description="path to model"),
tokenizer_path: str = Input(description="path to tokenizer"),
data_path: str = Input(description="path to data", default='alpaca_data.json'),
output_path: str = Input(description="path to save model", default='alpaca_out')
) -> int:
if not output_path.startswith('/src'):
output_path = os.path.join('src', output_path)
if not os.path.exists(output_path):
os.makedirs(output_path)
prompt: str = Input(description=f"Prompt to send to LLaMA."),
n: int = Input(description="Number of output sequences to generate", default=1, ge=1, le=5),
max_length: int = Input(
description="Maximum number of tokens to generate. A word is generally 2-3 tokens",
ge=1,
default=50
),
temperature: float = Input(
description="Adjusts randomness of outputs, greater than 1 is random and 0 is deterministic, 0.75 is a good starting value.",
ge=0.01,
le=5,
default=0.75,
),
top_p: float = Input(
description="When decoding text, samples from the top p percentage of most likely tokens; lower to ignore less likely tokens",
ge=0.01,
le=1.0,
default=1.0
),
repetition_penalty: float = Input(
description="Penalty for repeated words in generated text; 1 is no penalty, values greater than 1 discourage repetition, less than 1 encourage it.",
ge=0.01,
le=5,
default=1
)
) -> List[str]:
input = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)

command = f'''torchrun --nproc_per_node=4 --master_port=9292 train.py \
--model_name_or_path {model_path} \
--tokenizer_name_or_path {tokenizer_path} \
--data_path {data_path} \
--bf16 True \
--output_dir {output_path} \
--num_train_epochs 1 \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 1 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 2000 \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--fsdp "full_shard auto_wrap" \
--fsdp_transformer_layer_cls_to_wrap 'LLaMADecoderLayer' \
--tf32 True '''
res = os.system(command)
return res
outputs = self.model.generate(
input,
num_return_sequences=n,
max_length=max_length,
do_sample=True,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty
)
out = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
return out

23 changes: 23 additions & 0 deletions train_model.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#!/bin/bash

torchrun --nproc_per_node=4 --master_port=9292 train.py \
--model_name_or_path /src/weights/llama-7b \
--tokenizer_name_or_path /src/weights/tokenizer \
--data_path ./alpaca_data.json \
--bf16 True \
--output_dir alpaca_out \
--num_train_epochs 3 \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 8 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 2000 \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--fsdp "full_shard auto_wrap" \
--fsdp_transformer_layer_cls_to_wrap 'LLaMADecoderLayer' \
--tf32 True \

0 comments on commit 61db2d7

Please sign in to comment.