From 6fb3245f589a29734e4136155f14b07409810ff3 Mon Sep 17 00:00:00 2001 From: Zachary Nado Date: Mon, 9 Aug 2021 08:37:31 -0700 Subject: [PATCH] Internal change PiperOrigin-RevId: 389640780 --- .travis.yml | 4 +- baselines/jft/batchensemble.py | 19 +++++++- baselines/jft/batchensemble_utils.py | 1 + baselines/jft/deterministic.py | 10 ++++ baselines/jft/deterministic_test.py | 1 - baselines/jft/experiments/common_fewshot.py | 48 +++++++++++++++++++ .../jft/experiments/jft300m_vit_base16.py | 2 +- baselines/jft/heteroscedastic.py | 10 ++++ baselines/jft/sngp.py | 10 ++++ baselines/jft/sngp_test.py | 1 - uncertainty_baselines/__init__.py | 10 +--- 11 files changed, 102 insertions(+), 14 deletions(-) create mode 100644 baselines/jft/experiments/common_fewshot.py diff --git a/.travis.yml b/.travis.yml index 77d212bb3..fcb9aec55 100644 --- a/.travis.yml +++ b/.travis.yml @@ -8,7 +8,7 @@ stages: jobs: include: - stage: lint - python: "3.6" + python: "3.7" script: - set -v # print commands as they're executed - set -e # fail and exit on any command erroring @@ -16,7 +16,7 @@ jobs: - pylint --jobs=2 --rcfile=pylintrc *.py - pylint --jobs=2 --rcfile=pylintrc */ python: - - "3.6" + - "3.7" install: - set -v # print commands as they're executed - set -e # fail and exit on any command erroring diff --git a/baselines/jft/batchensemble.py b/baselines/jft/batchensemble.py index c397bba24..daf3add14 100644 --- a/baselines/jft/batchensemble.py +++ b/baselines/jft/batchensemble.py @@ -38,6 +38,23 @@ import batchensemble_utils # local file import # TODO(dusenberrymw): Open-source remaining imports. +u = None +ensemble = None +default_input_pipeline = None +jft_latest_pipeline = None +metric_writers = None +partitioning = None +train = None +experts_utils = None +xprof = None +core = None +metrics = None +ema = None +pp_builder = None +config_flags = None +xm = None +xm_api = None +BIG_VISION_DIR = None config_flags.DEFINE_config_file( @@ -58,7 +75,7 @@ def restore_model_and_put_to_devices( config: ml_collections.ConfigDict, output_dir: str, - partition_specs: Sequence[PartitionSpec], + partition_specs: Sequence[partitioning.PartitionSpec], model: flax.nn.Module, optimizer: flax.optim.Optimizer, train_iter: Iterable[Any], diff --git a/baselines/jft/batchensemble_utils.py b/baselines/jft/batchensemble_utils.py index 487b45188..3c80e1bc5 100644 --- a/baselines/jft/batchensemble_utils.py +++ b/baselines/jft/batchensemble_utils.py @@ -27,6 +27,7 @@ import jax.numpy as jnp # TODO(dusenberrymw): Open-source remaining imports. +core = None EvaluationOutput = Tuple[jnp.ndarray, ...] diff --git a/baselines/jft/deterministic.py b/baselines/jft/deterministic.py index 4c7ecb808..661b19de9 100644 --- a/baselines/jft/deterministic.py +++ b/baselines/jft/deterministic.py @@ -34,6 +34,13 @@ from tensorflow.io import gfile import uncertainty_baselines as ub +fewshot = None +input_pipeline = None +resformer = None +u = None +pp_builder = None +xm = None +xm_api = None # TODO(dusenberrymw): Open-source remaining imports. @@ -77,11 +84,14 @@ def main(argv): # tf.data pipeline not being deterministic even if we would set TF seed. rng = jax.random.PRNGKey(config.get('seed', 0)) + xm_xp = None + xm_wu = None def write_note(note): if jax.host_id() == 0: logging.info('NOTE: %s', note) write_note('Initializing...') + fillin = lambda *_: None # Verify settings to make sure no checkpoints are accidentally missed. if config.get('keep_checkpoint_steps'): assert config.get('checkpoint_steps'), 'Specify `checkpoint_steps`.' diff --git a/baselines/jft/deterministic_test.py b/baselines/jft/deterministic_test.py index 62207216d..14419ee9b 100644 --- a/baselines/jft/deterministic_test.py +++ b/baselines/jft/deterministic_test.py @@ -16,7 +16,6 @@ """Tests for the deterministic ViT on JFT-300M model script.""" import os import pathlib -import shutil import tempfile from absl import flags diff --git a/baselines/jft/experiments/common_fewshot.py b/baselines/jft/experiments/common_fewshot.py new file mode 100644 index 000000000..e397bcc05 --- /dev/null +++ b/baselines/jft/experiments/common_fewshot.py @@ -0,0 +1,48 @@ +# coding=utf-8 +# Copyright 2021 The Uncertainty Baselines Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Most common few-shot eval configuration.""" + +import ml_collections + + +def get_fewshot(batch_size=None, target_resolution=224, resize_resolution=256, + runlocal=False): + """Returns a standard-ish fewshot eval configuration.""" + config = ml_collections.ConfigDict() + if batch_size: + config.batch_size = batch_size + config.representation_layer = 'pre_logits' + config.log_steps = 25_000 + config.datasets = { # pylint: disable=g-long-ternary + 'birds': ('caltech_birds2011', 'train', 'test'), + 'caltech': ('caltech101', 'train', 'test'), + 'cars': ('cars196:2.1.0', 'train', 'test'), + 'cifar100': ('cifar100', 'train', 'test'), + 'col_hist': ('colorectal_histology', 'train[:2000]', 'train[2000:]'), + 'dtd': ('dtd', 'train', 'test'), + 'imagenet': ('imagenet2012_subset/10pct', 'train', 'validation'), + 'pets': ('oxford_iiit_pet', 'train', 'test'), + 'uc_merced': ('uc_merced', 'train[:1000]', 'train[1000:]'), + } if not runlocal else { + 'pets': ('oxford_iiit_pet', 'train', 'test'), + } + config.pp_train = f'decode|resize({resize_resolution})|central_crop({target_resolution})|value_range(-1,1)' + config.pp_eval = f'decode|resize({resize_resolution})|central_crop({target_resolution})|value_range(-1,1)' + config.shots = [1, 5, 10, 25] + config.l2_regs = [2.0 ** i for i in range(-10, 20)] + config.walk_first = ('imagenet', 10) if not runlocal else ('pets', 10) + + return config diff --git a/baselines/jft/experiments/jft300m_vit_base16.py b/baselines/jft/experiments/jft300m_vit_base16.py index a944a2603..108071743 100644 --- a/baselines/jft/experiments/jft300m_vit_base16.py +++ b/baselines/jft/experiments/jft300m_vit_base16.py @@ -20,7 +20,7 @@ # pylint: enable=line-too-long import ml_collections -# TODO(dusenberrymw): Open-source remaining imports. +import get_fewshot # local file import def get_config(): diff --git a/baselines/jft/heteroscedastic.py b/baselines/jft/heteroscedastic.py index f6b29a7c1..0b62fb510 100644 --- a/baselines/jft/heteroscedastic.py +++ b/baselines/jft/heteroscedastic.py @@ -35,6 +35,13 @@ import uncertainty_baselines as ub # TODO(dusenberrymw): Open-source remaining imports. +fewshot = None +input_pipeline = None +resformer = None +u = None +pp_builder = None +xm = None +xm_api = None ml_collections.config_flags.DEFINE_config_file( @@ -77,11 +84,14 @@ def main(argv): # tf.data pipeline not being deterministic even if we would set TF seed. rng = jax.random.PRNGKey(config.get('seed', 0)) + xm_xp = None + xm_wu = None def write_note(note): if jax.host_id() == 0: logging.info('NOTE: %s', note) write_note('Initializing...') + fillin = lambda *_: None # Verify settings to make sure no checkpoints are accidentally missed. if config.get('keep_checkpoint_steps'): assert config.get('checkpoint_steps'), 'Specify `checkpoint_steps`.' diff --git a/baselines/jft/sngp.py b/baselines/jft/sngp.py index 6636daea6..32a588cc8 100644 --- a/baselines/jft/sngp.py +++ b/baselines/jft/sngp.py @@ -36,6 +36,13 @@ import uncertainty_baselines as ub # TODO(dusenberrymw): Open-source remaining imports. +fewshot = None +input_pipeline = None +resformer = None +u = None +pp_builder = None +xm = None +xm_api = None ml_collections.config_flags.DEFINE_config_file( @@ -143,11 +150,14 @@ def main(argv): # tf.data pipeline not being deterministic even if we would set TF seed. rng = jax.random.PRNGKey(config.get('seed', 0)) + xm_xp = None + xm_wu = None def write_note(note): if jax.host_id() == 0: logging.info('NOTE: %s', note) write_note('Initializing...') + fillin = lambda *_: None # Verify settings to make sure no checkpoints are accidentally missed. if config.get('keep_checkpoint_steps'): assert config.get('checkpoint_steps'), 'Specify `checkpoint_steps`.' diff --git a/baselines/jft/sngp_test.py b/baselines/jft/sngp_test.py index 45eebe3a1..016c05683 100644 --- a/baselines/jft/sngp_test.py +++ b/baselines/jft/sngp_test.py @@ -16,7 +16,6 @@ """Tests for the ViT-SNGP on JFT-300M model script.""" import os import pathlib -import shutil import tempfile from absl import flags diff --git a/uncertainty_baselines/__init__.py b/uncertainty_baselines/__init__.py index 64d767ba3..042366dc6 100644 --- a/uncertainty_baselines/__init__.py +++ b/uncertainty_baselines/__init__.py @@ -48,11 +48,5 @@ def _lazy_import(name): return imported -for module_name in _IMPORTS: - try: - _lazy_import(module_name) - except ModuleNotFoundError: - logging.error( - 'Skipped importing top level uncertainty_baselines module %s due to ' - 'ModuleNotFoundError:', module_name, exc_info=True) - +# Lazily load any top level modules when accessed. Requires Python 3.7. +__getattr__ = _lazy_import