Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add diversity regularization to BatchEnsemble on ImageNet. #346

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 116 additions & 46 deletions baselines/imagenet/batchensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
import edward2 as ed
import batchensemble_model # local file import
import utils # local file import
import tensorflow as tf
from edward2.google.rank1_pert.ensemble_keras import utils as be_utils
import tensorflow.compat.v2 as tf

flags.DEFINE_integer('ensemble_size', 4, 'Size of ensemble.')
flags.DEFINE_integer('per_core_batch_size', 128, 'Batch size per TPU core/GPU.')
Expand All @@ -39,16 +40,18 @@
'fast weights lr multiplier.')
flags.DEFINE_string('data_dir', None, 'Path to training and testing data.')
flags.mark_flag_as_required('data_dir')
flags.DEFINE_string('output_dir', '/tmp/imagenet',
'The directory where the model weights and '
'training/evaluation summaries are stored.')
flags.DEFINE_string(
'output_dir', '/tmp/imagenet', 'The directory where the model weights and '
'training/evaluation summaries are stored.')
flags.DEFINE_integer('train_epochs', 135, 'Number of training epochs.')
flags.DEFINE_integer('corruptions_interval', 135,
'Number of epochs between evaluating on the corrupted '
'test data. Use -1 to never evaluate.')
flags.DEFINE_integer('checkpoint_interval', 27,
'Number of epochs between saving checkpoints. Use -1 to '
'never save checkpoints.')
flags.DEFINE_integer(
'corruptions_interval', 135,
'Number of epochs between evaluating on the corrupted '
'test data. Use -1 to never evaluate.')
flags.DEFINE_integer(
'checkpoint_interval', 27,
'Number of epochs between saving checkpoints. Use -1 to '
'never save checkpoints.')
flags.DEFINE_string('alexnet_errors_path', None,
'Path to AlexNet corruption errors file.')
flags.DEFINE_integer('num_bins', 15, 'Number of bins for ECE computation.')
Expand All @@ -60,6 +63,22 @@
flags.DEFINE_integer('num_cores', 32, 'Number of TPU cores or number of GPUs.')
flags.DEFINE_string('tpu', None,
'Name of the TPU. Only used if use_gpu is False.')
flags.DEFINE_string('similarity_metric', 'cosine', 'Similarity metric in '
'[cosine, dpp_logdet]')
flags.DEFINE_string('dpp_kernel', 'linear', 'Kernel for DPP log determinant')
flags.DEFINE_bool('use_output_similarity', False,
'If true, compute similarity on the ensemble outputs.')
flags.DEFINE_enum('diversity_scheduler', 'LinearAnnealing',
['LinearAnnealing', 'ExponentialDecay', 'Fixed'],
'Diversity coefficient scheduler..')
flags.DEFINE_float('annealing_epochs', 200,
'Number of epochs over which to linearly anneal')
flags.DEFINE_float('diversity_coeff', 0., 'Diversity loss coefficient.')
flags.DEFINE_float('diversity_decay_epoch', 4, 'Diversity decay epoch.')
flags.DEFINE_float('diversity_decay_rate', 0.97, 'Rate of exponential decay.')
flags.DEFINE_integer('diversity_start_epoch', 100,
'Diversity loss starting epoch')

FLAGS = flags.FLAGS

# Number of images in ImageNet-1k train dataset.
Expand All @@ -68,7 +87,7 @@
IMAGENET_VALIDATION_IMAGES = 50000
NUM_CLASSES = 1000

_LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
_LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
(1.0, 5), (0.1, 30), (0.01, 60), (0.001, 80)
]

Expand Down Expand Up @@ -147,22 +166,53 @@ def main(argv):
logging.info('Model number of weights: %s', model.count_params())
# Scale learning rate and decay epochs by vanilla settings.
base_lr = FLAGS.base_learning_rate * batch_size / 256
learning_rate = utils.LearningRateSchedule(steps_per_epoch,
base_lr,
FLAGS.train_epochs,
_LR_SCHEDULE)
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate,
momentum=0.9,
nesterov=True)
learning_rate = utils.LearningRateSchedule(steps_per_epoch, base_lr,
FLAGS.train_epochs, _LR_SCHEDULE)
optimizer = tf.keras.optimizers.SGD(
learning_rate=learning_rate, momentum=0.9, nesterov=True)

