Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add code for running the Eval Harness in t5x #10

Merged
merged 2 commits into from
Dec 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions bigscience/gins/eval_harness.gin
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Defaults for eval_harness.py.
#
# invoke like:
#
# python3 ${T5X_DIR}/t5x/eval_harness.py \
# --gin_file_="t5x/examples/t5/t5_1_1/small.gin"\
# --gin_file_="t5x/bigscience/gins/eval_harness.gin" \
# --gin.INFER_OUTPUT_DIR="'.'"\
# --gin.DROPOUT_RATE=0.0 \
# --gin.CHECKPOINT_PATH="'gs://t5-data/pretrained_models/t5.1.1.lm100k.small/model.ckpt-1100000'"\
# --results_path /home/user/base_test.json


from __gin__ import dynamic_registration

import __main__ as infer_script
from t5x import partitioning
from t5x import utils
from t5x import models


#include %MODEL_GIN

# DEPRECATED: Import the this module in your gin file.
MIXTURE_OR_TASK_MODULE = None

infer_script.infer:
model = %MODEL # imported from separate gin file
output_dir = %INFER_OUTPUT_DIR
dataset_cfg = @utils.DatasetConfig()
partitioner = @partitioning.ModelBasedPjitPartitioner()
restore_checkpoint_cfg = @utils.RestoreCheckpointConfig()
checkpoint_period = 100
shard_id = 0
num_shards = 1


infer_script.create_task_from_tuples:
vocab = %VOCABULARY

partitioning.ModelBasedPjitPartitioner:
num_partitions = 4
model_parallel_submesh = (2,1,1,1)

TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 114}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused by this at inference, like how do you make a sample fit inside this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not quite sure what you mean, but fit as in fit in memory?

In that case I didn't play with it too much but since we're not storing grads and the batch size is small everything seems to work out fine even for the xxl. Just reduced it since we cant partition the small one 4 times.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you mean the length of the features I should probably find a way to make sure that it's never truncated. The tasks I've looked at in the EH are quite short though so didn't seem to be an issue but should probably add an assert. Will be more of an issue if we look at few-shot instead of zero-shot.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah okay I see, it automatically pads to those sequence lengths right? Concerning the truncation problem ... that's a good problem. We tried tracking the length of each task in this google sheet (shared internally). And it seems to be okay-ish to truncate (most samples will fit) race might be problematic though.


utils.DatasetConfig:
batch_size = 16
use_cached = True
pack = True
use_custom_packing_ops = False
seed = 42
shuffle = False
split = 'infer'
module = None
mixture_or_task_name = None
task_feature_lengths = %TASK_FEATURE_LENGTHS

utils.RestoreCheckpointConfig:
path = %CHECKPOINT_PATH
mode = 'specific'
dtype = 'bfloat16'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm saving them in float32 don't know if this impacts if you load a float32 checkpoint in bfloat16. I have some earlier checkpoints, if you could run inference on them that'd be awesome!

Copy link
Collaborator Author

@DanielHesslow DanielHesslow Nov 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, but I don't think it should be an issue, since the training is in bfloat16 the inf should work as well. I'll check and see if it makes any difference though.

Sure, send me a path and I can test it.

Loading