Skip to content

Commit

Permalink
Merge pull request #18 from coreweave/rwang.padtoken052324
Browse files Browse the repository at this point in the history
feat: add default padding token
  • Loading branch information
rtalaricw authored May 31, 2024
2 parents bef6e0d + c4d661e commit b0e1976
Show file tree
Hide file tree
Showing 6 changed files with 169 additions and 43 deletions.
15 changes: 15 additions & 0 deletions gpt_bpe.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
const BPE_LRU_SZ = 65536
const RUNEBUF_SZ = 16384
const WORDCHAN_SZ = 4096
const defaultPadTokenString = "[PAD]"

type Token uint16
type Tokens []Token
Expand Down Expand Up @@ -366,6 +367,20 @@ func NewEncoder(vocabId string) (*GPTEncoder, error) {
tokenizerSpecialConfig.AddBosToken = true
tokenizerSpecialConfig.AddEosToken = true
}

// Add in default pad token if not already set
padTokenNotFound := (tokenizerSpecialConfig.PadToken == "" && hfConfig.PadTokenStr == nil)
if padTokenNotFound {
// Inject the pad token into the encoder to uintmax16,
// throw an error if vocab is larger than uintmax16
if len(encoderTokens) >= math.MaxInt16 {
log.Fatalf("Vocab size is larger than uint16 max, default pad token cannot be added." +
"Please specify a pad token in the vocab file.")
}
encoderTokens[defaultPadTokenString] = math.MaxUint16
tokenizerSpecialConfig.PadToken = defaultPadTokenString
hfConfig.PadTokenStr = &tokenizerSpecialConfig.PadToken
}
encoder := &GPTEncoder{
encoderTokens,
tokensEncoder,
Expand Down
42 changes: 42 additions & 0 deletions gpt_bpe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1196,6 +1196,48 @@ func TestModelDownloadLlama(t *testing.T) {
fmt.Println("All Exists - Looks good.")
}

func TestGPT2DefaultPadding(t *testing.T) {
// GPT2 defines a padding token, we test if it properly gets this token
// corresponds to <|padding|> in the vocab
assert.Equal(t, gpt2Encoder.PadToken, Token(50257))
assert.Equal(t, gpt2Encoder.Encoder["<|padding|>"], Token(50257))
}

func TestPilePadding(t *testing.T) {
// Pile defines a padding token, we test if it properly gets this token
// corresponds to <|padding|> in the vocab
assert.Equal(t, pileEncoder.PadToken, Token(1))
assert.Equal(t, pileEncoder.Encoder["<|padding|>"], Token(1))
}

func TestClipPadding(t *testing.T) {
// CLIP defines a padding token, we test if it properly gets this token
// corresponds to <|endoftext|> in the vocab
assert.Equal(t, clipEncoder.PadToken, Token(49407))
assert.Equal(t, clipEncoder.Encoder["<|endoftext|>"], Token(49407))
}

func TestNerdstashPadding(t *testing.T) {
// Nerdstash defines a padding token, we test if it properly gets this token
// corresponds to <|pad|> in the vocab
assert.Equal(t, nerdstashV2Encoder.PadToken, Token(0))
assert.Equal(t, nerdstashV2Encoder.Encoder["<|pad|>"], Token(0))
}

func TestLlamaPadding(t *testing.T) {
// Llama doesn't define a padding token, we test if it properly defaults to
// [PAD] as 65535
assert.Equal(t, llama2Encoder.PadToken, Token(65535))
assert.Equal(t, llama2Encoder.Encoder["[PAD]"], Token(65535))
}

func TestMistralPadding(t *testing.T) {
// Mistral doesn't define a padding token, we test if it properly defaults to
// [PAD] as 65535
assert.Equal(t, mistralEncoder.PadToken, Token(65535))
assert.Equal(t, mistralEncoder.Encoder["[PAD]"], Token(65535))
}

func TestModelDownloadFairseq(t *testing.T) {
// Koboldai's fairseq models are stored in a different format
// it has merges and vocab but no tokenizer.json
Expand Down
3 changes: 1 addition & 2 deletions resources/data/llama-tokenizer/special_tokens_map.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,11 @@
"rstrip": false,
"single_word": false
},
"pad_token": "[PAD]",
"unk_token": {
"content": "<unk>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
}
}
}
80 changes: 71 additions & 9 deletions resources/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"path"
"regexp"
"strconv"
"strings"
"time"

