diff --git a/baselines/imagenet/batchensemble.py b/baselines/imagenet/batchensemble.py index bc8d213b..14d35b8c 100644 --- a/baselines/imagenet/batchensemble.py +++ b/baselines/imagenet/batchensemble.py @@ -25,7 +25,8 @@ import edward2 as ed import batchensemble_model # local file import import utils # local file import -import tensorflow as tf +from edward2.google.rank1_pert.ensemble_keras import utils as be_utils +import tensorflow.compat.v2 as tf flags.DEFINE_integer('ensemble_size', 4, 'Size of ensemble.') flags.DEFINE_integer('per_core_batch_size', 128, 'Batch size per TPU core/GPU.') @@ -39,16 +40,18 @@ 'fast weights lr multiplier.') flags.DEFINE_string('data_dir', None, 'Path to training and testing data.') flags.mark_flag_as_required('data_dir') -flags.DEFINE_string('output_dir', '/tmp/imagenet', - 'The directory where the model weights and ' - 'training/evaluation summaries are stored.') +flags.DEFINE_string( + 'output_dir', '/tmp/imagenet', 'The directory where the model weights and ' + 'training/evaluation summaries are stored.') flags.DEFINE_integer('train_epochs', 135, 'Number of training epochs.') -flags.DEFINE_integer('corruptions_interval', 135, - 'Number of epochs between evaluating on the corrupted ' - 'test data. Use -1 to never evaluate.') -flags.DEFINE_integer('checkpoint_interval', 27, - 'Number of epochs between saving checkpoints. Use -1 to ' - 'never save checkpoints.') +flags.DEFINE_integer( + 'corruptions_interval', 135, + 'Number of epochs between evaluating on the corrupted ' + 'test data. Use -1 to never evaluate.') +flags.DEFINE_integer( + 'checkpoint_interval', 27, + 'Number of epochs between saving checkpoints. Use -1 to ' + 'never save checkpoints.') flags.DEFINE_string('alexnet_errors_path', None, 'Path to AlexNet corruption errors file.') flags.DEFINE_integer('num_bins', 15, 'Number of bins for ECE computation.') @@ -60,6 +63,22 @@ flags.DEFINE_integer('num_cores', 32, 'Number of TPU cores or number of GPUs.') flags.DEFINE_string('tpu', None, 'Name of the TPU. Only used if use_gpu is False.') +flags.DEFINE_string('similarity_metric', 'cosine', 'Similarity metric in ' + '[cosine, dpp_logdet]') +flags.DEFINE_string('dpp_kernel', 'linear', 'Kernel for DPP log determinant') +flags.DEFINE_bool('use_output_similarity', False, + 'If true, compute similarity on the ensemble outputs.') +flags.DEFINE_enum('diversity_scheduler', 'LinearAnnealing', + ['LinearAnnealing', 'ExponentialDecay', 'Fixed'], + 'Diversity coefficient scheduler..') +flags.DEFINE_float('annealing_epochs', 200, + 'Number of epochs over which to linearly anneal') +flags.DEFINE_float('diversity_coeff', 0., 'Diversity loss coefficient.') +flags.DEFINE_float('diversity_decay_epoch', 4, 'Diversity decay epoch.') +flags.DEFINE_float('diversity_decay_rate', 0.97, 'Rate of exponential decay.') +flags.DEFINE_integer('diversity_start_epoch', 100, + 'Diversity loss starting epoch') + FLAGS = flags.FLAGS # Number of images in ImageNet-1k train dataset. @@ -68,7 +87,7 @@ IMAGENET_VALIDATION_IMAGES = 50000 NUM_CLASSES = 1000 -_LR_SCHEDULE = [ # (multiplier, epoch to start) tuples +_LR_SCHEDULE = [ # (multiplier, epoch to start) tuples (1.0, 5), (0.1, 30), (0.01, 60), (0.001, 80) ] @@ -147,22 +166,53 @@ def main(argv): logging.info('Model number of weights: %s', model.count_params()) # Scale learning rate and decay epochs by vanilla settings. base_lr = FLAGS.base_learning_rate * batch_size / 256 - learning_rate = utils.LearningRateSchedule(steps_per_epoch, - base_lr, - FLAGS.train_epochs, - _LR_SCHEDULE) - optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate, - momentum=0.9, - nesterov=True) + learning_rate = utils.LearningRateSchedule(steps_per_epoch, base_lr, + FLAGS.train_epochs, _LR_SCHEDULE) + optimizer = tf.keras.optimizers.SGD( + learning_rate=learning_rate, momentum=0.9, nesterov=True) + + if FLAGS.diversity_scheduler == 'ExponentialDecay': + diversity_schedule = be_utils.ExponentialDecay( + initial_coeff=FLAGS.diversity_coeff, + start_epoch=FLAGS.diversity_start_epoch, + decay_epoch=FLAGS.diversity_decay_epoch, + steps_per_epoch=steps_per_epoch, + decay_rate=FLAGS.diversity_decay_rate, + staircase=True) + + elif FLAGS.diversity_scheduler == 'LinearAnnealing': + diversity_schedule = be_utils.LinearAnnealing( + initial_coeff=FLAGS.diversity_coeff, + annealing_epochs=FLAGS.annealing_epochs, + steps_per_epoch=steps_per_epoch) + else: + diversity_schedule = lambda x: FLAGS.diversity_coeff + metrics = { - 'train/negative_log_likelihood': tf.keras.metrics.Mean(), - 'train/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(), - 'train/loss': tf.keras.metrics.Mean(), - 'train/ece': ed.metrics.ExpectedCalibrationError( - num_bins=FLAGS.num_bins), - 'test/negative_log_likelihood': tf.keras.metrics.Mean(), - 'test/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(), - 'test/ece': ed.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins) + 'train/similarity_loss': + tf.keras.metrics.Mean(), + 'train/weights_similarity': + tf.keras.metrics.Mean(), + 'train/outputs_similarity': + tf.keras.metrics.Mean(), + 'train/negative_log_likelihood': + tf.keras.metrics.Mean(), + 'train/accuracy': + tf.keras.metrics.SparseCategoricalAccuracy(), + 'train/loss': + tf.keras.metrics.Mean(), + 'train/ece': + ed.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins), + 'test/negative_log_likelihood': + tf.keras.metrics.Mean(), + 'test/accuracy': + tf.keras.metrics.SparseCategoricalAccuracy(), + 'test/ece': + ed.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins), + 'test/weights_similarity': + tf.keras.metrics.Mean(), + 'test/outputs_similarity': + tf.keras.metrics.Mean() } if FLAGS.corruptions_interval > 0: corrupt_metrics = {} @@ -208,6 +258,7 @@ def main(argv): @tf.function def train_step(iterator): """Training StepFn.""" + def step_fn(inputs): """Per-Replica StepFn.""" images, labels = inputs @@ -225,10 +276,20 @@ def step_fn(inputs): diversity_results = ed.metrics.average_pairwise_diversity( per_probs, FLAGS.ensemble_size) + # print(' > per_probs {}'.format(per_probs)) + similarity_coeff, similarity_loss = be_utils.scaled_similarity_loss( + FLAGS.diversity_coeff, diversity_schedule, optimizer.iterations, + FLAGS.similarity_metric, FLAGS.dpp_kernel, + model.trainable_variables, FLAGS.use_output_similarity, per_probs) + weights_similarity = be_utils.fast_weights_similarity( + model.trainable_variables, FLAGS.similarity_metric, + FLAGS.dpp_kernel) + outputs_similarity = be_utils.outputs_similarity( + per_probs, FLAGS.similarity_metric, FLAGS.dpp_kernel) + negative_log_likelihood = tf.reduce_mean( - tf.keras.losses.sparse_categorical_crossentropy(labels, - logits, - from_logits=True)) + tf.keras.losses.sparse_categorical_crossentropy( + labels, logits, from_logits=True)) filtered_variables = [] for var in model.trainable_variables: # Apply l2 on the slow weights and bias terms. This excludes BN @@ -239,7 +300,7 @@ def step_fn(inputs): l2_loss = FLAGS.l2 * 2 * tf.nn.l2_loss( tf.concat(filtered_variables, axis=0)) - loss = negative_log_likelihood + l2_loss + loss = negative_log_likelihood + l2_loss + similarity_coeff * similarity_loss # Scale the loss given the TPUStrategy will reduce sum all gradients. scaled_loss = loss / strategy.num_replicas_in_sync @@ -252,14 +313,18 @@ def step_fn(inputs): # Apply different learning rate on the fast weights. This excludes BN # and slow weights, but pay caution to the naming scheme. if ('batch_norm' not in var.name and 'kernel' not in var.name): - grads_and_vars.append((grad * FLAGS.fast_weight_lr_multiplier, - var)) + grads_and_vars.append((grad * FLAGS.fast_weight_lr_multiplier, var)) else: grads_and_vars.append((grad, var)) optimizer.apply_gradients(grads_and_vars) else: optimizer.apply_gradients(zip(grads, model.trainable_variables)) + metrics['train/similarity_loss'].update_state(similarity_coeff * + similarity_loss) + metrics['train/weights_similarity'].update_state(weights_similarity) + metrics['train/outputs_similarity'].update_state(outputs_similarity) + metrics['train/ece'].update_state(labels, probs) metrics['train/loss'].update_state(loss) metrics['train/negative_log_likelihood'].update_state( @@ -273,6 +338,7 @@ def step_fn(inputs): @tf.function def test_step(iterator, dataset_name): """Evaluation StepFn.""" + def step_fn(inputs): """Per-Replica StepFn.""" images, labels = inputs @@ -287,6 +353,8 @@ def step_fn(inputs): probs, tf.concat([[FLAGS.ensemble_size, -1], probs.shape[1:]], 0)) diversity_results = ed.metrics.average_pairwise_diversity( per_probs_tensor, FLAGS.ensemble_size) + outputs_similarity = be_utils.outputs_similarity( + per_probs_tensor, FLAGS.similarity_metric, FLAGS.dpp_kernel) for k, v in diversity_results.items(): test_diversity['test/' + k].update_state(v) @@ -310,6 +378,11 @@ def step_fn(inputs): negative_log_likelihood) metrics['test/accuracy'].update_state(labels, probs) metrics['test/ece'].update_state(labels, probs) + weights_similarity = be_utils.fast_weights_similarity( + model.trainable_variables, FLAGS.similarity_metric, + FLAGS.dpp_kernel) + metrics['test/weights_similarity'].update_state(weights_similarity) + metrics['test/outputs_similarity'].update_state(outputs_similarity) else: corrupt_metrics['test/nll_{}'.format(dataset_name)].update_state( negative_log_likelihood) @@ -334,12 +407,8 @@ def step_fn(inputs): eta_seconds = (max_steps - current_step) / steps_per_sec message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. ' 'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format( - current_step / max_steps, - epoch + 1, - FLAGS.train_epochs, - steps_per_sec, - eta_seconds / 60, - time_elapsed / 60)) + current_step / max_steps, epoch + 1, FLAGS.train_epochs, + steps_per_sec, eta_seconds / 60, time_elapsed / 60)) if step % 20 == 0: logging.info(message) @@ -352,8 +421,7 @@ def step_fn(inputs): logging.info('Testing on dataset %s', dataset_name) for step in range(steps_per_eval): if step % 20 == 0: - logging.info('Starting to run eval step %s of epoch: %s', step, - epoch) + logging.info('Starting to run eval step %s of epoch: %s', step, epoch) test_step(test_iterator, dataset_name) logging.info('Done with testing on %s', dataset_name) @@ -371,15 +439,16 @@ def step_fn(inputs): metrics['test/negative_log_likelihood'].result(), metrics['test/accuracy'].result() * 100) for i in range(FLAGS.ensemble_size): - logging.info('Member %d Test Loss: %.4f, Accuracy: %.2f%%', - i, metrics['test/nll_member_{}'.format(i)].result(), + logging.info('Member %d Test Loss: %.4f, Accuracy: %.2f%%', i, + metrics['test/nll_member_{}'.format(i)].result(), metrics['test/accuracy_member_{}'.format(i)].result() * 100) total_metrics = metrics.copy() total_metrics.update(training_diversity) total_metrics.update(test_diversity) - total_results = {name: metric.result() - for name, metric in total_metrics.items()} + total_results = { + name: metric.result() for name, metric in total_metrics.items() + } total_results.update(corrupt_results) with summary_writer.as_default(): for name, result in total_results.items(): @@ -390,13 +459,14 @@ def step_fn(inputs): if (FLAGS.checkpoint_interval > 0 and (epoch + 1) % FLAGS.checkpoint_interval == 0): - checkpoint_name = checkpoint.save(os.path.join( - FLAGS.output_dir, 'checkpoint')) + checkpoint_name = checkpoint.save( + os.path.join(FLAGS.output_dir, 'checkpoint')) logging.info('Saved checkpoint to %s', checkpoint_name) final_checkpoint_name = checkpoint.save( os.path.join(FLAGS.output_dir, 'checkpoint')) logging.info('Saved last checkpoint to %s', final_checkpoint_name) + if __name__ == '__main__': app.run(main)