diff --git a/edward2/jax/nn/random_feature.py b/edward2/jax/nn/random_feature.py index b10f97cb..92b1e9ee 100644 --- a/edward2/jax/nn/random_feature.py +++ b/edward2/jax/nn/random_feature.py @@ -27,6 +27,9 @@ [3]: Ali Rahimi and Benjamin Recht. Random Features for Large-Scale Kernel Machines. In _Neural Information Processing Systems_, 2007. https://people.eecs.berkeley.edu/~brecht/papers/07.rah.rec.nips.pdf +[4]: Zhiyun Lu, Eugene Ie, Fei Sha. Uncertainty Estimation with Infinitesimal + Jackknife. _arXiv preprint arXiv:2006.07584_, 2020. + https://arxiv.org/abs/2006.07584 """ import dataclasses import functools @@ -47,8 +50,13 @@ # Default config for random features. default_rbf_activation = jnp.cos -default_rbf_kernel_init = nn.initializers.normal(stddev=1.) default_rbf_bias_init = nn.initializers.uniform(scale=2. * jnp.pi) +# Using "he_normal" style random feature distribution. Effectively, this is +# equivalent to approximating a RBF kernel but with the input standardized by +# its dimensionality (i.e., input_scaled = input * sqrt(2. / dim_input)) and +# empirically leads to better performance for neural network inputs. +default_rbf_kernel_init = nn.initializers.variance_scaling( + scale=2.0, mode='fan_in', distribution='normal') # Default field value for kwargs, to be used for data class declaration. default_kwarg_dict = lambda: dataclasses.field(default_factory=dict) @@ -149,7 +157,7 @@ class RandomFourierFeatures(nn.Module): dtype: the dtype of the computation (default: float32). """ features: int - feature_scale: Optional[jnp.float32] = None + feature_scale: Optional[jnp.float32] = 1. activation: Callable[[Array], Array] = default_rbf_activation kernel_init: Initializer = default_rbf_kernel_init bias_init: Initializer = default_rbf_bias_init diff --git a/edward2/jax/nn/random_feature_test.py b/edward2/jax/nn/random_feature_test.py index 3d2f2f33..b1d34040 100644 --- a/edward2/jax/nn/random_feature_test.py +++ b/edward2/jax/nn/random_feature_test.py @@ -21,6 +21,8 @@ import edward2.jax as ed +import flax.linen as nn + import jax import jax.numpy as jnp import numpy as np @@ -94,6 +96,10 @@ def setUp(self): self.x_test = _generate_normal_data( self.num_test_sample, self.num_data_dim, seed=21) + # Uses classic RBF random feature distribution. + self.hidden_kwargs = dict( + kernel_init=nn.initializers.normal(stddev=1.), feature_scale=None) + self.rbf_approx_maximum_tol = 5e-3 self.rbf_approx_average_tol = 5e-4 self.primal_dual_maximum_diff = 1e-6 @@ -105,6 +111,7 @@ def one_step_rfgp_result(self, train_data, test_data, **eval_kwargs): features=1, hidden_features=self.num_random_features, normalize_input=False, + hidden_kwargs=self.hidden_kwargs, covmat_kwargs=dict(ridge_penalty=self.ridge_penalty)) # Computes posterior covariance on test data. @@ -231,13 +238,19 @@ def setUp(self): self.x_test = _generate_normal_data(self.num_train_sample, self.num_data_dim) + # Uses classic RBF random feature distribution. + self.hidden_kwargs = dict( + kernel_init=nn.initializers.normal(stddev=1.), feature_scale=None) + self.kernel_approx_tolerance = dict(atol=5e-2, rtol=1e-2) def test_random_feature_mutable_collection(self): """Tests if RFF variables are properly nested under a mutable collection.""" rng = jax.random.PRNGKey(self.seed) rff_layer = ed.nn.RandomFourierFeatures( - features=self.num_random_features, collection_name=self.collection_name) + features=self.num_random_features, + collection_name=self.collection_name, + **self.hidden_kwargs) # Computes forward pass with mutable collection specified. init_vars = rff_layer.init(rng, self.x_train) @@ -260,7 +273,8 @@ def test_random_feature_mutable_collection(self): def test_random_feature_nd_input(self, input_shape): rng = jax.random.PRNGKey(self.seed) x = jnp.ones(input_shape) - rff_layer = ed.nn.RandomFourierFeatures(features=self.num_random_features) + rff_layer = ed.nn.RandomFourierFeatures( + features=self.num_random_features, **self.hidden_kwargs) y, _ = rff_layer.init_with_output(rng, x) expected_output_shape = input_shape[:-1] + (self.num_random_features,) @@ -270,7 +284,9 @@ def test_random_feature_kernel_approximation(self): """Tests if default RFF layer approximates a RBF kernel matrix.""" rng = jax.random.PRNGKey(self.seed) rff_layer = ed.nn.RandomFourierFeatures( - features=self.num_random_features, collection_name=self.collection_name) + features=self.num_random_features, + collection_name=self.collection_name, + **self.hidden_kwargs) # Extracts random features by computing forward pass. init_vars = rff_layer.init(rng, self.x_train) diff --git a/edward2/jax/nn/utils.py b/edward2/jax/nn/utils.py index cf7d405d..157e6e59 100644 --- a/edward2/jax/nn/utils.py +++ b/edward2/jax/nn/utils.py @@ -15,11 +15,12 @@ """JAX layer and utils.""" -from typing import Iterable, Callable +from typing import Callable, Iterable, Optional from jax import random import jax.numpy as jnp +Array = jnp.ndarray DType = type(jnp.float32) InitializeFn = Callable[[jnp.ndarray, Iterable[int], DType], jnp.ndarray] @@ -48,3 +49,55 @@ def initializer(key, shape, dtype=jnp.float32): x = random.normal(key, shape, dtype) * (-random_sign_init) + 1.0 return x.astype(dtype) return initializer + + +def mean_field_logits(logits: Array, + covmat: Optional[Array] = None, + mean_field_factor: float = 1., + likelihood: str = 'logistic'): + """Adjust the model logits so its softmax approximates the posterior mean [4]. + + Arguments: + logits: A float ndarray of shape (batch_size, num_classes). + covmat: A float ndarray of shape (batch_size, ). If None then it is assumed + to be a vector of 1.'s. + mean_field_factor: The scale factor for mean-field approximation, used to + adjust the influence of posterior variance in posterior mean + approximation. If covmat=None then it is used as the scaling parameter for + temperature scaling. + likelihood: name of the likelihood for integration in Gaussian-approximated + latent posterior. Must be one of ('logistic', 'binary_logistic', + 'poisson'). + + Returns: + A float ndarray of uncertainty-adjusted logits, shape + (batch_size, num_classes). + + Raises: + (ValueError) If likelihood is not one of ('logistic', 'binary_logistic', + 'poisson'). + """ + if likelihood not in ('logistic', 'binary_logistic', 'poisson'): + raise ValueError( + f'Likelihood" must be one of (\'logistic\', \'binary_logistic\', \'poisson\'), got {likelihood}.' + ) + + if mean_field_factor < 0: + return logits + + # Defines predictive variance. + variances = 1. if covmat is None else covmat + + # Computes scaling coefficient for mean-field approximation. + if likelihood == 'poisson': + logits_scale = jnp.exp(-variances * mean_field_factor / 2.) # pylint:disable=invalid-unary-operand-type + else: + logits_scale = jnp.sqrt(1. + variances * mean_field_factor) + + # Pads logits_scale to compatible dimension. + while logits_scale.ndim < logits.ndim: + logits_scale = jnp.expand_dims(logits_scale, axis=-1) + + return logits / logits_scale + + diff --git a/edward2/jax/nn/utils_test.py b/edward2/jax/nn/utils_test.py new file mode 100644 index 00000000..8eebf0d3 --- /dev/null +++ b/edward2/jax/nn/utils_test.py @@ -0,0 +1,75 @@ +# coding=utf-8 +# Copyright 2021 The Edward2 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for utils.""" +from absl.testing import absltest +from absl.testing import parameterized + +import edward2.jax as ed + +import jax +import jax.numpy as jnp + +import numpy as np +import tensorflow as tf + + +class MeanFieldLogitsTest(parameterized.TestCase, tf.test.TestCase): + + def testMeanFieldLogitsLikelihood(self): + """Tests if scaling is correct under different likelihood.""" + batch_size = 10 + num_classes = 12 + variance = 1.5 + mean_field_factor = 2. + + rng_key = jax.random.PRNGKey(0) + logits = jax.random.normal(rng_key, (batch_size, num_classes)) + covmat = jnp.ones(batch_size) * variance + + logits_logistic = ed.nn.utils.mean_field_logits( + logits, covmat, mean_field_factor=mean_field_factor) + logits_poisson = ed.nn.utils.mean_field_logits( + logits, + covmat, + mean_field_factor=mean_field_factor, + likelihood='poisson') + + self.assertAllClose(logits_logistic, logits / 2., atol=1e-4) + self.assertAllClose(logits_poisson, logits * np.exp(1.5), atol=1e-4) + + def testMeanFieldLogitsTemperatureScaling(self): + """Tests using mean_field_logits as temperature scaling method.""" + batch_size = 10 + num_classes = 12 + + rng_key = jax.random.PRNGKey(0) + logits = jax.random.normal(rng_key, (batch_size, num_classes)) + + # Test if there's no change to logits when mean_field_factor < 0. + logits_no_change = ed.nn.utils.mean_field_logits( + logits, covmat=None, mean_field_factor=-1) + + # Test if mean_field_logits functions as a temperature scaling method when + # mean_field_factor > 0, with temperature = sqrt(1. + mean_field_factor). + logits_scale_by_two = ed.nn.utils.mean_field_logits( + logits, covmat=None, mean_field_factor=3.) + + self.assertAllClose(logits_no_change, logits, atol=1e-4) + self.assertAllClose(logits_scale_by_two, logits / 2., atol=1e-4) + + +if __name__ == '__main__': + absltest.main()