Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
yanbing-j committed Oct 8, 2024
1 parent f7f8298 commit fe741f4
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,14 @@ def encode_tokens(tokenizer, string, bos=True, device=default_device):
tokens = [tokenizer.bos_id()] + tokens
return torch.tensor(tokens, dtype=torch.int, device=device)

def _convert_weight(model):
from quantize import WeightOnlyInt4Linear
for fqn, mod in model.named_modules():
if isinstance(mod, WeightOnlyInt4Linear):
weight = mod.weight.data
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight, mod.inner_k_tiles)
mod.weight = weight_int4pack

def _load_model(checkpoint_path, device, precision, use_tp):
use_cuda = 'cuda' in device
with torch.device('meta'):
Expand All @@ -240,19 +248,15 @@ def _load_model(checkpoint_path, device, precision, use_tp):
checkpoint = checkpoint["model"]
model.load_state_dict(checkpoint, assign=True)

model = model.to(device=device, dtype=precision)
# int4 packed weight needs to be converted after model loading to the specific device
if "int4" in str(checkpoint_path):
_convert_weight(model)

if use_tp:
from tp import apply_tp
print("Applying tensor parallel to model ...")
apply_tp(model)

model = model.to(device=device, dtype=precision)
if "int4" in str(checkpoint_path):
from quantize import WeightOnlyInt4Linear
for fqn, mod in model.named_modules():
if isinstance(mod, WeightOnlyInt4Linear):
weight = mod.weight.data
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight, mod.inner_k_tiles)
mod.weight = weight_int4pack
return model.eval()

def _get_model_size(model):
Expand Down

0 comments on commit fe741f4

Please sign in to comment.