Skip to content

Commit

Permalink
Add minor sampling fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Tenoke committed Mar 18, 2020
1 parent f097b2d commit 9742ffc
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def learning_rate(step):
return FixedTrainingSchedule


def sample_single(length=2048, prompt=None, temperature=1.0, top_k=None, exp2=None,
def sample_single(model_weights, infer_state, jit_model_infer, length=2048, prompt=None, temperature=1.0, top_k=None, exp2=None,
boost_top=None, meena_max=False, meena_combine=False):
"""Sample from the ReformerLM model
example top_k = 32
Expand Down Expand Up @@ -137,7 +137,7 @@ def sample_single(length=2048, prompt=None, temperature=1.0, top_k=None, exp2=No
cur_inputs = np.array(cur_samples[:, None, None])
all_samples = onp.stack(all_samples, -1)
for ids in all_samples:
print(tokenizer.decode(ids.tolist()))
print(tokenizer.decode(ids.tolist()).replace('\\n', '\n'))
print('_____________________')
if meena_combine or meena_max:
# all samples are the same for those options
Expand Down Expand Up @@ -172,8 +172,9 @@ def sample(args):
model_weights = trainer._opt_state[0][0]
del trainer
del model_infer
sample_single(length=args.length, prompt=args.prompt, temperature=args.temperature,
top_k=args.top_k, exp2=args.exp2, meena_combine=args.meena_combine, meena_max=args.meena_max)
sample_single(model_weights=model_weights, infer_state=infer_state, jit_model_infer=jit_model_infer, length=args.length, prompt=args.prompt, temperature=args.temperature,
top_k=args.top_k, exp2=args.exp2, meena_combine=args.meena_combine,
meena_max=args.meena_max)


def main_sample(args):
Expand Down

0 comments on commit 9742ffc

Please sign in to comment.