if FLAGS.diversity_scheduler == 'ExponentialDecay':
diversity_schedule = be_utils.ExponentialDecay(
initial_coeff=FLAGS.diversity_coeff,
start_epoch=FLAGS.diversity_start_epoch,
decay_epoch=FLAGS.diversity_decay_epoch,
steps_per_epoch=steps_per_epoch,
decay_rate=FLAGS.diversity_decay_rate,
staircase=True)

elif FLAGS.diversity_scheduler == 'LinearAnnealing':
diversity_schedule = be_utils.LinearAnnealing(
initial_coeff=FLAGS.diversity_coeff,
annealing_epochs=FLAGS.annealing_epochs,
steps_per_epoch=steps_per_epoch)
else:
diversity_schedule = lambda x: FLAGS.diversity_coeff

metrics = {
'train/negative_log_likelihood': tf.keras.metrics.Mean(),
'train/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(),
'train/loss': tf.keras.metrics.Mean(),
'train/ece': ed.metrics.ExpectedCalibrationError(
num_bins=FLAGS.num_bins),
'test/negative_log_likelihood': tf.keras.metrics.Mean(),
'test/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(),
'test/ece': ed.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins)
'train/similarity_loss':
tf.keras.metrics.Mean(),
'train/weights_similarity':
tf.keras.metrics.Mean(),
'train/outputs_similarity':
tf.keras.metrics.Mean(),
'train/negative_log_likelihood':
tf.keras.metrics.Mean(),
'train/accuracy':
tf.keras.metrics.SparseCategoricalAccuracy(),
'train/loss':
tf.keras.metrics.Mean(),
'train/ece':
ed.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins),
'test/negative_log_likelihood':
tf.keras.metrics.Mean(),
'test/accuracy':
tf.keras.metrics.SparseCategoricalAccuracy(),
'test/ece':
ed.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins),
'test/weights_similarity':
tf.keras.metrics.Mean(),
'test/outputs_similarity':
tf.keras.metrics.Mean()
}
if FLAGS.corruptions_interval > 0:
corrupt_metrics = {}
Expand Down Expand Up @@ -208,6 +258,7 @@ def main(argv):
@tf.function
def train_step(iterator):
"""Training StepFn."""

def step_fn(inputs):
"""Per-Replica StepFn."""
images, labels = inputs
Expand All @@ -225,10 +276,20 @@ def step_fn(inputs):
diversity_results = ed.metrics.average_pairwise_diversity(
per_probs, FLAGS.ensemble_size)

# print(' > per_probs {}'.format(per_probs))
similarity_coeff, similarity_loss = be_utils.scaled_similarity_loss(
FLAGS.diversity_coeff, diversity_schedule, optimizer.iterations,
FLAGS.similarity_metric, FLAGS.dpp_kernel,
model.trainable_variables, FLAGS.use_output_similarity, per_probs)
weights_similarity = be_utils.fast_weights_similarity(
model.trainable_variables, FLAGS.similarity_metric,
FLAGS.dpp_kernel)
outputs_similarity = be_utils.outputs_similarity(
per_probs, FLAGS.similarity_metric, FLAGS.dpp_kernel)

negative_log_likelihood = tf.reduce_mean(
tf.keras.losses.sparse_categorical_crossentropy(labels,
logits,
from_logits=True))
tf.keras.losses.sparse_categorical_crossentropy(
labels, logits, from_logits=True))
filtered_variables = []
for var in model.trainable_variables:
# Apply l2 on the slow weights and bias terms. This excludes BN
Expand All @@ -239,7 +300,7 @@ def step_fn(inputs):

l2_loss = FLAGS.l2 * 2 * tf.nn.l2_loss(
tf.concat(filtered_variables, axis=0))
loss = negative_log_likelihood + l2_loss
loss = negative_log_likelihood + l2_loss + similarity_coeff * similarity_loss
# Scale the loss given the TPUStrategy will reduce sum all gradients.
scaled_loss = loss / strategy.num_replicas_in_sync

Expand All @@ -252,14 +313,18 @@ def step_fn(inputs):
# Apply different learning rate on the fast weights. This excludes BN
# and slow weights, but pay caution to the naming scheme.
if ('batch_norm' not in var.name and 'kernel' not in var.name):
grads_and_vars.append((grad * FLAGS.fast_weight_lr_multiplier,
var))
grads_and_vars.append((grad * FLAGS.fast_weight_lr_multiplier, var))
else:
grads_and_vars.append((grad, var))
optimizer.apply_gradients(grads_and_vars)
else:
optimizer.apply_gradients(zip(grads, model.trainable_variables))

