Skip to content

Commit

Permalink
Add fixes for Colab
Browse files Browse the repository at this point in the history
  • Loading branch information
Tenoke committed Mar 17, 2020
1 parent f9311d2 commit f097b2d
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 83 deletions.
10 changes: 10 additions & 0 deletions configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,13 @@
ReformerLM.axial_pos_shape = (128, 512)
ReformerLM.d_axial_pos_embs= (256, 768)
"""

test_config = """
TimeBinCausalAttention.bin_length = 128
TimeBinCausalAttention.n_bins = None
LSHCausalAttention.n_hashes = 8
LSHCausalAttention.bucket_capacity_for_inference = 258
ReformerLM.dropout = 0.0
TimeBinCausalAttention.dropout = 0.0
LSHCausalAttention.dropout = 0.0
"""
72 changes: 36 additions & 36 deletions reformer.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,41 @@
import argparse
import gin
import glob
import jax
import os
import sys
import requests
import trax

from functools import partial
from trax.supervised import inputs
import numpy as onp
import jax.numpy as np


from configs import train_config

parser = argparse.ArgumentParser(
description='Tokenize a folder of text file(s)')

parser.add_argument('--data_folder', type=str, default='sample_data',
help='Data folder with 1 or more tokenized files')
parser.add_argument('--model_folder', type=str, default='model',
help='Folder For saving and loading the model')
parser.add_argument('--steps_per_epoch', type=int, default=100)
parser.add_argument('--epochs', type=int, default=10)
parser.add_argument('--learning_rate', type=float, default=0.0001)
parser.add_argument('--multi_factor_schedule',
default=False, action='store_true')
parser.add_argument('--tpu',
default=False, action='store_true')

if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(
description='Tokenize a folder of text file(s)')

args = parser.parse_args()
parser.add_argument('--data_folder', type=str, default='sample_data',
help='Data folder with 1 or more tokenized files')
parser.add_argument('--model_folder', type=str, default='model',
help='Folder For saving and loading the model')
parser.add_argument('--steps_per_epoch', type=int, default=100)
parser.add_argument('--epochs', type=int, default=10)
parser.add_argument('--learning_rate', type=float, default=0.0001)
parser.add_argument('--multi_factor_schedule',
default=False, action='store_true')
parser.add_argument('--tpu',
default=False, action='store_true')

if args.tpu:
if 'TPU_DRIVER_MODE' not in globals():
url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20191206'
resp = requests.post(url)
TPU_DRIVER_MODE = 1
args = parser.parse_args()

# The following is required to use TPU Driver as JAX's backend.
from jax.config import config
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']
print(config.FLAGS.jax_backend_target)


def gen_inputs(n_devices):
def gen_inputs(n_devices, folder):
max_length = int(65536 * 0.98) # always leave a little padding
folder = args.data_folder
files = glob.glob(f'{folder}/*.npy')
print(f'first start from {len(files)} files')
while True:
Expand Down Expand Up @@ -80,9 +68,9 @@ def gen_inputs(n_devices):
yield (inputs, inputs, mask)


def gen_validation_inputs(n_devices):
def gen_validation_inputs(n_devices, folder):
# different validation each time but consistent across the run
ids = next(gen_inputs(n_devices))
ids = next(gen_inputs(n_devices, folder))
while True:
yield ids

Expand All @@ -96,7 +84,7 @@ def learning_rate(step):
return FixedTrainingSchedule


def train():
def train(args):
gin.parse_config(train_config)
schedule = create_fixed_training_schedule(args.learning_rate)
if args.multi_factor_schedule:
Expand All @@ -107,14 +95,26 @@ def train():
loss_fn=trax.layers.CrossEntropyLoss,
optimizer=trax.optimizers.Adam,
lr_schedule=schedule,
inputs=trax.supervised.inputs.Inputs(gen_inputs, gen_validation_inputs),
inputs=trax.supervised.inputs.Inputs(partial(gen_inputs, folder=args.data_folder), partial(gen_validation_inputs, folder=args.data_folder)),
output_dir=output_dir,
has_weights=True)

for i in range(args.epochs):
print(f'epoch {i} starting')
trainer.train_epoch(n_steps=args.steps_per_epoch, n_eval_steps=1)

sys.exit()
def main_train(args):
if args.tpu:
if 'TPU_DRIVER_MODE' not in globals():
url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20191206'
resp = requests.post(url)
TPU_DRIVER_MODE = 1
# The following is required to use TPU Driver as JAX's backend.
from jax.config import config
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']
print(config.FLAGS.jax_backend_target)
train(args)

if __name__ == '__main__':
train()
main_train(args)
102 changes: 55 additions & 47 deletions sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,62 +12,54 @@
from tokenizers import ByteLevelBPETokenizer


from configs import train_config

parser = argparse.ArgumentParser(
description='Tokenize a folder of text file(s)')

parser.add_argument('--model_folder', type=str, default='model',
help='Folder For saving and loading the model')
parser.add_argument('--tpu',
default=False, action='store_true')
parser.add_argument('--prompt', type=str, default='',
help='Prompt for beginning the sampling e.g. {"title": "Sampling"')
parser.add_argument('--prompt', type=int, default=512,
help='Maximum length of sample')
parser.add_argument('--temperature', type=float, default=1.0,
help='Sampling Temperature')
parser.add_argument('--top_k', type=int, default=0,)
parser.add_argument('--exp2', default=False, action='store_true',
help='Use exp2 instead of exp during sampling')
parser.add_argument('--meena_max', default=False, action='store_true',
help='pick the probabilities with highest max')
parser.add_argument('--meena_combine', default=False, action='store_true',
help='use all probabilities from all samples (8 on TPU) at once')

args = parser.parse_args()

if args.tpu:
if 'TPU_DRIVER_MODE' not in globals():
url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(
':')[0] + ':8475/requestversion/tpu_driver0.1-dev20191206'
resp = requests.post(url)
TPU_DRIVER_MODE = 1

# The following is required to use TPU Driver as JAX's backend.
from jax.config import config
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']
print(config.FLAGS.jax_backend_target)
from configs import train_config, test_config

if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Sample from a given model')

parser.add_argument('--model_folder', type=str, default='model',
help='Folder For saving and loading the model')
parser.add_argument('--tpu',
default=False, action='store_true')
parser.add_argument('--prompt', type=str, default='',
help='Prompt for beginning the sampling e.g. {"title": "Sampling"')
parser.add_argument('--length', type=int, default=512,
help='Maximum length of sample')
parser.add_argument('--temperature', type=float, default=1.0,
help='Sampling Temperature')
parser.add_argument('--top_k', type=int, default=0,)
parser.add_argument('--exp2', default=False, action='store_true',
help='Use exp2 instead of exp during sampling')
parser.add_argument('--meena_max', default=False, action='store_true',
help='pick the probabilities with highest max')
parser.add_argument('--meena_combine', default=False, action='store_true',
help='use all probabilities from all samples (8 on TPU) at once')

args = parser.parse_args()


tokenizer = ByteLevelBPETokenizer(
'256bytebpe-res-vocab.json', '256bytebpe-merges.txt')


def fake_data(n_devices):
data = onp.zeros((n_devices, 65536))
while True:
yield (data, data, data)
data = onp.zeros((n_devices, 65536))
while True:
yield (data, data, data)


def create_fixed_training_schedule(lr=0.001):
# Yes, it does look unneceserily nested for passing a single float
# Yes, it does look unneceserily nested for passing a single float
def FixedTrainingSchedule(*args, **kwargs):
def learning_rate(step):
return {'learning_rate': np.asarray(lr, dtype=np.float32)}
return learning_rate
return FixedTrainingSchedule


def sample(length=args.length, prompt=args.prompt, temperature=args.temperature, top_k=args.top_k, exp2=args.exp2, boost_top=args.boost_top, meena_max=args.meena_max, meena_combine=args.meena_combine):
def sample_single(length=2048, prompt=None, temperature=1.0, top_k=None, exp2=None,
boost_top=None, meena_max=False, meena_combine=False):
"""Sample from the ReformerLM model
example top_k = 32
exp2 = True|False
Expand Down Expand Up @@ -153,7 +145,7 @@ def sample(length=args.length, prompt=args.prompt, temperature=args.temperature,
return all_samples


def sample():
def sample(args):
gin.parse_config(train_config)
schedule = create_fixed_training_schedule()
output_dir = os.path.expanduser(f'{args.model_folder}/')
Expand All @@ -166,6 +158,7 @@ def sample():
fake_data, fake_data),
output_dir=output_dir,
has_weights=True)
gin.parse_config(test_config)
model_infer = trax.models.ReformerLM(mode='predict')

# Prepare a jitted copy of the model.
Expand All @@ -179,10 +172,25 @@ def sample():
model_weights = trainer._opt_state[0][0]
del trainer
del model_infer
for _ in range(args.epochs):
sample_single(length=2048, prompt=None, temperature=1.0,
top_k=None, exp2=None, boost_top=None)
sample_single(length=args.length, prompt=args.prompt, temperature=args.temperature,
top_k=args.top_k, exp2=args.exp2, meena_combine=args.meena_combine, meena_max=args.meena_max)


def main_sample(args):
if args.tpu:
if 'TPU_DRIVER_MODE' not in globals():
url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(
':')[0] + ':8475/requestversion/tpu_driver0.1-dev20191206'
resp = requests.post(url)
TPU_DRIVER_MODE = 1
# The following is required to use TPU Driver as JAX's backend.
from jax.config import config
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + \
os.environ['COLAB_TPU_ADDR']
print(config.FLAGS.jax_backend_target)
sample(args)


if __name__ == '__main__':
sample()
main_sample(args)

0 comments on commit f097b2d

Please sign in to comment.