Skip to content

Commit

Permalink
Move basic Colab code out of notebooks
Browse files Browse the repository at this point in the history
  • Loading branch information
Tenoke committed Feb 18, 2020
1 parent a1fc25f commit 5a79f55
Show file tree
Hide file tree
Showing 9 changed files with 310 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__pycache__/
1 change: 1 addition & 0 deletions 256bytebpe-merges.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#version: 0.2 - Trained by `huggingface/tokenizers`
1 change: 1 addition & 0 deletions 256bytebpe-res-vocab.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"!": 256, "\"": 257, "#": 2, "$": 3, "%": 4, "&": 5, "'": 6, "(": 7, ")": 8, "*": 9, "+": 10, ",": 11, "-": 12, ".": 13, "/": 14, "0": 15, "1": 16, "2": 17, "3": 18, "4": 19, "5": 20, "6": 21, "7": 22, "8": 23, "9": 24, ":": 25, ";": 26, "<": 27, "=": 28, ">": 29, "?": 30, "@": 31, "A": 32, "B": 33, "C": 34, "D": 35, "E": 36, "F": 37, "G": 38, "H": 39, "I": 40, "J": 41, "K": 42, "L": 43, "M": 44, "N": 45, "O": 46, "P": 47, "Q": 48, "R": 49, "S": 50, "T": 51, "U": 52, "V": 53, "W": 54, "X": 55, "Y": 56, "Z": 57, "[": 58, "\\": 59, "]": 60, "^": 61, "_": 62, "`": 63, "a": 64, "b": 65, "c": 66, "d": 67, "e": 68, "f": 69, "g": 70, "h": 71, "i": 72, "j": 73, "k": 74, "l": 75, "m": 76, "n": 77, "o": 78, "p": 79, "q": 80, "r": 81, "s": 82, "t": 83, "u": 84, "v": 85, "w": 86, "x": 87, "y": 88, "z": 89, "{": 90, "|": 91, "}": 92, "~": 93, "\u00a1": 94, "\u00a2": 95, "\u00a3": 96, "\u00a4": 97, "\u00a5": 98, "\u00a6": 99, "\u00a7": 100, "\u00a8": 101, "\u00a9": 102, "\u00aa": 103, "\u00ab": 104, "\u00ac": 105, "\u00ae": 106, "\u00af": 107, "\u00b0": 108, "\u00b1": 109, "\u00b2": 110, "\u00b3": 111, "\u00b4": 112, "\u00b5": 113, "\u00b6": 114, "\u00b7": 115, "\u00b8": 116, "\u00b9": 117, "\u00ba": 118, "\u00bb": 119, "\u00bc": 120, "\u00bd": 121, "\u00be": 122, "\u00bf": 123, "\u00c0": 124, "\u00c1": 125, "\u00c2": 126, "\u00c3": 127, "\u00c4": 128, "\u00c5": 129, "\u00c6": 130, "\u00c7": 131, "\u00c8": 132, "\u00c9": 133, "\u00ca": 134, "\u00cb": 135, "\u00cc": 136, "\u00cd": 137, "\u00ce": 138, "\u00cf": 139, "\u00d0": 140, "\u00d1": 141, "\u00d2": 142, "\u00d3": 143, "\u00d4": 144, "\u00d5": 145, "\u00d6": 146, "\u00d7": 147, "\u00d8": 148, "\u00d9": 149, "\u00da": 150, "\u00db": 151, "\u00dc": 152, "\u00dd": 153, "\u00de": 154, "\u00df": 155, "\u00e0": 156, "\u00e1": 157, "\u00e2": 158, "\u00e3": 159, "\u00e4": 160, "\u00e5": 161, "\u00e6": 162, "\u00e7": 163, "\u00e8": 164, "\u00e9": 165, "\u00ea": 166, "\u00eb": 167, "\u00ec": 168, "\u00ed": 169, "\u00ee": 170, "\u00ef": 171, "\u00f0": 172, "\u00f1": 173, "\u00f2": 174, "\u00f3": 175, "\u00f4": 176, "\u00f5": 177, "\u00f6": 178, "\u00f7": 179, "\u00f8": 180, "\u00f9": 181, "\u00fa": 182, "\u00fb": 183, "\u00fc": 184, "\u00fd": 185, "\u00fe": 186, "\u00ff": 187, "\u0100": 188, "\u0101": 189, "\u0102": 190, "\u0103": 191, "\u0104": 192, "\u0105": 193, "\u0106": 194, "\u0107": 195, "\u0108": 196, "\u0109": 197, "\u010a": 198, "\u010b": 199, "\u010c": 200, "\u010d": 201, "\u010e": 202, "\u010f": 203, "\u0110": 204, "\u0111": 205, "\u0112": 206, "\u0113": 207, "\u0114": 208, "\u0115": 209, "\u0116": 210, "\u0117": 211, "\u0118": 212, "\u0119": 213, "\u011a": 214, "\u011b": 215, "\u011c": 216, "\u011d": 217, "\u011e": 218, "\u011f": 219, "\u0120": 220, "\u0121": 221, "\u0122": 222, "\u0123": 223, "\u0124": 224, "\u0125": 225, "\u0126": 226, "\u0127": 227, "\u0128": 228, "\u0129": 229, "\u012a": 230, "\u012b": 231, "\u012c": 232, "\u012d": 233, "\u012e": 234, "\u012f": 235, "\u0130": 236, "\u0131": 237, "\u0132": 238, "\u0133": 239, "\u0134": 240, "\u0135": 241, "\u0136": 242, "\u0137": 243, "\u0138": 244, "\u0139": 245, "\u013a": 246, "\u013b": 247, "\u013c": 248, "\u013d": 249, "\u013e": 250, "\u013f": 251, "\u0140": 252, "\u0141": 253, "\u0142": 254, "\u0143": 255}
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
# trax-reformer
Training, Fine-tuning, Sampling and using Google's ReformerLM models


