diff --git a/hfppl/llms.py b/hfppl/llms.py index 5b56064..05ba845 100644 --- a/hfppl/llms.py +++ b/hfppl/llms.py @@ -37,7 +37,7 @@ def __init__(self, lm): self.MAX_TOKEN_LENGTH = self.precompute_token_length_masks(lm) - def precompute_token_length_masks(self, lm) -> Dict[int, Set[int]]: + def precompute_token_length_masks(self, lm): """Precompute masks for tokens of different lengths. Each mask is a set of token ids that are of the given length or shorter."""