Skip to content

Commit

Permalink
refactor the logging function
Browse files Browse the repository at this point in the history
  • Loading branch information
yezhengmao1 authored and mikecovlee committed Jan 22, 2024
1 parent babab38 commit 39c62de
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 36 deletions.
67 changes: 39 additions & 28 deletions mlora.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,52 +20,45 @@
import torch
import mlora
import random
import datetime
import logging
import argparse
from typing import Dict, Tuple, List

# Command Line Arguments
parser = argparse.ArgumentParser(description='m-LoRA main program')
parser.add_argument('--base_model', type=str,
help='Path to or name of base model')
parser.add_argument('--tokenizer', type=str,
help='Path to or name of tokenizer')
parser.add_argument('--model_type', type=str, default="llama",
help='The model type, support: llama, chatglm')
parser.add_argument('--device', type=str, default='cuda:0',
help='Specify which GPU to be used, default is cuda:0')

parser.add_argument('--load_8bit', action="store_true",
help='Load model in 8bit mode')
parser.add_argument('--load_4bit', action="store_true",
help='Load model in 4bit mode')

parser.add_argument('--inference', action="store_true",
help='The inference mode (just for test)')

parser.add_argument('--load_lora', action="store_true",
help="Load lora from file instead of init randomly")
parser.add_argument('--disable_lora', action="store_true",
help="Disable the lora modules")
parser.add_argument('--tokenizer', type=str,
help='Path to or name of tokenizer')
parser.add_argument('--load_8bit', action="store_true",
help='Load model in 8bit mode')
parser.add_argument('--load_4bit', action="store_true",
help='Load model in 4bit mode')
parser.add_argument('--device', type=str, default='cuda:0',
help='Specify which GPU to be used, default is cuda:0')

parser.add_argument('--config', type=str,
help='Path to finetune configuration')
parser.add_argument('--seed', type=int, default=42,
help='Random seed in integer, default is 42')
parser.add_argument('--log', type=bool, default=True,
help='Turn on or off log, default is true')

args = parser.parse_args()


def log(msg: str):
if args.log:
print('[%s] m-LoRA: %s' %
(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), msg))

parser.add_argument('--log_level', type=str, default="INFO",
help="Set the log level.")
parser.add_argument('--log_file', type=str,
help="Save log to specific file.")

if torch.cuda.is_available():
log('NVIDIA CUDA initialized successfully.')
log('Total %i GPU(s) detected.' % torch.cuda.device_count())
else:
print('m-LoRA requires NVIDIA CUDA computing capacity. Please check your PyTorch installation.')
exit(-1)
args = parser.parse_args()


if args.base_model is None:
Expand Down Expand Up @@ -93,14 +86,12 @@ def load_base_model(config: Dict[str, any]) -> Tuple[mlora.Tokenizer, mlora.LLMM
path=args.base_model,
device=args.device,
bits=(8 if args.load_8bit else (4 if args.load_4bit else None)),
log_fn=log
)
elif args.model_type == "chatglm":
model = mlora.ChatGLMModel.from_pretrained(
path=args.base_model,
device=args.device,
bits=(8 if args.load_8bit else (4 if args.load_4bit else None)),
log_fn=log
bits=(8 if args.load_8bit else (4 if args.load_4bit else None))
)
else:
raise f"unkown model type {args.model_type}"
Expand Down Expand Up @@ -275,8 +266,28 @@ def inference(config: Dict[str, any],

# Main Function
if __name__ == "__main__":
# set the random seed
setup_seed(args.seed)

# set the logger
log_handlers = [logging.StreamHandler()]
if args.log_file is not None:
log_handlers.append(logging.FileHandler(args.log_file))

logging.basicConfig(format="[%(asctime)s] m-LoRA: %(message)s",
level=args.log_level,
handlers=log_handlers,
force=True)

# check the enviroment
if torch.cuda.is_available():
logging.info('NVIDIA CUDA initialized successfully.')
logging.info('Total %i GPU(s) detected.' % torch.cuda.device_count())
else:
logging.error(
'm-LoRA requires NVIDIA CUDA computing capacity. Please check your PyTorch installation.')
exit(1)

with open(args.config, 'r', encoding='utf8') as fp:
config = json.load(fp)

Expand Down
7 changes: 3 additions & 4 deletions mlora/model_chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from mlora.model import apply_rotary_emb_to_one, repeat_kv, precompute_mask, precompute_rope_angle
from mlora.LoraLiner import Linear

import logging
import torch
import torch.nn.functional as F
import xformers.ops
Expand Down Expand Up @@ -181,12 +182,10 @@ def from_pretrained(path: str,
fp16: bool = True,
bf16: bool = True,
double_quant: bool = True,
quant_type: str = 'nf4',
log_fn=None):
quant_type: str = 'nf4'):
# now only support the qlora - 4bit
if bits in [4, 8]:
if log_fn is not None:
log_fn('Loading model with quantization, bits = %i' % bits)
logging.info('Loading model with quantization, bits = %i' % bits)
from transformers import BitsAndBytesConfig
compute_dtype = (torch.float16 if fp16 else (
torch.bfloat16 if bf16 else torch.float32))
Expand Down
7 changes: 3 additions & 4 deletions mlora/model_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from mlora.model import LLMModel, RMSNorm
from mlora.LoraLiner import Linear

import logging
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
Expand Down Expand Up @@ -239,11 +240,9 @@ def from_pretrained(path: str,
fp16: bool = True,
bf16: bool = True,
double_quant: bool = True,
quant_type: str = 'nf4',
log_fn=None) -> LLMModel:
quant_type: str = 'nf4') -> LLMModel:
if bits in [4, 8]:
if log_fn is not None:
log_fn('Loading model with quantization, bits = %i' % bits)
logging.info('Loading model with quantization, bits = %i' % bits)
from transformers import BitsAndBytesConfig
compute_dtype = (torch.float16 if fp16 else (
torch.bfloat16 if bf16 else torch.float32))
Expand Down

0 comments on commit 39c62de

Please sign in to comment.