metrics['train/similarity_loss'].update_state(similarity_coeff *
similarity_loss)
metrics['train/weights_similarity'].update_state(weights_similarity)
metrics['train/outputs_similarity'].update_state(outputs_similarity)

metrics['train/ece'].update_state(labels, probs)
metrics['train/loss'].update_state(loss)
metrics['train/negative_log_likelihood'].update_state(
Expand All @@ -273,6 +338,7 @@ def step_fn(inputs):
@tf.function
def test_step(iterator, dataset_name):
"""Evaluation StepFn."""

def step_fn(inputs):
"""Per-Replica StepFn."""
images, labels = inputs
Expand All @@ -287,6 +353,8 @@ def step_fn(inputs):
probs, tf.concat([[FLAGS.ensemble_size, -1], probs.shape[1:]], 0))
diversity_results = ed.metrics.average_pairwise_diversity(
per_probs_tensor, FLAGS.ensemble_size)
outputs_similarity = be_utils.outputs_similarity(
per_probs_tensor, FLAGS.similarity_metric, FLAGS.dpp_kernel)
for k, v in diversity_results.items():
test_diversity['test/' + k].update_state(v)

Expand All @@ -310,6 +378,11 @@ def step_fn(inputs):
negative_log_likelihood)
metrics['test/accuracy'].update_state(labels, probs)
metrics['test/ece'].update_state(labels, probs)
weights_similarity = be_utils.fast_weights_similarity(
model.trainable_variables, FLAGS.similarity_metric,
FLAGS.dpp_kernel)
metrics['test/weights_similarity'].update_state(weights_similarity)
metrics['test/outputs_similarity'].update_state(outputs_similarity)
else:
corrupt_metrics['test/nll_{}'.format(dataset_name)].update_state(
negative_log_likelihood)
Expand All @@ -334,12 +407,8 @@ def step_fn(inputs):
eta_seconds = (max_steps - current_step) / steps_per_sec
message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. '
'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format(
current_step / max_steps,
epoch + 1,
FLAGS.train_epochs,
steps_per_sec,
eta_seconds / 60,
time_elapsed / 60))
current_step / max_steps, epoch + 1, FLAGS.train_epochs,
steps_per_sec, eta_seconds / 60, time_elapsed / 60))
if step % 20 == 0:
logging.info(message)

Expand All @@ -352,8 +421,7 @@ def step_fn(inputs):
logging.info('Testing on dataset %s', dataset_name)
for step in range(steps_per_eval):
if step % 20 == 0:
logging.info('Starting to run eval step %s of epoch: %s', step,
epoch)
logging.info('Starting to run eval step %s of epoch: %s', step, epoch)
test_step(test_iterator, dataset_name)
logging.info('Done with testing on %s', dataset_name)

Expand All @@ -371,15 +439,16 @@ def step_fn(inputs):
metrics['test/negative_log_likelihood'].result(),
metrics['test/accuracy'].result() * 100)
for i in range(FLAGS.ensemble_size):
logging.info('Member %d Test Loss: %.4f, Accuracy: %.2f%%',
i, metrics['test/nll_member_{}'.format(i)].result(),
logging.info('Member %d Test Loss: %.4f, Accuracy: %.2f%%', i,
metrics['test/nll_member_{}'.format(i)].result(),
metrics['test/accuracy_member_{}'.format(i)].result() * 100)

total_metrics = metrics.copy()
total_metrics.update(training_diversity)
total_metrics.update(test_diversity)
total_results = {name: metric.result()
for name, metric in total_metrics.items()}
total_results = {
name: metric.result() for name, metric in total_metrics.items()
}
total_results.update(corrupt_results)
with summary_writer.as_default():
for name, result in total_results.items():
Expand All @@ -390,13 +459,14 @@ def step_fn(inputs):

if (FLAGS.checkpoint_interval > 0 and
(epoch + 1) % FLAGS.checkpoint_interval == 0):
checkpoint_name = checkpoint.save(os.path.join(
FLAGS.output_dir, 'checkpoint'))
checkpoint_name = checkpoint.save(
os.path.join(FLAGS.output_dir, 'checkpoint'))
logging.info('Saved checkpoint to %s', checkpoint_name)

final_checkpoint_name = checkpoint.save(
os.path.join(FLAGS.output_dir, 'checkpoint'))
logging.info('Saved last checkpoint to %s', final_checkpoint_name)


if __name__ == '__main__':
app.run(main)