diff --git a/reformer.py b/reformer.py index a9d088b..cdbcd03 100644 --- a/reformer.py +++ b/reformer.py @@ -83,7 +83,7 @@ def gen_validation_inputs(n_devices): # different validation each time but consistent across the run ids = next(gen_inputs(n_devices)) while True: - return ids + yield ids def create_fixed_training_schedule(lr): diff --git a/requirements-colab.txt b/requirements-colab.txt index 63a2536..188b626 100644 --- a/requirements-colab.txt +++ b/requirements-colab.txt @@ -1,4 +1,5 @@ jax jaxlib git+https://github.com/google/trax.git@v1.2.2 -tokenizers \ No newline at end of file +tokenizers +gin \ No newline at end of file