diff --git a/configs.py b/configs.py index 620af40..cc4e4a2 100644 --- a/configs.py +++ b/configs.py @@ -92,3 +92,13 @@ ReformerLM.axial_pos_shape = (128, 512) ReformerLM.d_axial_pos_embs= (256, 768) """ + +test_config = """ +TimeBinCausalAttention.bin_length = 128 +TimeBinCausalAttention.n_bins = None +LSHCausalAttention.n_hashes = 8 +LSHCausalAttention.bucket_capacity_for_inference = 258 +ReformerLM.dropout = 0.0 +TimeBinCausalAttention.dropout = 0.0 +LSHCausalAttention.dropout = 0.0 +""" diff --git a/reformer.py b/reformer.py index 502d93d..4089d07 100644 --- a/reformer.py +++ b/reformer.py @@ -1,4 +1,3 @@ -import argparse import gin import glob import jax @@ -6,7 +5,7 @@ import sys import requests import trax - +from functools import partial from trax.supervised import inputs import numpy as onp import jax.numpy as np @@ -14,40 +13,29 @@ from configs import train_config -parser = argparse.ArgumentParser( - description='Tokenize a folder of text file(s)') - -parser.add_argument('--data_folder', type=str, default='sample_data', - help='Data folder with 1 or more tokenized files') -parser.add_argument('--model_folder', type=str, default='model', - help='Folder For saving and loading the model') -parser.add_argument('--steps_per_epoch', type=int, default=100) -parser.add_argument('--epochs', type=int, default=10) -parser.add_argument('--learning_rate', type=float, default=0.0001) -parser.add_argument('--multi_factor_schedule', - default=False, action='store_true') -parser.add_argument('--tpu', - default=False, action='store_true') - +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser( + description='Tokenize a folder of text file(s)') -args = parser.parse_args() + parser.add_argument('--data_folder', type=str, default='sample_data', + help='Data folder with 1 or more tokenized files') + parser.add_argument('--model_folder', type=str, default='model', + help='Folder For saving and loading the model') + parser.add_argument('--steps_per_epoch', type=int, default=100) + parser.add_argument('--epochs', type=int, default=10) + parser.add_argument('--learning_rate', type=float, default=0.0001) + parser.add_argument('--multi_factor_schedule', + default=False, action='store_true') + parser.add_argument('--tpu', + default=False, action='store_true') -if args.tpu: - if 'TPU_DRIVER_MODE' not in globals(): - url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20191206' - resp = requests.post(url) - TPU_DRIVER_MODE = 1 + args = parser.parse_args() - # The following is required to use TPU Driver as JAX's backend. - from jax.config import config - config.FLAGS.jax_xla_backend = "tpu_driver" - config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR'] - print(config.FLAGS.jax_backend_target) -def gen_inputs(n_devices): +def gen_inputs(n_devices, folder): max_length = int(65536 * 0.98) # always leave a little padding - folder = args.data_folder files = glob.glob(f'{folder}/*.npy') print(f'first start from {len(files)} files') while True: @@ -80,9 +68,9 @@ def gen_inputs(n_devices): yield (inputs, inputs, mask) -def gen_validation_inputs(n_devices): +def gen_validation_inputs(n_devices, folder): # different validation each time but consistent across the run - ids = next(gen_inputs(n_devices)) + ids = next(gen_inputs(n_devices, folder)) while True: yield ids @@ -96,7 +84,7 @@ def learning_rate(step): return FixedTrainingSchedule -def train(): +def train(args): gin.parse_config(train_config) schedule = create_fixed_training_schedule(args.learning_rate) if args.multi_factor_schedule: @@ -107,7 +95,7 @@ def train(): loss_fn=trax.layers.CrossEntropyLoss, optimizer=trax.optimizers.Adam, lr_schedule=schedule, - inputs=trax.supervised.inputs.Inputs(gen_inputs, gen_validation_inputs), + inputs=trax.supervised.inputs.Inputs(partial(gen_inputs, folder=args.data_folder), partial(gen_validation_inputs, folder=args.data_folder)), output_dir=output_dir, has_weights=True) @@ -115,6 +103,18 @@ def train(): print(f'epoch {i} starting') trainer.train_epoch(n_steps=args.steps_per_epoch, n_eval_steps=1) - sys.exit() +def main_train(args): + if args.tpu: + if 'TPU_DRIVER_MODE' not in globals(): + url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20191206' + resp = requests.post(url) + TPU_DRIVER_MODE = 1 + # The following is required to use TPU Driver as JAX's backend. + from jax.config import config + config.FLAGS.jax_xla_backend = "tpu_driver" + config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR'] + print(config.FLAGS.jax_backend_target) + train(args) + if __name__ == '__main__': - train() + main_train(args) diff --git a/sample.py b/sample.py index e4572f6..b89e812 100644 --- a/sample.py +++ b/sample.py @@ -12,54 +12,45 @@ from tokenizers import ByteLevelBPETokenizer -from configs import train_config - -parser = argparse.ArgumentParser( - description='Tokenize a folder of text file(s)') - -parser.add_argument('--model_folder', type=str, default='model', - help='Folder For saving and loading the model') -parser.add_argument('--tpu', - default=False, action='store_true') -parser.add_argument('--prompt', type=str, default='', - help='Prompt for beginning the sampling e.g. {"title": "Sampling"') -parser.add_argument('--prompt', type=int, default=512, - help='Maximum length of sample') -parser.add_argument('--temperature', type=float, default=1.0, - help='Sampling Temperature') -parser.add_argument('--top_k', type=int, default=0,) -parser.add_argument('--exp2', default=False, action='store_true', - help='Use exp2 instead of exp during sampling') -parser.add_argument('--meena_max', default=False, action='store_true', - help='pick the probabilities with highest max') -parser.add_argument('--meena_combine', default=False, action='store_true', - help='use all probabilities from all samples (8 on TPU) at once') - -args = parser.parse_args() - -if args.tpu: - if 'TPU_DRIVER_MODE' not in globals(): - url = 'http://' + os.environ['COLAB_TPU_ADDR'].split( - ':')[0] + ':8475/requestversion/tpu_driver0.1-dev20191206' - resp = requests.post(url) - TPU_DRIVER_MODE = 1 - - # The following is required to use TPU Driver as JAX's backend. - from jax.config import config - config.FLAGS.jax_xla_backend = "tpu_driver" - config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR'] - print(config.FLAGS.jax_backend_target) +from configs import train_config, test_config + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description='Sample from a given model') + + parser.add_argument('--model_folder', type=str, default='model', + help='Folder For saving and loading the model') + parser.add_argument('--tpu', + default=False, action='store_true') + parser.add_argument('--prompt', type=str, default='', + help='Prompt for beginning the sampling e.g. {"title": "Sampling"') + parser.add_argument('--length', type=int, default=512, + help='Maximum length of sample') + parser.add_argument('--temperature', type=float, default=1.0, + help='Sampling Temperature') + parser.add_argument('--top_k', type=int, default=0,) + parser.add_argument('--exp2', default=False, action='store_true', + help='Use exp2 instead of exp during sampling') + parser.add_argument('--meena_max', default=False, action='store_true', + help='pick the probabilities with highest max') + parser.add_argument('--meena_combine', default=False, action='store_true', + help='use all probabilities from all samples (8 on TPU) at once') + + args = parser.parse_args() + tokenizer = ByteLevelBPETokenizer( '256bytebpe-res-vocab.json', '256bytebpe-merges.txt') + def fake_data(n_devices): - data = onp.zeros((n_devices, 65536)) - while True: - yield (data, data, data) + data = onp.zeros((n_devices, 65536)) + while True: + yield (data, data, data) + def create_fixed_training_schedule(lr=0.001): - # Yes, it does look unneceserily nested for passing a single float + # Yes, it does look unneceserily nested for passing a single float def FixedTrainingSchedule(*args, **kwargs): def learning_rate(step): return {'learning_rate': np.asarray(lr, dtype=np.float32)} @@ -67,7 +58,8 @@ def learning_rate(step): return FixedTrainingSchedule -def sample(length=args.length, prompt=args.prompt, temperature=args.temperature, top_k=args.top_k, exp2=args.exp2, boost_top=args.boost_top, meena_max=args.meena_max, meena_combine=args.meena_combine): +def sample_single(length=2048, prompt=None, temperature=1.0, top_k=None, exp2=None, + boost_top=None, meena_max=False, meena_combine=False): """Sample from the ReformerLM model example top_k = 32 exp2 = True|False @@ -153,7 +145,7 @@ def sample(length=args.length, prompt=args.prompt, temperature=args.temperature, return all_samples -def sample(): +def sample(args): gin.parse_config(train_config) schedule = create_fixed_training_schedule() output_dir = os.path.expanduser(f'{args.model_folder}/') @@ -166,6 +158,7 @@ def sample(): fake_data, fake_data), output_dir=output_dir, has_weights=True) + gin.parse_config(test_config) model_infer = trax.models.ReformerLM(mode='predict') # Prepare a jitted copy of the model. @@ -179,10 +172,25 @@ def sample(): model_weights = trainer._opt_state[0][0] del trainer del model_infer - for _ in range(args.epochs): - sample_single(length=2048, prompt=None, temperature=1.0, - top_k=None, exp2=None, boost_top=None) + sample_single(length=args.length, prompt=args.prompt, temperature=args.temperature, + top_k=args.top_k, exp2=args.exp2, meena_combine=args.meena_combine, meena_max=args.meena_max) + + +def main_sample(args): + if args.tpu: + if 'TPU_DRIVER_MODE' not in globals(): + url = 'http://' + os.environ['COLAB_TPU_ADDR'].split( + ':')[0] + ':8475/requestversion/tpu_driver0.1-dev20191206' + resp = requests.post(url) + TPU_DRIVER_MODE = 1 + # The following is required to use TPU Driver as JAX's backend. + from jax.config import config + config.FLAGS.jax_xla_backend = "tpu_driver" + config.FLAGS.jax_backend_target = "grpc://" + \ + os.environ['COLAB_TPU_ADDR'] + print(config.FLAGS.jax_backend_target) + sample(args) if __name__ == '__main__': - sample() + main_sample(args)