Skip to content

Commit

Permalink
Add fake data to sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
Tenoke committed Mar 17, 2020
1 parent 8d92275 commit f9311d2
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 48 deletions.
8 changes: 5 additions & 3 deletions reformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import glob
import jax
import os
import sys
import requests
import trax

Expand All @@ -16,7 +17,7 @@
parser = argparse.ArgumentParser(
description='Tokenize a folder of text file(s)')

parser.add_argument('--data_folder', type=str, default='tokenized_data',
parser.add_argument('--data_folder', type=str, default='sample_data',
help='Data folder with 1 or more tokenized files')
parser.add_argument('--model_folder', type=str, default='model',
help='Folder For saving and loading the model')
Expand Down Expand Up @@ -110,9 +111,10 @@ def train():
output_dir=output_dir,
has_weights=True)

for _ in range(args.epochs):
for i in range(args.epochs):
print(f'epoch {i} starting')
trainer.train_epoch(n_steps=args.steps_per_epoch, n_eval_steps=1)


sys.exit()
if __name__ == '__main__':
train()
52 changes: 8 additions & 44 deletions sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,14 @@
from trax.supervised import inputs
import numpy as onp
import jax.numpy as np
from tokenizers import ByteLevelBPETokenizer


from configs import train_config

parser = argparse.ArgumentParser(
description='Tokenize a folder of text file(s)')

parser.add_argument('--data_folder', type=str, default='tokenized_data',
help='Data folder with 1 or more tokenized files')
parser.add_argument('--model_folder', type=str, default='model',
help='Folder For saving and loading the model')
parser.add_argument('--tpu',
Expand Down Expand Up @@ -51,48 +50,13 @@
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']
print(config.FLAGS.jax_backend_target)

tokenizer = ByteLevelBPETokenizer(
'256bytebpe-res-vocab.json', '256bytebpe-merges.txt')

def gen_inputs(n_devices):
max_length = int(65536 * 0.98) # always leave a little padding
folder = args.data_folder
files = glob.glob(f'{folder}/*.npy')
print(f'first start from {len(files)} files')
while True:
file = onp.random.choice(files, 1)[0]
data = onp.load(file, allow_pickle=True)
print(f'processing from {file}, {len(data)} examples in file')
max_picks = int((len(data) * 0.7) / n_devices)
indices = onp.arange(len(data))
picks = onp.random.choice(
indices, (max_picks, n_devices), replace=False)
for id_list in picks:
inputs = []
mask = []
for id_ in id_list:
IDS = data[id_]
if len(IDS) > max_length:
rand_start = onp.random.randint(0, len(IDS) - max_length)
IDS = IDS[rand_start:rand_start + max_length]

PAD_AMOUNT = 65536 - len(IDS) # same as axial_pos_shape
pad_start = onp.random.choice(PAD_AMOUNT)
inputs.append(onp.pad(IDS, (pad_start, PAD_AMOUNT - pad_start),
mode='constant'))
mask.append(onp.pad(onp.ones_like(IDS, dtype=onp.float32),
(pad_start, PAD_AMOUNT - pad_start),
mode='constant'))
inputs = onp.stack(inputs)
mask = onp.stack(mask)
# for i in range(100):
yield (inputs, inputs, mask)


def gen_validation_inputs(n_devices):
# different validation each time but consistent across the run
ids = next(gen_inputs(n_devices))
while True:
yield ids

def fake_data(n_devices):
data = onp.zeros((n_devices, 65536))
while True:
yield (data, data, data)

def create_fixed_training_schedule(lr=0.001):
# Yes, it does look unneceserily nested for passing a single float
Expand Down Expand Up @@ -199,7 +163,7 @@ def sample():
optimizer=trax.optimizers.Adam,
lr_schedule=schedule,
inputs=trax.supervised.inputs.Inputs(
gen_inputs, gen_validation_inputs),
fake_data, fake_data),
output_dir=output_dir,
has_weights=True)
model_infer = trax.models.ReformerLM(mode='predict')
Expand Down
1 change: 0 additions & 1 deletion sample_data/config

This file was deleted.

0 comments on commit f9311d2

Please sign in to comment.