diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ba0430d --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +__pycache__/ \ No newline at end of file diff --git a/256bytebpe-merges.txt b/256bytebpe-merges.txt new file mode 100644 index 0000000..0809d44 --- /dev/null +++ b/256bytebpe-merges.txt @@ -0,0 +1 @@ +#version: 0.2 - Trained by `huggingface/tokenizers` diff --git a/256bytebpe-res-vocab.json b/256bytebpe-res-vocab.json new file mode 100644 index 0000000..be7c6fb --- /dev/null +++ b/256bytebpe-res-vocab.json @@ -0,0 +1 @@ +{"!": 256, "\"": 257, "#": 2, "$": 3, "%": 4, "&": 5, "'": 6, "(": 7, ")": 8, "*": 9, "+": 10, ",": 11, "-": 12, ".": 13, "/": 14, "0": 15, "1": 16, "2": 17, "3": 18, "4": 19, "5": 20, "6": 21, "7": 22, "8": 23, "9": 24, ":": 25, ";": 26, "<": 27, "=": 28, ">": 29, "?": 30, "@": 31, "A": 32, "B": 33, "C": 34, "D": 35, "E": 36, "F": 37, "G": 38, "H": 39, "I": 40, "J": 41, "K": 42, "L": 43, "M": 44, "N": 45, "O": 46, "P": 47, "Q": 48, "R": 49, "S": 50, "T": 51, "U": 52, "V": 53, "W": 54, "X": 55, "Y": 56, "Z": 57, "[": 58, "\\": 59, "]": 60, "^": 61, "_": 62, "`": 63, "a": 64, "b": 65, "c": 66, "d": 67, "e": 68, "f": 69, "g": 70, "h": 71, "i": 72, "j": 73, "k": 74, "l": 75, "m": 76, "n": 77, "o": 78, "p": 79, "q": 80, "r": 81, "s": 82, "t": 83, "u": 84, "v": 85, "w": 86, "x": 87, "y": 88, "z": 89, "{": 90, "|": 91, "}": 92, "~": 93, "\u00a1": 94, "\u00a2": 95, "\u00a3": 96, "\u00a4": 97, "\u00a5": 98, "\u00a6": 99, "\u00a7": 100, "\u00a8": 101, "\u00a9": 102, "\u00aa": 103, "\u00ab": 104, "\u00ac": 105, "\u00ae": 106, "\u00af": 107, "\u00b0": 108, "\u00b1": 109, "\u00b2": 110, "\u00b3": 111, "\u00b4": 112, "\u00b5": 113, "\u00b6": 114, "\u00b7": 115, "\u00b8": 116, "\u00b9": 117, "\u00ba": 118, "\u00bb": 119, "\u00bc": 120, "\u00bd": 121, "\u00be": 122, "\u00bf": 123, "\u00c0": 124, "\u00c1": 125, "\u00c2": 126, "\u00c3": 127, "\u00c4": 128, "\u00c5": 129, "\u00c6": 130, "\u00c7": 131, "\u00c8": 132, "\u00c9": 133, "\u00ca": 134, "\u00cb": 135, "\u00cc": 136, "\u00cd": 137, "\u00ce": 138, "\u00cf": 139, "\u00d0": 140, "\u00d1": 141, "\u00d2": 142, "\u00d3": 143, "\u00d4": 144, "\u00d5": 145, "\u00d6": 146, "\u00d7": 147, "\u00d8": 148, "\u00d9": 149, "\u00da": 150, "\u00db": 151, "\u00dc": 152, "\u00dd": 153, "\u00de": 154, "\u00df": 155, "\u00e0": 156, "\u00e1": 157, "\u00e2": 158, "\u00e3": 159, "\u00e4": 160, "\u00e5": 161, "\u00e6": 162, "\u00e7": 163, "\u00e8": 164, "\u00e9": 165, "\u00ea": 166, "\u00eb": 167, "\u00ec": 168, "\u00ed": 169, "\u00ee": 170, "\u00ef": 171, "\u00f0": 172, "\u00f1": 173, "\u00f2": 174, "\u00f3": 175, "\u00f4": 176, "\u00f5": 177, "\u00f6": 178, "\u00f7": 179, "\u00f8": 180, "\u00f9": 181, "\u00fa": 182, "\u00fb": 183, "\u00fc": 184, "\u00fd": 185, "\u00fe": 186, "\u00ff": 187, "\u0100": 188, "\u0101": 189, "\u0102": 190, "\u0103": 191, "\u0104": 192, "\u0105": 193, "\u0106": 194, "\u0107": 195, "\u0108": 196, "\u0109": 197, "\u010a": 198, "\u010b": 199, "\u010c": 200, "\u010d": 201, "\u010e": 202, "\u010f": 203, "\u0110": 204, "\u0111": 205, "\u0112": 206, "\u0113": 207, "\u0114": 208, "\u0115": 209, "\u0116": 210, "\u0117": 211, "\u0118": 212, "\u0119": 213, "\u011a": 214, "\u011b": 215, "\u011c": 216, "\u011d": 217, "\u011e": 218, "\u011f": 219, "\u0120": 220, "\u0121": 221, "\u0122": 222, "\u0123": 223, "\u0124": 224, "\u0125": 225, "\u0126": 226, "\u0127": 227, "\u0128": 228, "\u0129": 229, "\u012a": 230, "\u012b": 231, "\u012c": 232, "\u012d": 233, "\u012e": 234, "\u012f": 235, "\u0130": 236, "\u0131": 237, "\u0132": 238, "\u0133": 239, "\u0134": 240, "\u0135": 241, "\u0136": 242, "\u0137": 243, "\u0138": 244, "\u0139": 245, "\u013a": 246, "\u013b": 247, "\u013c": 248, "\u013d": 249, "\u013e": 250, "\u013f": 251, "\u0140": 252, "\u0141": 253, "\u0142": 254, "\u0143": 255} \ No newline at end of file diff --git a/README.md b/README.md index 317ac29..c475917 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,5 @@ # trax-reformer Training, Fine-tuning, Sampling and using Google's ReformerLM models + + +`%pip install --upgrade requirements-colab.txt' \ No newline at end of file diff --git a/configs.py b/configs.py new file mode 100644 index 0000000..1d8a51d --- /dev/null +++ b/configs.py @@ -0,0 +1,108 @@ +train_config = """ +import trax.layers +import trax.models +import trax.optimizers +import trax.supervised.inputs +import trax.supervised.trainer_lib + +# Parameters that will vary between experiments: +# ============================================================================== +train.model = @trax.models.ReformerLM +# attn_type = @TimeBinCausalAttention +attn_type = [ + @TimeBinCausalAttention, + @TimeBinCausalAttention, + @LSHCausalAttention, + @TimeBinCausalAttention, +] +share_qk = False # LSHCausalAttention ignores this flag and always shares q & k +attn_kv = 128 +n_layers = 12 +dropout = 0.2 + +# MemoryEfficientCausalAttention: full attention +# (no hparams to vary between experiments) + +# TimeBinCausalAttention: attend to nearby items +TimeBinCausalAttention.n_bins = 512 + +# LSHCausalAttention: locality-sensitive hashing (LSH) attention +LSHCausalAttention.n_bins = 256 +LSHCausalAttention.n_buckets = 512 # Always 2 * n_bins +LSHCausalAttention.n_hashes = 2 +LSHCausalAttention.drop_for_hash_rate = 0.0 + + +# Parameters for MultifactorSchedule: +# ============================================================================== +# 0.03125 ~= 1024^-0.5 = d_model^-0.5 +MultifactorSchedule.constant = 0.03125 +MultifactorSchedule.factors = 'constant * linear_warmup * rsqrt_decay' +MultifactorSchedule.warmup_steps = 2000 + +# Parameters for Adam: +# ============================================================================== +Adam.weight_decay_rate=0.0 +Adam.b1 = 0.9 +Adam.b2 = 0.98 +Adam.eps = 1e-9 + + +# Parameters for MemoryEfficientCausalAttention: +# ============================================================================== +MemoryEfficientCausalAttention.dropout = 0.0 +MemoryEfficientCausalAttention.loop_stride = 256 +MemoryEfficientCausalAttention.share_qk = %share_qk + +# Parameters for TimeBinCausalAttention: +# ============================================================================== +TimeBinCausalAttention.dropout = 0.2 +# TimeBinCausalAttention.n_bins: see top +TimeBinCausalAttention.share_qk = %share_qk + +# Parameters for LSHCausalAttention: +# ============================================================================== +LSHCausalAttention.allow_duplicate_attention = False +LSHCausalAttention.attend_across_buckets = True +LSHCausalAttention.rehash_each_round = True +# LSHCausalAttention.n_bins: see top +# LSHCausalAttention.n_buckets: see top +# LSHCausalAttention.n_hashes: see top +LSHCausalAttention.one_rng = False +LSHCausalAttention.hard_k = 0 +LSHCausalAttention.dropout = 0.2 +# LSHCausalAttention.drop_for_hash_rate: see top + +# Parameters for TransformerLM: +# ============================================================================== +TransformerLM.attention_type = %attn_type +TransformerLM.d_attention_key = %attn_kv +TransformerLM.d_attention_value = %attn_kv +TransformerLM.d_model = 1024 +TransformerLM.d_ff = 2048 +TransformerLM.dropout = %dropout +TransformerLM.max_len = 65536 +TransformerLM.mode = 'train' +TransformerLM.n_heads = 8 +TransformerLM.n_layers = %n_layers +TransformerLM.share_qk = %share_qk +TransformerLM.vocab_size = 258 # Includes pad token and unused EOS token + +# Parameters for ReformerLM: +# ============================================================================== +ReformerLM.attention_type = %attn_type +ReformerLM.d_attention_key = %attn_kv +ReformerLM.d_attention_value = %attn_kv +ReformerLM.d_model = 1024 +ReformerLM.d_ff = 2048 +ReformerLM.dropout = %dropout +ReformerLM.ff_activation = @trax.layers.Relu +ReformerLM.max_len = 65536 +ReformerLM.mode = 'train' +ReformerLM.n_heads = 8 +ReformerLM.n_layers = %n_layers +ReformerLM.vocab_size = 258 # Includes pad token and unused EOS token +ReformerLM.share_qk = %share_qk +ReformerLM.axial_pos_shape = (128, 512) +ReformerLM.d_axial_pos_embs= (256, 768) +""" diff --git a/reformer.py b/reformer.py new file mode 100644 index 0000000..08af046 --- /dev/null +++ b/reformer.py @@ -0,0 +1,98 @@ +import argparse +import gin +import os +import jax +import trax +from trax.supervised import inputs + +import numpy as onp +import jax.numpy as np +from scipy.special import softmax + + +import glob +import json +from tokenizers import ByteLevelBPETokenizer + +from start_tpu import config +from config import train_config + +parser = argparse.ArgumentParser( + description='Tokenize a folder of text file(s)') + +parser.add_argument('--data_folder', type=str, default='tokenized_data', + help='Data folder with 1 or more tokenized files') +parser.add_argument('--model_folder', type=str, default='model', + help='Folder For saving and loading the model') +parser.add_argument('--steps_per_epoch', type=int, default=100) +parser.add_argument('--epochs', type=int, default=10) +args = parser.parse_args() + + +def gen_inputs(n_devices): + max_length = int(65536 * 0.98) # always leave a little padding + folder = args.data_folder + files = glob.glob(f'{folder}/*.npy') + print(f'first start from {len(files)} files') + while True: + file = onp.random.choice(files, 1)[0] + data = onp.load(file, allow_pickle=True) + print(f'processing from {file}, {len(data)} examples in file') + max_picks = int((len(data) * 0.7) / n_devices) + indices = onp.arange(len(data)) + picks = onp.random.choice( + indices, (max_picks, n_devices), replace=False) + for id_list in picks: + inputs = [] + mask = [] + for id_ in id_list: + IDS = data[id_] + if len(IDS) > max_length: + rand_start = onp.random.randint(0, len(IDS) - max_length) + IDS = IDS[rand_start:rand_start + max_length] + + PAD_AMOUNT = 65536 - len(IDS) # same as axial_pos_shape + pad_start = onp.random.choice(PAD_AMOUNT) + inputs.append(onp.pad(IDS, (pad_start, PAD_AMOUNT - pad_start), + mode='constant')) + mask.append(onp.pad(onp.ones_like(IDS, dtype=onp.float32), + (pad_start, PAD_AMOUNT - pad_start), + mode='constant')) + inputs = onp.stack(inputs) + mask = onp.stack(mask) + # for i in range(100): + yield (inputs, inputs, mask) + + +def gen_validation_inputs(n_devices): + # different validation each time but consistent across the run + ids = next(gen_inputs(n_devices)) + while True: + return ids + + +def create_fixed_training_schedule(lr=0.0001): + # Yes, it does look unneceserily nested for passing a single float + def FixedTrainingSchedule(*args, **kwargs): + def learning_rate(step): + return {'learning_rate': np.asarray(lr, dtype=np.float32)} + return learning_rate + + +def train(): + output_dir = os.path.expanduser(f'{args.model_folder}/') + trainer = trax.supervised.Trainer( + model=trax.models.ReformerLM, + loss_fn=trax.layers.CrossEntropyLoss, + optimizer=trax.optimizers.Adam, + lr_schedule=FixedTrainingSchedule, + # lr_schedule=trax.lr.MultifactorSchedule, + inputs=trax.supervised.inputs.Inputs(gen_inputs, gen_inputs2), + output_dir=output_dir, + has_weights=True) + + for _ in range(args.epochs): + trainer.train_epoch(n_steps=args.steps_per_epoch, n_eval_steps=1) + +if __name__ == '__main__': + train() diff --git a/requirements-colab.txt b/requirements-colab.txt new file mode 100644 index 0000000..63a2536 --- /dev/null +++ b/requirements-colab.txt @@ -0,0 +1,4 @@ +jax +jaxlib +git+https://github.com/google/trax.git@v1.2.2 +tokenizers \ No newline at end of file diff --git a/start_tpu.py b/start_tpu.py new file mode 100644 index 0000000..b474262 --- /dev/null +++ b/start_tpu.py @@ -0,0 +1,11 @@ +import requests +import os +if 'TPU_DRIVER_MODE' not in globals(): + url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20191206' + resp = requests.post(url) + TPU_DRIVER_MODE = 1 + +# The following is required to use TPU Driver as JAX's backend. +from jax.config import config +config.FLAGS.jax_xla_backend = "tpu_driver" +config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR'] diff --git a/tokenize_data.py b/tokenize_data.py new file mode 100644 index 0000000..b113e85 --- /dev/null +++ b/tokenize_data.py @@ -0,0 +1,83 @@ +import argparse +import glob +import os +import json +import numpy as onp +from tokenizers import ByteLevelBPETokenizer + +parser = argparse.ArgumentParser( + description='Tokenize a folder of text file(s)') + +parser.add_argument('--input_folder', type=str, required=True, + help='Input folder with 1 or more text files') +parser.add_argument('--output_folder', type=str, default='tokenized_data') +parser.add_argument('--files_start_with', type=str, default='', + help='Process only files starting with this string') +parser.add_argument('--remove_input', default=False, action='store_true', + help='Delete input file after tokenizing') + +args = parser.parse_args() + +input_files = glob.glob(f'{args.input_folder}/{args.files_start_with}*') +input_files = [x for x in input_files if os.path.isfile(x)] + +print(input_files) + +tokenizer = ByteLevelBPETokenizer( + '256bytebpe-res-vocab.json', '256bytebpe-merges.txt') + + +def encode_data(file, max_per_n=10000): + folder = args.output_folder + with open(file, 'r') as f: + # print(f.read()) + ids_n = 0 + largest_id = 0 + i = 0 + id_list = [] + for line in f: + i += 1 + IDS = tokenizer.encode(line).ids + IDS = onp.asarray(IDS, dtype=onp.int32) + ids_n += len(IDS) + largest_id = max(len(IDS), largest_id) + print(largest_id) + id_list.append(IDS) + # print(id_list) + if i > 0 and i % max_per_n == 0: + # save every max_per_n lines + onp.save(f'{folder}/{file[1:]}-{i}', id_list) + print(f'{i} processed lines') + id_list = [] + print(dict(ids_n=ids_n, largest_id=largest_id, name=folder)) + if len(id_list) > 16: + # we skip if there's too litle for batching + onp.save(f'{folder}/{file[1:]}-{i}', id_list) + with open(f'{folder}/{file[1:]}-config', 'w') as out: + json.dump(dict(ids_n=ids_n, largest_id=largest_id, name=folder), out) + + +if __name__ == '__main__': + if not os.path.exists(args.output_folder): + os.makedirs(args.output_folder) + + for i, file in enumerate(input_files): + print(file) + encode_data(file) + if args.remove_input: + os.remove(file) + + # combine configs - not used, just for sanity checking + files = glob.glob(f'{args.output_folder}/*-config') + config = {'ids_n': 0, 'largest_id': 0, 'name': args.output_folder} + for file in files: + with open(file) as json_file: + temp_conf = json.load(json_file) + config['ids_n'] += temp_conf['ids_n'] + config['largest_id'] = max( + temp_conf['largest_id'], config['largest_id']) + + with open(f'config', 'w') as out: + json.dump(f'{args.output_folder}/config', out) + print(config) + print('Finished tokenizing data')