From 64271d2662a521cea919191b7b355af60009b49a Mon Sep 17 00:00:00 2001 From: Daniel Hesslow Date: Wed, 24 Nov 2021 16:39:50 +0100 Subject: [PATCH 1/2] Add code for running the Eval Harness in t5x --- bigscience/gins/eval_harness.gin | 63 ++++++ t5x/eval_harness.py | 349 +++++++++++++++++++++++++++++++ 2 files changed, 412 insertions(+) create mode 100644 bigscience/gins/eval_harness.gin create mode 100644 t5x/eval_harness.py diff --git a/bigscience/gins/eval_harness.gin b/bigscience/gins/eval_harness.gin new file mode 100644 index 000000000..4e85ba183 --- /dev/null +++ b/bigscience/gins/eval_harness.gin @@ -0,0 +1,63 @@ +# Defaults for eval_harness.py. +# +# invoke like: +# +# python3 ${T5X_DIR}/t5x/eval_harness.py \ +# --gin_file_="t5x/examples/t5/t5_1_1/small.gin"\ +# --gin_file_="t5x/bigscience/gins/eval_harness.gin" \ +# --gin.INFER_OUTPUT_DIR="'.'"\ +# --gin.DROPOUT_RATE=0.0 \ +# --gin.CHECKPOINT_PATH="'gs://t5-data/pretrained_models/t5.1.1.lm100k.small/model.ckpt-1100000'"\ +# --results_path /home/user/base_test.json + + +from __gin__ import dynamic_registration + +import __main__ as infer_script +from t5x import partitioning +from t5x import utils +from t5x import models + + +#include %MODEL_GIN + +# DEPRECATED: Import the this module in your gin file. +MIXTURE_OR_TASK_MODULE = None + +infer_script.infer: + model = %MODEL # imported from separate gin file + output_dir = %INFER_OUTPUT_DIR + dataset_cfg = @utils.DatasetConfig() + partitioner = @partitioning.ModelBasedPjitPartitioner() + restore_checkpoint_cfg = @utils.RestoreCheckpointConfig() + checkpoint_period = 100 + shard_id = 0 + num_shards = 1 + + +infer_script.create_task_from_tuples: + vocab = %VOCABULARY + +partitioning.ModelBasedPjitPartitioner: + num_partitions = 4 + model_parallel_submesh = (2,1,1,1) + +TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 114} + +utils.DatasetConfig: + batch_size = 16 + use_cached = True + pack = True + use_custom_packing_ops = False + seed = 42 + shuffle = False + use_cached = True + split = 'infer' + module = None + mixture_or_task_name = None + task_feature_lengths = %TASK_FEATURE_LENGTHS + +utils.RestoreCheckpointConfig: + path = %CHECKPOINT_PATH + mode = 'specific' + dtype = 'bfloat16' diff --git a/t5x/eval_harness.py b/t5x/eval_harness.py new file mode 100644 index 000000000..960394017 --- /dev/null +++ b/t5x/eval_harness.py @@ -0,0 +1,349 @@ +# Copyright 2021 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint:disable=line-too-long +# pyformat: disable +r"""This script runs inference on a T5X-compatible model. + +""" +# pyformat: enable +# pylint:enable=line-too-long +from functools import partial +import concurrent.futures +import functools +import hashlib +import json +import os +import re +import shutil +import time +from typing import Any, Callable, Iterator, List, Mapping, Optional, Sequence, Tuple +import task + +from absl import logging +import jax +import jax.numpy as jnp +import seqio +from t5x import models +from t5x import multihost_utils +from t5x import partitioning +from t5x import utils +from t5x.infer import create_task_from_tfexample_file +import tensorflow as tf +from tensorflow.io import gfile + +from lm_eval.base import LM +import numpy as np +from lm_eval import evaluator, tasks +from models import cross_entropy_with_logits +from flax.training import common_utils + + +# Automatically search for gin files relative to the T5X package. +_DEFAULT_GIN_SEARCH_PATHS = [ + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +] + +AUTOTUNE = tf.data.experimental.AUTOTUNE + +def create_task_from_tuples(data, vocab): + tfrecord_writer = tf.io.TFRecordWriter("data.tfrecord") + def _bytes_feature(value): + """Returns a bytes_list from a string / byte.""" + if isinstance(value, type(tf.constant(0))): + value = value.numpy() # BytesList won't unpack a string from an EagerTensor. + return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) + + for p in data: + input, target = p + input, target = input.encode('utf-8'), target.encode('utf-8') + example = tf.train.Example(features=tf.train.Features(feature={ + 'input': _bytes_feature(input), + 'target': _bytes_feature(target), + })) + tfrecord_writer.write(example.SerializeToString()) + + tfrecord_writer.close() + + features = {'inputs': seqio.Feature(vocabulary=vocab, add_eos = False), 'targets': seqio.Feature(vocabulary=vocab, add_eos = False)} + task_name = create_task_from_tfexample_file(['data.tfrecord'], 'tfrecord', 'input', 'target', features) + return task_name + +def infer(*, + mode: str, + model: models.BaseTransformerModel, + dataset_cfg: utils.DatasetConfig, + restore_checkpoint_cfg: utils.RestoreCheckpointConfig, + partitioner: partitioning.BasePartitioner, + output_dir: str, + checkpoint_period: int, + task_name : str, + shard_id: int = 0, + num_shards: int = 1, + run_xprof: bool = True, + merge_epoch_results: bool = True): + """Funciton to run the inference and return the results as is. Slightly simpler version than the one in infer.py + + Args: + mode: Either 'predict' to decode targets, 'score' to compute the log + likelihood of given targets, or 'predict_with_aux' for both. + model: The model object to use for inference. + dataset_cfg: Specification for the dataset to infer based on. + restore_checkpoint_cfg: Specification for the model parameter checkpoint to + load. + partitioner: Partitioner for model parameters and data across devices. + output_dir: Path to directory to write temporary files and final results. + checkpoint_period: The intermediate results and dataset iterator will be + checkpointed on each multiple of this number of batches to enable + continuation after a failure. + shard_id: Index of dataset shard for this instance to use if splitting the + work across multiple jobs. + num_shards: Total number of dataset shards to split dataset across. + run_xprof: Whether to take an xprof snapshot during run. + merge_epoch_results: Whether to merge results of all epochs into a single + json file. + write_fn: Callable function used to serialized and write inferences out to + files. + """ + if mode not in ('predict', 'score', 'predict_with_aux', 'score_with_correct'): + raise ValueError( + "`mode` must be one of 'predict', 'score' or 'predict_with_aux'. " + f"Got '{mode}'") + + # Remove double-slashes in directory path to avoid inconsistencies. + output_dir = re.sub(r'(? 1: + raise app.UsageError('Too many command-line arguments.') + + if FLAGS.tfds_data_dir: + seqio.set_tfds_data_dir_override(FLAGS.tfds_data_dir) + + # Create gin-configurable version of `infer`. + infer_using_gin = gin.configurable(infer) + + create_task_from_tuples_gin = gin.configurable(create_task_from_tuples) + gin_utils.parse_gin_flags( + # User-provided gin paths take precedence if relative paths conflict. + FLAGS.gin_search_paths + _DEFAULT_GIN_SEARCH_PATHS, + FLAGS.gin_file_, + FLAGS.gin_bindings) + + print(FLAGS.results_path) + + eval_task(FLAGS.results_path, create_task_from_tuples_gin, infer_using_gin) + + gin_utils.run(main) From 3caa0a0ed2b6b2f8da9c2649cbc2c25a0c45d15e Mon Sep 17 00:00:00 2001 From: Daniel Hesslow Date: Thu, 25 Nov 2021 15:19:48 +0100 Subject: [PATCH 2/2] Remove duplicated use_cached --- bigscience/gins/eval_harness.gin | 1 - 1 file changed, 1 deletion(-) diff --git a/bigscience/gins/eval_harness.gin b/bigscience/gins/eval_harness.gin index 4e85ba183..0bd5c42d2 100644 --- a/bigscience/gins/eval_harness.gin +++ b/bigscience/gins/eval_harness.gin @@ -51,7 +51,6 @@ utils.DatasetConfig: use_custom_packing_ops = False seed = 42 shuffle = False - use_cached = True split = 'infer' module = None mixture_or_task_name = None