`%pip install --upgrade requirements-colab.txt'
108 changes: 108 additions & 0 deletions configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
train_config = """
import trax.layers
import trax.models
import trax.optimizers
import trax.supervised.inputs
import trax.supervised.trainer_lib
# Parameters that will vary between experiments:
# ==============================================================================
train.model = @trax.models.ReformerLM
# attn_type = @TimeBinCausalAttention
attn_type = [
@TimeBinCausalAttention,
@TimeBinCausalAttention,
@LSHCausalAttention,
@TimeBinCausalAttention,
]
share_qk = False # LSHCausalAttention ignores this flag and always shares q & k
attn_kv = 128
n_layers = 12
dropout = 0.2
# MemoryEfficientCausalAttention: full attention
# (no hparams to vary between experiments)
# TimeBinCausalAttention: attend to nearby items
TimeBinCausalAttention.n_bins = 512
# LSHCausalAttention: locality-sensitive hashing (LSH) attention
LSHCausalAttention.n_bins = 256
LSHCausalAttention.n_buckets = 512 # Always 2 * n_bins
LSHCausalAttention.n_hashes = 2
LSHCausalAttention.drop_for_hash_rate = 0.0
# Parameters for MultifactorSchedule:
# ==============================================================================
# 0.03125 ~= 1024^-0.5 = d_model^-0.5
MultifactorSchedule.constant = 0.03125
MultifactorSchedule.factors = 'constant * linear_warmup * rsqrt_decay'
MultifactorSchedule.warmup_steps = 2000
# Parameters for Adam:
# ==============================================================================
Adam.weight_decay_rate=0.0
Adam.b1 = 0.9
Adam.b2 = 0.98
Adam.eps = 1e-9
# Parameters for MemoryEfficientCausalAttention:
# ==============================================================================
MemoryEfficientCausalAttention.dropout = 0.0
MemoryEfficientCausalAttention.loop_stride = 256
MemoryEfficientCausalAttention.share_qk = %share_qk
# Parameters for TimeBinCausalAttention:
# ==============================================================================
TimeBinCausalAttention.dropout = 0.2
# TimeBinCausalAttention.n_bins: see top
TimeBinCausalAttention.share_qk = %share_qk
# Parameters for LSHCausalAttention:
# ==============================================================================
LSHCausalAttention.allow_duplicate_attention = False
LSHCausalAttention.attend_across_buckets = True
LSHCausalAttention.rehash_each_round = True
# LSHCausalAttention.n_bins: see top
# LSHCausalAttention.n_buckets: see top
# LSHCausalAttention.n_hashes: see top
LSHCausalAttention.one_rng = False
LSHCausalAttention.hard_k = 0
LSHCausalAttention.dropout = 0.2
# LSHCausalAttention.drop_for_hash_rate: see top
# Parameters for TransformerLM:
# ==============================================================================
TransformerLM.attention_type = %attn_type
TransformerLM.d_attention_key = %attn_kv
TransformerLM.d_attention_value = %attn_kv
TransformerLM.d_model = 1024
TransformerLM.d_ff = 2048
TransformerLM.dropout = %dropout
TransformerLM.max_len = 65536
TransformerLM.mode = 'train'
TransformerLM.n_heads = 8
TransformerLM.n_layers = %n_layers
TransformerLM.share_qk = %share_qk
TransformerLM.vocab_size = 258 # Includes pad token and unused EOS token
# Parameters for ReformerLM:
# ==============================================================================
ReformerLM.attention_type = %attn_type
ReformerLM.d_attention_key = %attn_kv
ReformerLM.d_attention_value = %attn_kv
ReformerLM.d_model = 1024
ReformerLM.d_ff = 2048
ReformerLM.dropout = %dropout
ReformerLM.ff_activation = @trax.layers.Relu
ReformerLM.max_len = 65536
ReformerLM.mode = 'train'
ReformerLM.n_heads = 8
ReformerLM.n_layers = %n_layers
ReformerLM.vocab_size = 258 # Includes pad token and unused EOS token
ReformerLM.share_qk = %share_qk
ReformerLM.axial_pos_shape = (128, 512)
ReformerLM.d_axial_pos_embs= (256, 768)
"""
98 changes: 98 additions & 0 deletions reformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import argparse
import gin
import os
import jax
import trax
from trax.supervised import inputs

import numpy as onp
import jax.numpy as np
from scipy.special import softmax


import glob
import json
from tokenizers import ByteLevelBPETokenizer

from start_tpu import config
from config import train_config

parser = argparse.ArgumentParser(
description='Tokenize a folder of text file(s)')

parser.add_argument('--data_folder', type=str, default='tokenized_data',
help='Data folder with 1 or more tokenized files')
parser.add_argument('--model_folder', type=str, default='model',
help='Folder For saving and loading the model')
parser.add_argument('--steps_per_epoch', type=int, default=100)
parser.add_argument('--epochs', type=int, default=10)
args = parser.parse_args()


def gen_inputs(n_devices):
max_length = int(65536 * 0.98) # always leave a little padding
folder = args.data_folder
files = glob.glob(f'{folder}/*.npy')
print(f'first start from {len(files)} files')
while True:
file = onp.random.choice(files, 1)[0]
data = onp.load(file, allow_pickle=True)
print(f'processing from {file}, {len(data)} examples in file')
max_picks = int((len(data) * 0.7) / n_devices)
indices = onp.arange(len(data))
picks = onp.random.choice(
indices, (max_picks, n_devices), replace=False)
for id_list in picks:
inputs = []
mask = []
for id_ in id_list:
IDS = data[id_]
if len(IDS) > max_length:
rand_start = onp.random.randint(0, len(IDS) - max_length)
IDS = IDS[rand_start:rand_start + max_length]

PAD_AMOUNT = 65536 - len(IDS) # same as axial_pos_shape
pad_start = onp.random.choice(PAD_AMOUNT)
inputs.append(onp.pad(IDS, (pad_start, PAD_AMOUNT - pad_start),
mode='constant'))
mask.append(onp.pad(onp.ones_like(IDS, dtype=onp.float32),
(pad_start, PAD_AMOUNT - pad_start),
mode='constant'))
inputs = onp.stack(inputs)
mask = onp.stack(mask)
# for i in range(100):
yield (inputs, inputs, mask)


def gen_validation_inputs(n_devices):
# different validation each time but consistent across the run
ids = next(gen_inputs(n_devices))
while True:
return ids


def create_fixed_training_schedule(lr=0.0001):
# Yes, it does look unneceserily nested for passing a single float
def FixedTrainingSchedule(*args, **kwargs):
def learning_rate(step):
return {'learning_rate': np.asarray(lr, dtype=np.float32)}
return learning_rate


def train():
output_dir = os.path.expanduser(f'{args.model_folder}/')
trainer = trax.supervised.Trainer(
model=trax.models.ReformerLM,
loss_fn=trax.layers.CrossEntropyLoss,
optimizer=trax.optimizers.Adam,
lr_schedule=FixedTrainingSchedule,
# lr_schedule=trax.lr.MultifactorSchedule,
inputs=trax.supervised.inputs.Inputs(gen_inputs, gen_inputs2),
output_dir=output_dir,
has_weights=True)

for _ in range(args.epochs):
trainer.train_epoch(n_steps=args.steps_per_epoch, n_eval_steps=1)

if __name__ == '__main__':
train()
4 changes: 4 additions & 0 deletions requirements-colab.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
jax
jaxlib
git+https://github.com/google/[email protected]
tokenizers
11 changes: 11 additions & 0 deletions start_tpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import requests
import os
if 'TPU_DRIVER_MODE' not in globals():
url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20191206'
resp = requests.post(url)
TPU_DRIVER_MODE = 1

# The following is required to use TPU Driver as JAX's backend.
from jax.config import config
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']
83 changes: 83 additions & 0 deletions tokenize_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import argparse
import glob
import os
import json
import numpy as onp
from tokenizers import ByteLevelBPETokenizer

parser = argparse.ArgumentParser(
description='Tokenize a folder of text file(s)')

parser.add_argument('--input_folder', type=str, required=True,
help='Input folder with 1 or more text files')
parser.add_argument('--output_folder', type=str, default='tokenized_data')
parser.add_argument('--files_start_with', type=str, default='',
help='Process only files starting with this string')
parser.add_argument('--remove_input', default=False, action='store_true',
help='Delete input file after tokenizing')

args = parser.parse_args()

input_files = glob.glob(f'{args.input_folder}/{args.files_start_with}*')
input_files = [x for x in input_files if os.path.isfile(x)]

print(input_files)

tokenizer = ByteLevelBPETokenizer(
'256bytebpe-res-vocab.json', '256bytebpe-merges.txt')


def encode_data(file, max_per_n=10000):
folder = args.output_folder
with open(file, 'r') as f:
# print(f.read())
ids_n = 0
largest_id = 0
i = 0
id_list = []
for line in f:
i += 1
IDS = tokenizer.encode(line).ids
IDS = onp.asarray(IDS, dtype=onp.int32)
ids_n += len(IDS)
largest_id = max(len(IDS), largest_id)
print(largest_id)
id_list.append(IDS)
# print(id_list)
if i > 0 and i % max_per_n == 0:
# save every max_per_n lines
onp.save(f'{folder}/{file[1:]}-{i}', id_list)
print(f'{i} processed lines')
id_list = []
print(dict(ids_n=ids_n, largest_id=largest_id, name=folder))
if len(id_list) > 16:
# we skip if there's too litle for batching
onp.save(f'{folder}/{file[1:]}-{i}', id_list)
with open(f'{folder}/{file[1:]}-config', 'w') as out:
json.dump(dict(ids_n=ids_n, largest_id=largest_id, name=folder), out)


if __name__ == '__main__':
if not os.path.exists(args.output_folder):
os.makedirs(args.output_folder)

for i, file in enumerate(input_files):
print(file)
encode_data(file)
if args.remove_input:
os.remove(file)

# combine configs - not used, just for sanity checking
files = glob.glob(f'{args.output_folder}/*-config')
config = {'ids_n': 0, 'largest_id': 0, 'name': args.output_folder}
for file in files:
with open(file) as json_file:
temp_conf = json.load(json_file)
config['ids_n'] += temp_conf['ids_n']
config['largest_id'] = max(
temp_conf['largest_id'], config['largest_id'])

with open(f'config', 'w') as out:
json.dump(f'{args.output_folder}/config', out)
print(config)
print('Finished tokenizing data')

0 comments on commit 5a79f55

Please sign in to comment.