Skip to content

Commit

Permalink
minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Guitaricet committed Aug 13, 2023
1 parent 0e7e9fe commit 3784fb4
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 16 deletions.
29 changes: 19 additions & 10 deletions run_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version

from peft_pretraining.modeling_llama import LlamaForSequenceClassification


# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.31.0")
Expand Down Expand Up @@ -168,6 +170,7 @@ class ModelArguments:
"""

model_name_or_path: str = field(
default=None,
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
config_name: Optional[str] = field(
Expand Down Expand Up @@ -216,6 +219,8 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()

training_args.save_strategy = "no"

# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your Python/PyTorch versions.
send_example_telemetry("run_glue", model_args, data_args)
Expand Down Expand Up @@ -370,15 +375,19 @@ def main():
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
model = AutoModelForSequenceClassification.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
)

if model_args.model_name_or_path is None:
model = LlamaForSequenceClassification(config)
else:
model = LlamaForSequenceClassification.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
)

# Preprocessing the raw_datasets
if data_args.task_name is not None:
Expand Down Expand Up @@ -535,7 +544,7 @@ def compute_metrics(p: EvalPrediction):
)
metrics["train_samples"] = min(max_train_samples, len(train_dataset))

trainer.save_model() # Saves the tokenizer too for easy upload
# trainer.save_model() # Saves the tokenizer too for easy upload

trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
Expand Down
19 changes: 13 additions & 6 deletions torchrun_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,11 @@ def main(args):
if args.resume_from:
logger.info(f"Loading model from {args.resume_from}")
checkpoint_path = os.path.join(args.resume_from, "pytorch_model.bin")
model.wrapped_model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"), strict=True)
if isinstance(model, ReLoRaModel):
model.wrapped_model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"), strict=True)
else:
model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"), strict=True)

logger.info(f"Model successfully loaded (strict=True policy)")

logger.info(f"Loading training state like global_step, update_step, and tokens_seen from {args.warmed_up_model}")
Expand Down Expand Up @@ -791,10 +795,6 @@ def main(args):
save_dir=current_model_directory,
)

# restart model after we modify the learning rate, so on the next step after the relora frequency
can_reset = args.resume_from is not None \
or (args.relora is not None and local_step * args.gradient_accumulation > args.relora)

# ##############################
# EVALUATION
if update_step % args.eval_every == 0:
Expand All @@ -813,6 +813,13 @@ def main(args):

# ##############################
# MERGE AND REINIT

# restart model after we modify the learning rate, so on the next step after the relora frequency
can_reset = args.relora is not None and (
args.resume_from is not None
or local_step * args.gradient_accumulation > args.relora
)

if can_reset and update_step % args.relora == 1:
logger.info(f"Performing lora reset at update step {update_step}. Current lr is {optimizer.param_groups[0]['lr']}")
n_lora_restarts += 1
Expand Down Expand Up @@ -899,7 +906,7 @@ def main(args):
training_state_checkpoint=training_state_checkpoint,
run_config=run_config,
distributed_type=args.distributed_type,
save_dir=args.save_dir,
save_dir=current_model_directory,
)

# Final evaluation
Expand Down

0 comments on commit 3784fb4

Please sign in to comment.