Skip to content

Commit

Permalink
Removes unnecessary ViT-GP hyper-parameters.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 388484029
  • Loading branch information
jereliu authored and edward-bot committed Aug 9, 2021
1 parent 194a984 commit 7438002
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 6 deletions.
12 changes: 10 additions & 2 deletions edward2/jax/nn/random_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
[3]: Ali Rahimi and Benjamin Recht. Random Features for Large-Scale Kernel
Machines. In _Neural Information Processing Systems_, 2007.
https://people.eecs.berkeley.edu/~brecht/papers/07.rah.rec.nips.pdf
[4]: Zhiyun Lu, Eugene Ie, Fei Sha. Uncertainty Estimation with Infinitesimal
Jackknife. _arXiv preprint arXiv:2006.07584_, 2020.
https://arxiv.org/abs/2006.07584
"""
import dataclasses
import functools
Expand All @@ -47,8 +50,13 @@

# Default config for random features.
default_rbf_activation = jnp.cos
default_rbf_kernel_init = nn.initializers.normal(stddev=1.)
default_rbf_bias_init = nn.initializers.uniform(scale=2. * jnp.pi)
# Using "he_normal" style random feature distribution. Effectively, this is
# equivalent to approximating a RBF kernel but with the input standardized by
# its dimensionality (i.e., input_scaled = input * sqrt(2. / dim_input)) and
# empirically leads to better performance for neural network inputs.
default_rbf_kernel_init = nn.initializers.variance_scaling(
scale=2.0, mode='fan_in', distribution='normal')

# Default field value for kwargs, to be used for data class declaration.
default_kwarg_dict = lambda: dataclasses.field(default_factory=dict)
Expand Down Expand Up @@ -149,7 +157,7 @@ class RandomFourierFeatures(nn.Module):
dtype: the dtype of the computation (default: float32).
"""
features: int
feature_scale: Optional[jnp.float32] = None
feature_scale: Optional[jnp.float32] = 1.
activation: Callable[[Array], Array] = default_rbf_activation
kernel_init: Initializer = default_rbf_kernel_init
bias_init: Initializer = default_rbf_bias_init
Expand Down
22 changes: 19 additions & 3 deletions edward2/jax/nn/random_feature_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

import edward2.jax as ed

import flax.linen as nn

import jax
import jax.numpy as jnp
import numpy as np
Expand Down Expand Up @@ -94,6 +96,10 @@ def setUp(self):
self.x_test = _generate_normal_data(
self.num_test_sample, self.num_data_dim, seed=21)

# Uses classic RBF random feature distribution.
self.hidden_kwargs = dict(
kernel_init=nn.initializers.normal(stddev=1.), feature_scale=None)

self.rbf_approx_maximum_tol = 5e-3
self.rbf_approx_average_tol = 5e-4
self.primal_dual_maximum_diff = 1e-6
Expand All @@ -105,6 +111,7 @@ def one_step_rfgp_result(self, train_data, test_data, **eval_kwargs):
features=1,
hidden_features=self.num_random_features,
normalize_input=False,
hidden_kwargs=self.hidden_kwargs,
covmat_kwargs=dict(ridge_penalty=self.ridge_penalty))

# Computes posterior covariance on test data.
Expand Down Expand Up @@ -231,13 +238,19 @@ def setUp(self):
self.x_test = _generate_normal_data(self.num_train_sample,
self.num_data_dim)

# Uses classic RBF random feature distribution.
self.hidden_kwargs = dict(
kernel_init=nn.initializers.normal(stddev=1.), feature_scale=None)

self.kernel_approx_tolerance = dict(atol=5e-2, rtol=1e-2)

def test_random_feature_mutable_collection(self):
"""Tests if RFF variables are properly nested under a mutable collection."""
rng = jax.random.PRNGKey(self.seed)
rff_layer = ed.nn.RandomFourierFeatures(
features=self.num_random_features, collection_name=self.collection_name)
features=self.num_random_features,
collection_name=self.collection_name,
**self.hidden_kwargs)

# Computes forward pass with mutable collection specified.
init_vars = rff_layer.init(rng, self.x_train)
Expand All @@ -260,7 +273,8 @@ def test_random_feature_mutable_collection(self):
def test_random_feature_nd_input(self, input_shape):
rng = jax.random.PRNGKey(self.seed)
x = jnp.ones(input_shape)
rff_layer = ed.nn.RandomFourierFeatures(features=self.num_random_features)
rff_layer = ed.nn.RandomFourierFeatures(
features=self.num_random_features, **self.hidden_kwargs)
y, _ = rff_layer.init_with_output(rng, x)

expected_output_shape = input_shape[:-1] + (self.num_random_features,)
Expand All @@ -270,7 +284,9 @@ def test_random_feature_kernel_approximation(self):
"""Tests if default RFF layer approximates a RBF kernel matrix."""
rng = jax.random.PRNGKey(self.seed)
rff_layer = ed.nn.RandomFourierFeatures(
features=self.num_random_features, collection_name=self.collection_name)
features=self.num_random_features,
collection_name=self.collection_name,
**self.hidden_kwargs)

