Skip to content

Commit

Permalink
removing debug mode
Browse files Browse the repository at this point in the history
  • Loading branch information
SwayamInSync committed Jun 22, 2024
1 parent ee03a7a commit ec9e64e
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "tunex"
version = "0.1.1"
version = "0.1.2"
description = "A powerful CLI toolkit offering one shot solution for LLM in running, finetuning and instruction tuning"
readme = "README.md"
license = { text = "Apache-2.0" }
Expand Down
9 changes: 4 additions & 5 deletions tunex/utils/convert_from_hf_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ def gpt2_checkpointing(state_dict: Dict[str, torch.Tensor], hf_weights) -> None:
"h.{}.mlp.c_proj.bias": "transformer.h.{}.mlp.c_proj.bias",
}

transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
transposed = ['attn.c_attn.weight', 'attn.c_proj.weight',
'mlp.c_fc.weight', 'mlp.c_proj.weight']

for name, param in hf_weights.items():
print(name)
if "h." in name:
from_name, number = get_layer_pos(name, 1)
to_name = weight_map[from_name].format(number)
Expand All @@ -66,7 +66,8 @@ def convert_and_save_hf_checkpoint(checkpoint_dir: Path, model_name: str) -> Non
copy_fn = partial(gpt2_checkpointing)

if copy_fn is None:
raise ValueError(f"No conversion function corresponding to {model_name} is found")
raise ValueError(
f"No conversion function corresponding to {model_name} is found")

state_dict = {}

Expand All @@ -90,5 +91,3 @@ def convert_and_save_hf_checkpoint(checkpoint_dir: Path, model_name: str) -> Non

# Save the model state dictionary
torch.save(state_dict, (checkpoint_dir / "tunex_model.pth").as_posix())


0 comments on commit ec9e64e

Please sign in to comment.