diff --git a/reformer.py b/reformer.py index cdbcd03..502d93d 100644 --- a/reformer.py +++ b/reformer.py @@ -3,6 +3,7 @@ import glob import jax import os +import sys import requests import trax @@ -16,7 +17,7 @@ parser = argparse.ArgumentParser( description='Tokenize a folder of text file(s)') -parser.add_argument('--data_folder', type=str, default='tokenized_data', +parser.add_argument('--data_folder', type=str, default='sample_data', help='Data folder with 1 or more tokenized files') parser.add_argument('--model_folder', type=str, default='model', help='Folder For saving and loading the model') @@ -110,9 +111,10 @@ def train(): output_dir=output_dir, has_weights=True) - for _ in range(args.epochs): + for i in range(args.epochs): + print(f'epoch {i} starting') trainer.train_epoch(n_steps=args.steps_per_epoch, n_eval_steps=1) - + sys.exit() if __name__ == '__main__': train() diff --git a/sample.py b/sample.py index cadeb2d..e4572f6 100644 --- a/sample.py +++ b/sample.py @@ -9,6 +9,7 @@ from trax.supervised import inputs import numpy as onp import jax.numpy as np +from tokenizers import ByteLevelBPETokenizer from configs import train_config @@ -16,8 +17,6 @@ parser = argparse.ArgumentParser( description='Tokenize a folder of text file(s)') -parser.add_argument('--data_folder', type=str, default='tokenized_data', - help='Data folder with 1 or more tokenized files') parser.add_argument('--model_folder', type=str, default='model', help='Folder For saving and loading the model') parser.add_argument('--tpu', @@ -51,48 +50,13 @@ config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR'] print(config.FLAGS.jax_backend_target) +tokenizer = ByteLevelBPETokenizer( + '256bytebpe-res-vocab.json', '256bytebpe-merges.txt') -def gen_inputs(n_devices): - max_length = int(65536 * 0.98) # always leave a little padding - folder = args.data_folder - files = glob.glob(f'{folder}/*.npy') - print(f'first start from {len(files)} files') - while True: - file = onp.random.choice(files, 1)[0] - data = onp.load(file, allow_pickle=True) - print(f'processing from {file}, {len(data)} examples in file') - max_picks = int((len(data) * 0.7) / n_devices) - indices = onp.arange(len(data)) - picks = onp.random.choice( - indices, (max_picks, n_devices), replace=False) - for id_list in picks: - inputs = [] - mask = [] - for id_ in id_list: - IDS = data[id_] - if len(IDS) > max_length: - rand_start = onp.random.randint(0, len(IDS) - max_length) - IDS = IDS[rand_start:rand_start + max_length] - - PAD_AMOUNT = 65536 - len(IDS) # same as axial_pos_shape - pad_start = onp.random.choice(PAD_AMOUNT) - inputs.append(onp.pad(IDS, (pad_start, PAD_AMOUNT - pad_start), - mode='constant')) - mask.append(onp.pad(onp.ones_like(IDS, dtype=onp.float32), - (pad_start, PAD_AMOUNT - pad_start), - mode='constant')) - inputs = onp.stack(inputs) - mask = onp.stack(mask) - # for i in range(100): - yield (inputs, inputs, mask) - - -def gen_validation_inputs(n_devices): - # different validation each time but consistent across the run - ids = next(gen_inputs(n_devices)) - while True: - yield ids - +def fake_data(n_devices): + data = onp.zeros((n_devices, 65536)) + while True: + yield (data, data, data) def create_fixed_training_schedule(lr=0.001): # Yes, it does look unneceserily nested for passing a single float @@ -199,7 +163,7 @@ def sample(): optimizer=trax.optimizers.Adam, lr_schedule=schedule, inputs=trax.supervised.inputs.Inputs( - gen_inputs, gen_validation_inputs), + fake_data, fake_data), output_dir=output_dir, has_weights=True) model_infer = trax.models.ReformerLM(mode='predict') diff --git a/sample_data/config b/sample_data/config deleted file mode 100644 index 4307755..0000000 --- a/sample_data/config +++ /dev/null @@ -1 +0,0 @@ -{"ids_n": 19813635120, "largest_id": 404635, "name": "data/bytebpe256-res"} \ No newline at end of file