From 9742ffcd0dd1ccc96de0cb29e0195e1946473c24 Mon Sep 17 00:00:00 2001 From: Svilen Todorov Date: Wed, 18 Mar 2020 17:05:58 +0100 Subject: [PATCH] Add minor sampling fixes --- sample.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/sample.py b/sample.py index b89e812..893cb8d 100644 --- a/sample.py +++ b/sample.py @@ -58,7 +58,7 @@ def learning_rate(step): return FixedTrainingSchedule -def sample_single(length=2048, prompt=None, temperature=1.0, top_k=None, exp2=None, +def sample_single(model_weights, infer_state, jit_model_infer, length=2048, prompt=None, temperature=1.0, top_k=None, exp2=None, boost_top=None, meena_max=False, meena_combine=False): """Sample from the ReformerLM model example top_k = 32 @@ -137,7 +137,7 @@ def sample_single(length=2048, prompt=None, temperature=1.0, top_k=None, exp2=No cur_inputs = np.array(cur_samples[:, None, None]) all_samples = onp.stack(all_samples, -1) for ids in all_samples: - print(tokenizer.decode(ids.tolist())) + print(tokenizer.decode(ids.tolist()).replace('\\n', '\n')) print('_____________________') if meena_combine or meena_max: # all samples are the same for those options @@ -172,8 +172,9 @@ def sample(args): model_weights = trainer._opt_state[0][0] del trainer del model_infer - sample_single(length=args.length, prompt=args.prompt, temperature=args.temperature, - top_k=args.top_k, exp2=args.exp2, meena_combine=args.meena_combine, meena_max=args.meena_max) + sample_single(model_weights=model_weights, infer_state=infer_state, jit_model_infer=jit_model_infer, length=args.length, prompt=args.prompt, temperature=args.temperature, + top_k=args.top_k, exp2=args.exp2, meena_combine=args.meena_combine, + meena_max=args.meena_max) def main_sample(args):