Skip to content

Commit

Permalink
Merge branch 'master' into jomayeri/zero-inf-gds
Browse files Browse the repository at this point in the history
  • Loading branch information
jomayeri authored Sep 10, 2024
2 parents f7fd25e + a256c04 commit 028d940
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
10 changes: 7 additions & 3 deletions applications/DeepSpeed-Chat/dschat/utils/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,13 @@ def __len__(self):
def __getitem__(self, idx):
if self.train_phase == 1:
return {
"input_ids": self.chosen_dataset[idx]["input_ids"],
"attention_mask": self.chosen_dataset[idx]["attention_mask"],
"labels": self.chosen_dataset[idx]["input_ids"]
"input_ids":
self.chosen_dataset[idx]["input_ids"],
"attention_mask":
self.chosen_dataset[idx]["attention_mask"],
"labels":
torch.where(self.chosen_dataset[idx]["attention_mask"].bool(),
self.chosen_dataset[idx]["input_ids"], -100)
}
elif self.train_phase == 2:
return self.chosen_dataset[idx]["input_ids"], self.chosen_dataset[idx]["attention_mask"], \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ def main():
args.seed,
tokenizer,
args.max_seq_len,
end_of_conversation_token=tokenizer.eos_token,
sft_only_data_path=args.sft_only_data_path)
# DataLoaders creation:
if args.local_rank == -1:
Expand Down

0 comments on commit 028d940

Please sign in to comment.