This code is written in Python, Dependencies include
-
spacy (python -m spacy download en_core_web_sm, python -m spacy download de_core_news_sm)
-
torchtext==0.4 (or 0.6)
We conduct experiments on binary classification dataset(IMDB and SST-2) and multi-class classification dataset(SST-5 and YELP-5)
preprocess.py is for preprocessing step before training the model.
Example command lines:
python preprocess.py -data_task [MT / CF] -data_dir [wmt16 / imdb / yelp5 / sst2 / sst5] -data_ext csv -data_pkl [pickleName.pickle]
Arguments are as follows:
- data_task:
MT
is for machine translation(De→En) andCF
is for classification (default: CF) - data_dir: directory of dataset
- data_ext: extension of dataset (default: csv)
- data_pkl: file name of preprocessed data(pickle file)
main.py is for model training and inference.
Example command lines:
python main.py -gpu 1 -option [BASE / LR / CT] -task [TRAIN / TEST] -data_task [MT / CF] -data_pkl [pickleName.pickle] -model_save [modelName.pt] -pred_save [predictionName.txt]
Arguments are as follows:
- gpu: gpu number
- option:
BASE
is for vanilla transformer,LR
is for low-rank attention(linformer) andCT
is for TopAttn (our proposed method) (default: CT) - task:
TRAIN
is for training, andTEST
is for inference - data_task:
MT
is for machine translation andCF
is for classification (default: CF) - data_pkl: file name of preprocessed data
- model_save: name of best model
- pred_save: file name of prediction reesult (This is for machine translation task)
Additional Arguments are as follows:
- batch_size: batch size (default: 16)
- num_dpoch : # of epoch (default: 8)
- learning_rate: learning rate (default: 1e-4)
- num_warmup: # of steps for warmup (default: 4000)
- hidden_dim: hidden dimension (default: 512)
- n_layer: # of encoder and decoder layer (default: 6)
- n_head: # of head(for multi-head attention) (default: 8)
- ff_dim: dimension of feed-forward neural network (default: 2048)
- dropout: ratio of dropout (default: 0.1)
1) Performance comparison of different token pruning ratios
2) Training memory with different token pruning ratios
3) Comparison with vanilla transformer on various datasets