"github.com/dustin/go-humanize"
Expand Down Expand Up @@ -800,11 +801,32 @@ func ResolveConfig(vocabId string, token string) (config *HFConfig,
// Given a set of resources, resolve the HuggingFace configuration.
// Used to be able to resolve both embedded and local resources.
func ResolveHFFromResources(resources *Resources, hfConfig *HFConfig) (*HFConfig, error) {
//use interfaces to unmarsal the config file and tokenizer config file
// Resolve config and tokenizer config from resources
// config.json and tokenizer_config.json
hfConfig, err := resolveConfigAndTokenizerConfig(resources, hfConfig)
if err != nil {
return nil, err
}

// Resolve special tokens and special tokens config from resources
// special_tokens_map.json and specials.txt
hfConfig, err = resolveSpecialsAndSpecialTokens(resources, hfConfig)
if err != nil {
return nil, err
}
return hfConfig, nil
}

// resolveConfigAndTokenizerConfig
// Resolve config and tokenizer config from resources.
// Used to be able to resolve both embedded and local resources.
// Continuation of ResolveHFFromResources.
func resolveConfigAndTokenizerConfig(resources *Resources, hfConfig *HFConfig) (*HFConfig, error) {
// Use interfaces to unmarshal the config file and tokenizer config file
var config interface{}
var tokenizerConfig interface{}
//if exists, unmarshal config.json and tokenizer_config.json
//use getfile to get the file, then unmarshal it
// If exists, unmarshal config.json and tokenizer_config.json, else
// use GetFile to get the file, then unmarshal it
if _, err := resources.GetFile("config.json"); err == nil {
if err := json.Unmarshal(*((*resources)["config.json"]).Data, &config); err != nil {
fmt.Errorf("Error unmarshalling config.json: %s", err)
Expand All @@ -824,12 +846,13 @@ func ResolveHFFromResources(resources *Resources, hfConfig *HFConfig) (*HFConfig

}

//check if bos_token is in string, this is the old format pythia has. If not, try to unmarshal to the tokenizerSpecials
// Check if bos_token is in string, this is the old format pythia has.
// If not, try to unmarshal to the tokenizerSpecials
// that llama 2 has, else try mistral format
if config != nil || tokenizerConfig != nil {
hasReadConfig := false
if config != nil {
//using interfaces, first check if bos_token is in string format
// Using interfaces, first check if bos_token is in string format
if bosToken, ok := config.(map[string]interface{})["bos_token"].(string); ok {
hfConfig.BosTokenStr = &bosToken
if eosToken, ok := config.(map[string]interface{})["eos_token"].(string); ok {
Expand All @@ -842,7 +865,7 @@ func ResolveHFFromResources(resources *Resources, hfConfig *HFConfig) (*HFConfig
}
}
if tokenizerConfig != nil && !hasReadConfig {
//using interfaces, first check if bos_token is in string format
// Using interfaces, first check if bos_token is in string format
if bosToken, ok := tokenizerConfig.(map[string]interface{})["bos_token"].(string); ok {
hfConfig.BosTokenStr = &bosToken
if eosToken, ok := tokenizerConfig.(map[string]interface{})["eos_token"].(string); ok {
Expand All @@ -854,7 +877,7 @@ func ResolveHFFromResources(resources *Resources, hfConfig *HFConfig) (*HFConfig
hasReadConfig = true

}
//if not, assume llama2 format and try to unmarshal
// If not, assume llama2 format and try to unmarshal
if !hasReadConfig {
cfg := tokenizerConfig.(map[string]interface{})
if bosToken, ok := cfg["bos_token"].(map[string]interface{}); ok {
Expand All @@ -871,7 +894,7 @@ func ResolveHFFromResources(resources *Resources, hfConfig *HFConfig) (*HFConfig
hfConfig.PadTokenStr = &padToken
}
}
//if that doesn't work, assume mistral format
// If that doesn't work, assume mistral format
if !hasReadConfig {
if bosToken, ok := tokenizerConfig.(map[string]interface{})["bos_token"].(string); ok {
hfConfig.BosTokenStr = &bosToken
Expand All @@ -889,6 +912,46 @@ func ResolveHFFromResources(resources *Resources, hfConfig *HFConfig) (*HFConfig
return hfConfig, nil
}

// resolveSpecialsAndSpecialTokens
// Resolve special tokens and special tokens config from resources.
// Used to be able to resolve both embedded and local resources.
// Continuation of ResolveHFFromResources.
func resolveSpecialsAndSpecialTokens(resources *Resources, hfConfig *HFConfig) (*HFConfig, error) {
// Get specials config from resources
// We can only generate specials.json if we have special_tokens_map
specialsJson, ok := (*resources)["special_tokens_map.json"]
if ok {
specialTokens := make(map[string]interface{}, 0)
if specialErr := json.Unmarshal(*specialsJson.Data,
&specialTokens); specialErr != nil {
return nil, specialErr
}

// Try to get pad token from specials if not already set
if hfConfig.PadTokenStr == nil {
if padToken, ok := specialTokens["pad_token"].(string); ok {
hfConfig.PadTokenStr = &padToken
}
}
}

// Get from specials.json
specialsTxt, ok := (*resources)["specials.txt"]
if ok {
// Treat specials.txt as an array of strings and try to match
specials := strings.Split(string(*specialsTxt.Data), "\n")
if hfConfig.PadTokenStr == nil {
for _, special := range specials {
if strings.Contains(strings.ToLower(special), "pad") {
hfConfig.PadTokenStr = &special
break
}
}
}
}
return hfConfig, nil
}

// ResolveVocabId
// Resolves a vocabulary id to a set of resources, from embedded,
// local filesystem, or remote.
Expand All @@ -902,7 +965,6 @@ func ResolveVocabId(vocabId string, token string) (*HFConfig, *Resources, error)
ModelId: &vocabId,
BosTokenStr: &bosText,
EosTokenStr: &endOfText,
PadTokenStr: &endOfText,
}
resources := make(Resources, 0)

Expand Down
71 changes: 39 additions & 32 deletions resources/resource_data_js.go

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions resources/resources.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
//go:embed data/clip-tokenizer/unitrim.json
//go:embed data/clip-tokenizer/specials.txt
//go:embed data/clip-tokenizer/special_config.json
//go:embed data/clip-tokenizer/special_tokens_map.json
//go:embed data/nerdstash_v1-tokenizer/encoder.json
//go:embed data/nerdstash_v1-tokenizer/merges.json
//go:embed data/nerdstash_v1-tokenizer/specials.txt
Expand Down

0 comments on commit b0e1976

Please sign in to comment.