-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Move basic Colab code out of notebooks
- Loading branch information
Showing
9 changed files
with
310 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
__pycache__/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
#version: 0.2 - Trained by `huggingface/tokenizers` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"!": 256, "\"": 257, "#": 2, "$": 3, "%": 4, "&": 5, "'": 6, "(": 7, ")": 8, "*": 9, "+": 10, ",": 11, "-": 12, ".": 13, "/": 14, "0": 15, "1": 16, "2": 17, "3": 18, "4": 19, "5": 20, "6": 21, "7": 22, "8": 23, "9": 24, ":": 25, ";": 26, "<": 27, "=": 28, ">": 29, "?": 30, "@": 31, "A": 32, "B": 33, "C": 34, "D": 35, "E": 36, "F": 37, "G": 38, "H": 39, "I": 40, "J": 41, "K": 42, "L": 43, "M": 44, "N": 45, "O": 46, "P": 47, "Q": 48, "R": 49, "S": 50, "T": 51, "U": 52, "V": 53, "W": 54, "X": 55, "Y": 56, "Z": 57, "[": 58, "\\": 59, "]": 60, "^": 61, "_": 62, "`": 63, "a": 64, "b": 65, "c": 66, "d": 67, "e": 68, "f": 69, "g": 70, "h": 71, "i": 72, "j": 73, "k": 74, "l": 75, "m": 76, "n": 77, "o": 78, "p": 79, "q": 80, "r": 81, "s": 82, "t": 83, "u": 84, "v": 85, "w": 86, "x": 87, "y": 88, "z": 89, "{": 90, "|": 91, "}": 92, "~": 93, "\u00a1": 94, "\u00a2": 95, "\u00a3": 96, "\u00a4": 97, "\u00a5": 98, "\u00a6": 99, "\u00a7": 100, "\u00a8": 101, "\u00a9": 102, "\u00aa": 103, "\u00ab": 104, "\u00ac": 105, "\u00ae": 106, "\u00af": 107, "\u00b0": 108, "\u00b1": 109, "\u00b2": 110, "\u00b3": 111, "\u00b4": 112, "\u00b5": 113, "\u00b6": 114, "\u00b7": 115, "\u00b8": 116, "\u00b9": 117, "\u00ba": 118, "\u00bb": 119, "\u00bc": 120, "\u00bd": 121, "\u00be": 122, "\u00bf": 123, "\u00c0": 124, "\u00c1": 125, "\u00c2": 126, "\u00c3": 127, "\u00c4": 128, "\u00c5": 129, "\u00c6": 130, "\u00c7": 131, "\u00c8": 132, "\u00c9": 133, "\u00ca": 134, "\u00cb": 135, "\u00cc": 136, "\u00cd": 137, "\u00ce": 138, "\u00cf": 139, "\u00d0": 140, "\u00d1": 141, "\u00d2": 142, "\u00d3": 143, "\u00d4": 144, "\u00d5": 145, "\u00d6": 146, "\u00d7": 147, "\u00d8": 148, "\u00d9": 149, "\u00da": 150, "\u00db": 151, "\u00dc": 152, "\u00dd": 153, "\u00de": 154, "\u00df": 155, "\u00e0": 156, "\u00e1": 157, "\u00e2": 158, "\u00e3": 159, "\u00e4": 160, "\u00e5": 161, "\u00e6": 162, "\u00e7": 163, "\u00e8": 164, "\u00e9": 165, "\u00ea": 166, "\u00eb": 167, "\u00ec": 168, "\u00ed": 169, "\u00ee": 170, "\u00ef": 171, "\u00f0": 172, "\u00f1": 173, "\u00f2": 174, "\u00f3": 175, "\u00f4": 176, "\u00f5": 177, "\u00f6": 178, "\u00f7": 179, "\u00f8": 180, "\u00f9": 181, "\u00fa": 182, "\u00fb": 183, "\u00fc": 184, "\u00fd": 185, "\u00fe": 186, "\u00ff": 187, "\u0100": 188, "\u0101": 189, "\u0102": 190, "\u0103": 191, "\u0104": 192, "\u0105": 193, "\u0106": 194, "\u0107": 195, "\u0108": 196, "\u0109": 197, "\u010a": 198, "\u010b": 199, "\u010c": 200, "\u010d": 201, "\u010e": 202, "\u010f": 203, "\u0110": 204, "\u0111": 205, "\u0112": 206, "\u0113": 207, "\u0114": 208, "\u0115": 209, "\u0116": 210, "\u0117": 211, "\u0118": 212, "\u0119": 213, "\u011a": 214, "\u011b": 215, "\u011c": 216, "\u011d": 217, "\u011e": 218, "\u011f": 219, "\u0120": 220, "\u0121": 221, "\u0122": 222, "\u0123": 223, "\u0124": 224, "\u0125": 225, "\u0126": 226, "\u0127": 227, "\u0128": 228, "\u0129": 229, "\u012a": 230, "\u012b": 231, "\u012c": 232, "\u012d": 233, "\u012e": 234, "\u012f": 235, "\u0130": 236, "\u0131": 237, "\u0132": 238, "\u0133": 239, "\u0134": 240, "\u0135": 241, "\u0136": 242, "\u0137": 243, "\u0138": 244, "\u0139": 245, "\u013a": 246, "\u013b": 247, "\u013c": 248, "\u013d": 249, "\u013e": 250, "\u013f": 251, "\u0140": 252, "\u0141": 253, "\u0142": 254, "\u0143": 255} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,5 @@ | ||
# trax-reformer | ||
Training, Fine-tuning, Sampling and using Google's ReformerLM models | ||
|
||
|
||
`%pip install --upgrade requirements-colab.txt' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
train_config = """ | ||
import trax.layers | ||
import trax.models | ||
import trax.optimizers | ||
import trax.supervised.inputs | ||
import trax.supervised.trainer_lib | ||
# Parameters that will vary between experiments: | ||
# ============================================================================== | ||
train.model = @trax.models.ReformerLM | ||
# attn_type = @TimeBinCausalAttention | ||
attn_type = [ | ||
@TimeBinCausalAttention, | ||
@TimeBinCausalAttention, | ||
@LSHCausalAttention, | ||
@TimeBinCausalAttention, | ||
] | ||
share_qk = False # LSHCausalAttention ignores this flag and always shares q & k | ||
attn_kv = 128 | ||
n_layers = 12 | ||
dropout = 0.2 | ||
# MemoryEfficientCausalAttention: full attention | ||
# (no hparams to vary between experiments) | ||
# TimeBinCausalAttention: attend to nearby items | ||
TimeBinCausalAttention.n_bins = 512 | ||
# LSHCausalAttention: locality-sensitive hashing (LSH) attention | ||
LSHCausalAttention.n_bins = 256 | ||
LSHCausalAttention.n_buckets = 512 # Always 2 * n_bins | ||
LSHCausalAttention.n_hashes = 2 | ||
LSHCausalAttention.drop_for_hash_rate = 0.0 | ||
# Parameters for MultifactorSchedule: | ||
# ============================================================================== | ||
# 0.03125 ~= 1024^-0.5 = d_model^-0.5 | ||
MultifactorSchedule.constant = 0.03125 | ||
MultifactorSchedule.factors = 'constant * linear_warmup * rsqrt_decay' | ||
MultifactorSchedule.warmup_steps = 2000 | ||
# Parameters for Adam: | ||
# ============================================================================== | ||
Adam.weight_decay_rate=0.0 | ||
Adam.b1 = 0.9 | ||
Adam.b2 = 0.98 | ||
Adam.eps = 1e-9 | ||
# Parameters for MemoryEfficientCausalAttention: | ||
# ============================================================================== | ||
MemoryEfficientCausalAttention.dropout = 0.0 | ||
MemoryEfficientCausalAttention.loop_stride = 256 | ||
MemoryEfficientCausalAttention.share_qk = %share_qk | ||
# Parameters for TimeBinCausalAttention: | ||
# ============================================================================== | ||
TimeBinCausalAttention.dropout = 0.2 | ||
# TimeBinCausalAttention.n_bins: see top | ||
TimeBinCausalAttention.share_qk = %share_qk | ||
# Parameters for LSHCausalAttention: | ||
# ============================================================================== | ||
LSHCausalAttention.allow_duplicate_attention = False | ||
LSHCausalAttention.attend_across_buckets = True | ||
LSHCausalAttention.rehash_each_round = True | ||
# LSHCausalAttention.n_bins: see top | ||
# LSHCausalAttention.n_buckets: see top | ||
# LSHCausalAttention.n_hashes: see top | ||
LSHCausalAttention.one_rng = False | ||
LSHCausalAttention.hard_k = 0 | ||
LSHCausalAttention.dropout = 0.2 | ||
# LSHCausalAttention.drop_for_hash_rate: see top | ||
# Parameters for TransformerLM: | ||
# ============================================================================== | ||
TransformerLM.attention_type = %attn_type | ||
TransformerLM.d_attention_key = %attn_kv | ||
TransformerLM.d_attention_value = %attn_kv | ||
TransformerLM.d_model = 1024 | ||
TransformerLM.d_ff = 2048 | ||
TransformerLM.dropout = %dropout | ||
TransformerLM.max_len = 65536 | ||
TransformerLM.mode = 'train' | ||
TransformerLM.n_heads = 8 | ||
TransformerLM.n_layers = %n_layers | ||
TransformerLM.share_qk = %share_qk | ||
TransformerLM.vocab_size = 258 # Includes pad token and unused EOS token | ||
# Parameters for ReformerLM: | ||
# ============================================================================== | ||
ReformerLM.attention_type = %attn_type | ||
ReformerLM.d_attention_key = %attn_kv | ||
ReformerLM.d_attention_value = %attn_kv | ||
ReformerLM.d_model = 1024 | ||
ReformerLM.d_ff = 2048 | ||
ReformerLM.dropout = %dropout | ||
ReformerLM.ff_activation = @trax.layers.Relu | ||
ReformerLM.max_len = 65536 | ||
ReformerLM.mode = 'train' | ||
ReformerLM.n_heads = 8 | ||
ReformerLM.n_layers = %n_layers | ||
ReformerLM.vocab_size = 258 # Includes pad token and unused EOS token | ||
ReformerLM.share_qk = %share_qk | ||
ReformerLM.axial_pos_shape = (128, 512) | ||
ReformerLM.d_axial_pos_embs= (256, 768) | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
import argparse | ||
import gin | ||
import os | ||
import jax | ||
import trax | ||
from trax.supervised import inputs | ||
|
||
import numpy as onp | ||
import jax.numpy as np | ||
from scipy.special import softmax | ||
|
||
|
||
import glob | ||
import json | ||
from tokenizers import ByteLevelBPETokenizer | ||
|
||
from start_tpu import config | ||
from config import train_config | ||
|
||
parser = argparse.ArgumentParser( | ||
description='Tokenize a folder of text file(s)') | ||
|
||
parser.add_argument('--data_folder', type=str, default='tokenized_data', | ||
help='Data folder with 1 or more tokenized files') | ||
parser.add_argument('--model_folder', type=str, default='model', | ||
help='Folder For saving and loading the model') | ||
parser.add_argument('--steps_per_epoch', type=int, default=100) | ||
parser.add_argument('--epochs', type=int, default=10) | ||
args = parser.parse_args() | ||
|
||
|
||
def gen_inputs(n_devices): | ||
max_length = int(65536 * 0.98) # always leave a little padding | ||
folder = args.data_folder | ||
files = glob.glob(f'{folder}/*.npy') | ||
print(f'first start from {len(files)} files') | ||
while True: | ||
file = onp.random.choice(files, 1)[0] | ||
data = onp.load(file, allow_pickle=True) | ||
print(f'processing from {file}, {len(data)} examples in file') | ||
max_picks = int((len(data) * 0.7) / n_devices) | ||
indices = onp.arange(len(data)) | ||
picks = onp.random.choice( | ||
indices, (max_picks, n_devices), replace=False) | ||
for id_list in picks: | ||
inputs = [] | ||
mask = [] | ||
for id_ in id_list: | ||
IDS = data[id_] | ||
if len(IDS) > max_length: | ||
rand_start = onp.random.randint(0, len(IDS) - max_length) | ||
IDS = IDS[rand_start:rand_start + max_length] | ||
|
||
PAD_AMOUNT = 65536 - len(IDS) # same as axial_pos_shape | ||
pad_start = onp.random.choice(PAD_AMOUNT) | ||
inputs.append(onp.pad(IDS, (pad_start, PAD_AMOUNT - pad_start), | ||
mode='constant')) | ||
mask.append(onp.pad(onp.ones_like(IDS, dtype=onp.float32), | ||
(pad_start, PAD_AMOUNT - pad_start), | ||
mode='constant')) | ||
inputs = onp.stack(inputs) | ||
mask = onp.stack(mask) | ||
# for i in range(100): | ||
yield (inputs, inputs, mask) | ||
|
||
|
||
def gen_validation_inputs(n_devices): | ||
# different validation each time but consistent across the run | ||
ids = next(gen_inputs(n_devices)) | ||
while True: | ||
return ids | ||
|
||
|
||
def create_fixed_training_schedule(lr=0.0001): | ||
# Yes, it does look unneceserily nested for passing a single float | ||
def FixedTrainingSchedule(*args, **kwargs): | ||
def learning_rate(step): | ||
return {'learning_rate': np.asarray(lr, dtype=np.float32)} | ||
return learning_rate | ||
|
||
|
||
def train(): | ||
output_dir = os.path.expanduser(f'{args.model_folder}/') | ||
trainer = trax.supervised.Trainer( | ||
model=trax.models.ReformerLM, | ||
loss_fn=trax.layers.CrossEntropyLoss, | ||
optimizer=trax.optimizers.Adam, | ||
lr_schedule=FixedTrainingSchedule, | ||
# lr_schedule=trax.lr.MultifactorSchedule, | ||
inputs=trax.supervised.inputs.Inputs(gen_inputs, gen_inputs2), | ||
output_dir=output_dir, | ||
has_weights=True) | ||
|
||
for _ in range(args.epochs): | ||
trainer.train_epoch(n_steps=args.steps_per_epoch, n_eval_steps=1) | ||
|
||
if __name__ == '__main__': | ||
train() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
jax | ||
jaxlib | ||
git+https://github.com/google/[email protected] | ||
tokenizers |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
import requests | ||
import os | ||
if 'TPU_DRIVER_MODE' not in globals(): | ||
url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20191206' | ||
resp = requests.post(url) | ||
TPU_DRIVER_MODE = 1 | ||
|
||
# The following is required to use TPU Driver as JAX's backend. | ||
from jax.config import config | ||
config.FLAGS.jax_xla_backend = "tpu_driver" | ||
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
import argparse | ||
import glob | ||
import os | ||
import json | ||
import numpy as onp | ||
from tokenizers import ByteLevelBPETokenizer | ||
|
||
parser = argparse.ArgumentParser( | ||
description='Tokenize a folder of text file(s)') | ||
|
||
parser.add_argument('--input_folder', type=str, required=True, | ||
help='Input folder with 1 or more text files') | ||
parser.add_argument('--output_folder', type=str, default='tokenized_data') | ||
parser.add_argument('--files_start_with', type=str, default='', | ||
help='Process only files starting with this string') | ||
parser.add_argument('--remove_input', default=False, action='store_true', | ||
help='Delete input file after tokenizing') | ||
|
||
args = parser.parse_args() | ||
|
||
input_files = glob.glob(f'{args.input_folder}/{args.files_start_with}*') | ||
input_files = [x for x in input_files if os.path.isfile(x)] | ||
|
||
print(input_files) | ||
|
||
tokenizer = ByteLevelBPETokenizer( | ||
'256bytebpe-res-vocab.json', '256bytebpe-merges.txt') | ||
|
||
|
||
def encode_data(file, max_per_n=10000): | ||
folder = args.output_folder | ||
with open(file, 'r') as f: | ||
# print(f.read()) | ||
ids_n = 0 | ||
largest_id = 0 | ||
i = 0 | ||
id_list = [] | ||
for line in f: | ||
i += 1 | ||
IDS = tokenizer.encode(line).ids | ||
IDS = onp.asarray(IDS, dtype=onp.int32) | ||
ids_n += len(IDS) | ||
largest_id = max(len(IDS), largest_id) | ||
print(largest_id) | ||
id_list.append(IDS) | ||
# print(id_list) | ||
if i > 0 and i % max_per_n == 0: | ||
# save every max_per_n lines | ||
onp.save(f'{folder}/{file[1:]}-{i}', id_list) | ||
print(f'{i} processed lines') | ||
id_list = [] | ||
print(dict(ids_n=ids_n, largest_id=largest_id, name=folder)) | ||
if len(id_list) > 16: | ||
# we skip if there's too litle for batching | ||
onp.save(f'{folder}/{file[1:]}-{i}', id_list) | ||
with open(f'{folder}/{file[1:]}-config', 'w') as out: | ||
json.dump(dict(ids_n=ids_n, largest_id=largest_id, name=folder), out) | ||
|
||
|
||
if __name__ == '__main__': | ||
if not os.path.exists(args.output_folder): | ||
os.makedirs(args.output_folder) | ||
|
||
for i, file in enumerate(input_files): | ||
print(file) | ||
encode_data(file) | ||
if args.remove_input: | ||
os.remove(file) | ||
|
||
# combine configs - not used, just for sanity checking | ||
files = glob.glob(f'{args.output_folder}/*-config') | ||
config = {'ids_n': 0, 'largest_id': 0, 'name': args.output_folder} | ||
for file in files: | ||
with open(file) as json_file: | ||
temp_conf = json.load(json_file) | ||
config['ids_n'] += temp_conf['ids_n'] | ||
config['largest_id'] = max( | ||
temp_conf['largest_id'], config['largest_id']) | ||
|
||
with open(f'config', 'w') as out: | ||
json.dump(f'{args.output_folder}/config', out) | ||
print(config) | ||
print('Finished tokenizing data') |