Skip to content

Commit

Permalink
tokenizer add mask_from function
Browse files Browse the repository at this point in the history
  • Loading branch information
yezhengmao1 committed Jan 21, 2024
1 parent dc6c701 commit 6fc92a6
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
5 changes: 1 addition & 4 deletions mlora/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,17 +385,14 @@ def get_train_data(self) -> MultiLoraBatchData:
lora_config = ilora_conf
pad_side = lora_config.get("expand_side", "right")
assert pad_side == "right" or pad_side == "left"
mask = [False] * len(tokens)
# pad the tokens to align
while len(tokens) < batch_seq_len:
if pad_side == "right":
tokens.append(self.tokenizer_.pad_id_)
mask += [True]
else:
tokens.insert(0, self.tokenizer_.pad_id_)
mask = [True] + mask
batch_tokens.append(tokens)
additional_mask.append(mask)
additional_mask.append(self.tokenizer_.mask_from(tokens))

lora_batch_data_config.append(LoraBatchDataConfig(adapter_name_=adapter,
batch_start_idx_=adapter_start_idx,
Expand Down
12 changes: 9 additions & 3 deletions mlora/tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from transformers import AutoTokenizer
from typing import List
from mlora.modelargs import Tokens, Masks

Tokens = List[int]
from transformers import AutoTokenizer


class Tokenizer:
Expand All @@ -27,3 +26,10 @@ def encode(self, data: str, bos: bool, eos: bool) -> Tokens:

def decode(self, data: Tokens) -> str:
return self.tokenizer.decode(data)

# get the mask from tokens
# example: tokens is [2, 3, pad, pad, 4, 5]
# output is [False, False, True, True, False, False]
def mask_from(self, tokens: Tokens) -> Masks:
mask_tokens = [self.pad_id_]
return [tok in mask_tokens for tok in tokens]

0 comments on commit 6fc92a6

Please sign in to comment.