# Extracts random features by computing forward pass.
init_vars = rff_layer.init(rng, self.x_train)
Expand Down
55 changes: 54 additions & 1 deletion edward2/jax/nn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@

"""JAX layer and utils."""

from typing import Iterable, Callable
from typing import Callable, Iterable, Optional

from jax import random
import jax.numpy as jnp

Array = jnp.ndarray
DType = type(jnp.float32)
InitializeFn = Callable[[jnp.ndarray, Iterable[int], DType], jnp.ndarray]

Expand Down Expand Up @@ -48,3 +49,55 @@ def initializer(key, shape, dtype=jnp.float32):
x = random.normal(key, shape, dtype) * (-random_sign_init) + 1.0
return x.astype(dtype)
return initializer


def mean_field_logits(logits: Array,
covmat: Optional[Array] = None,
mean_field_factor: float = 1.,
likelihood: str = 'logistic'):
"""Adjust the model logits so its softmax approximates the posterior mean [4].
Arguments:
logits: A float ndarray of shape (batch_size, num_classes).
covmat: A float ndarray of shape (batch_size, ). If None then it is assumed
to be a vector of 1.'s.
mean_field_factor: The scale factor for mean-field approximation, used to
adjust the influence of posterior variance in posterior mean
approximation. If covmat=None then it is used as the scaling parameter for
temperature scaling.
likelihood: name of the likelihood for integration in Gaussian-approximated
latent posterior. Must be one of ('logistic', 'binary_logistic',
'poisson').
Returns:
A float ndarray of uncertainty-adjusted logits, shape
(batch_size, num_classes).
Raises:
(ValueError) If likelihood is not one of ('logistic', 'binary_logistic',
'poisson').
"""
if likelihood not in ('logistic', 'binary_logistic', 'poisson'):
raise ValueError(
f'Likelihood" must be one of (\'logistic\', \'binary_logistic\', \'poisson\'), got {likelihood}.'
)

if mean_field_factor < 0:
return logits

# Defines predictive variance.
variances = 1. if covmat is None else covmat

# Computes scaling coefficient for mean-field approximation.
if likelihood == 'poisson':
logits_scale = jnp.exp(-variances * mean_field_factor / 2.) # pylint:disable=invalid-unary-operand-type
else:
logits_scale = jnp.sqrt(1. + variances * mean_field_factor)

# Pads logits_scale to compatible dimension.
while logits_scale.ndim < logits.ndim:
logits_scale = jnp.expand_dims(logits_scale, axis=-1)

return logits / logits_scale


75 changes: 75 additions & 0 deletions edward2/jax/nn/utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# coding=utf-8
# Copyright 2021 The Edward2 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for utils."""
from absl.testing import absltest
from absl.testing import parameterized

import edward2.jax as ed

import jax
import jax.numpy as jnp

import numpy as np
import tensorflow as tf


class MeanFieldLogitsTest(parameterized.TestCase, tf.test.TestCase):

def testMeanFieldLogitsLikelihood(self):
"""Tests if scaling is correct under different likelihood."""
batch_size = 10
num_classes = 12
variance = 1.5
mean_field_factor = 2.

rng_key = jax.random.PRNGKey(0)
logits = jax.random.normal(rng_key, (batch_size, num_classes))
covmat = jnp.ones(batch_size) * variance

logits_logistic = ed.nn.utils.mean_field_logits(
logits, covmat, mean_field_factor=mean_field_factor)
logits_poisson = ed.nn.utils.mean_field_logits(
logits,
covmat,
mean_field_factor=mean_field_factor,
likelihood='poisson')

self.assertAllClose(logits_logistic, logits / 2., atol=1e-4)
self.assertAllClose(logits_poisson, logits * np.exp(1.5), atol=1e-4)

def testMeanFieldLogitsTemperatureScaling(self):
"""Tests using mean_field_logits as temperature scaling method."""
batch_size = 10
num_classes = 12

rng_key = jax.random.PRNGKey(0)
logits = jax.random.normal(rng_key, (batch_size, num_classes))

# Test if there's no change to logits when mean_field_factor < 0.
logits_no_change = ed.nn.utils.mean_field_logits(
logits, covmat=None, mean_field_factor=-1)

# Test if mean_field_logits functions as a temperature scaling method when
# mean_field_factor > 0, with temperature = sqrt(1. + mean_field_factor).
logits_scale_by_two = ed.nn.utils.mean_field_logits(
logits, covmat=None, mean_field_factor=3.)

self.assertAllClose(logits_no_change, logits, atol=1e-4)
self.assertAllClose(logits_scale_by_two, logits / 2., atol=1e-4)


if __name__ == '__main__':
absltest.main()

0 comments on commit 7438002

Please sign in to comment.