-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' of github.com:tonywu71/conditional-neural-processes
- Loading branch information
Showing
10 changed files
with
483 additions
and
341 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
191 changes: 191 additions & 0 deletions
191
dataloader/dataloader_for_plotting/load_regression_data_from_arbitrary_gp_varying_kernel.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
from functools import partial | ||
from typing import Optional, Tuple, Callable, Iterator | ||
|
||
import tensorflow as tf | ||
import tensorflow_probability as tfp | ||
|
||
tfd = tfp.distributions | ||
|
||
from dataloader.dataloader_for_plotting.regression_data_generator_base import RegressionDataGeneratorBase | ||
|
||
|
||
def gen_from_arbitrary_gp( | ||
batch_size: int, | ||
iterations: int, | ||
min_kernel_length_scale: float, | ||
max_kernel_length_scale: float, | ||
min_num_context: int, | ||
max_num_context: int, | ||
min_num_target: int, | ||
max_num_target: int, | ||
min_x_val_uniform: float, | ||
max_x_val_uniform: float, | ||
testing: bool): | ||
"""Generates a batch of data for regression based on the original Conditional Neural Processes paper. | ||
Note that the data is generated batch-wise. | ||
During training and for each batch: | ||
- Both num_context and num_target are drawn from uniform distributions | ||
- The (num_context + num_target) x_values are drawn from a uniform distribution | ||
- A Gaussian Process with predefined kernel and a null mean function is used to generate the y_values from the x_values | ||
""" | ||
|
||
for _ in range(iterations): | ||
# NB: The distribution of y_values is the same for each iteration (i.e. the the one defined by | ||
# the arbitrary GP) but the sampled x_values do differ (in terms of size and values). | ||
num_context = tf.random.uniform(shape=[], | ||
minval=min_num_context, | ||
maxval=max_num_context, | ||
dtype=tf.int32) | ||
|
||
if not testing: | ||
num_target = tf.random.uniform(shape=[], | ||
minval=min_num_target, | ||
maxval=max_num_target, | ||
dtype=tf.int32) | ||
else: | ||
# If testing, we want to use a fixed number of points for the target | ||
num_target = max_num_target - 1 # -1 because max_num_target is exclusive | ||
|
||
num_total_points = num_context + num_target | ||
|
||
x_values = tf.random.uniform(shape=(batch_size, num_total_points, 1), | ||
minval=min_x_val_uniform, # type: ignore | ||
maxval=max_x_val_uniform) | ||
|
||
|
||
# Set kernel length scale: | ||
l1 = tf.random.uniform(shape=[], | ||
minval=min_kernel_length_scale, # type: ignore | ||
maxval=max_kernel_length_scale, | ||
dtype=tf.dtypes.float32) | ||
|
||
l2 = tf.random.uniform(shape=[], | ||
minval=min_kernel_length_scale, # type: ignore | ||
maxval=max_kernel_length_scale, | ||
dtype=tf.dtypes.float32) | ||
|
||
|
||
# Varying kernel: | ||
kernel_1 = tfp.math.psd_kernels.ExponentiatedQuadratic(length_scale=l1) | ||
kernel_2 = tfp.math.psd_kernels.ExponentiatedQuadratic(length_scale=l2) | ||
|
||
n_samples_1 = tf.random.uniform(shape=[], minval=2, maxval=num_total_points-1, dtype=tf.int32) # both splits will have at least one sample | ||
|
||
# Sort x_values: | ||
x_values = tf.sort(x_values, axis=1) | ||
|
||
# Split x_values into two parts such that the first part has n_samples_1 points: | ||
x_values_1 = x_values[:, :n_samples_1, :] | ||
x_values_2 = x_values[:, n_samples_1:, :] | ||
|
||
|
||
gp_1 = tfd.GaussianProcess(kernel_1, index_points=x_values_1, jitter=1.0e-4) | ||
y_values_1 = tf.expand_dims(gp_1.sample(), axis=-1) | ||
|
||
gp_2 = tfd.GaussianProcess(kernel_2, index_points=x_values_2, jitter=1.0e-4) | ||
|
||
gp_2 = tfd.GaussianProcessRegressionModel( | ||
kernel=kernel_2, | ||
index_points=x_values_2[:], | ||
observation_index_points=x_values_1[:, -1:, :], | ||
observations=y_values_1[:, -1:, 0], | ||
observation_noise_variance=1.0e-4) | ||
|
||
y_values_2 = tf.expand_dims(gp_2.sample(), axis=-1) | ||
|
||
y_values = tf.concat([y_values_1, y_values_2], axis=1) | ||
|
||
idx = tf.random.shuffle(tf.range(num_total_points)) | ||
|
||
# Select the targets which will consist of the context points | ||
# as well as some new target points | ||
target_x = x_values[:, :, :] | ||
target_y = y_values[:, :, :] # type: ignore | ||
|
||
# Select the observations | ||
context_x = tf.gather(x_values, indices=idx[:num_context], axis=1) | ||
context_y = tf.gather(y_values, indices=idx[:num_context], axis=1) | ||
|
||
if all(tf.shape(context_x) != tf.shape(context_y)): | ||
continue | ||
if all(tf.shape(target_x) != tf.shape(target_y)): | ||
continue | ||
if tf.shape(context_x)[-1] != tf.shape(target_x)[-1]: | ||
continue | ||
|
||
yield (context_x, context_y, target_x), target_y, l1, l2 | ||
|
||
|
||
class RegressionDataGeneratorArbitraryGPWithVaryingKernel(RegressionDataGeneratorBase): | ||
"""Class that generates a batch of data for regression based on | ||
the original Conditional Neural Processes paper.""" | ||
def __init__(self, | ||
iterations: int=250, | ||
batch_size: int=32, | ||
min_num_context: int=3, | ||
max_num_context: int=10, | ||
min_num_target: int=2, | ||
max_num_target: int=10, | ||
min_x_val_uniform: int=-2, | ||
max_x_val_uniform: int=2, | ||
n_iterations_test: Optional[int]=None, | ||
min_kernel_length_scale: float=0.1, | ||
max_kernel_length_scale: float=1.): | ||
super().__init__(iterations=iterations, | ||
batch_size=batch_size, | ||
min_num_context=min_num_context, | ||
max_num_context=max_num_context, | ||
min_num_target=min_num_target, | ||
max_num_target=max_num_target, | ||
min_x_val_uniform=min_x_val_uniform, | ||
max_x_val_uniform=max_x_val_uniform, | ||
n_iterations_test=n_iterations_test) | ||
|
||
self.min_kernel_length_scale = min_kernel_length_scale | ||
self.max_kernel_length_scale = max_kernel_length_scale | ||
|
||
self.train_ds, self.test_ds = self.load_regression_data() | ||
|
||
|
||
def get_gp_curve_generator(self, testing: bool=False) -> Callable: | ||
"""Returns a function that generates a batch of data for regression based on | ||
the original Conditional Neural Processes paper.""" | ||
return partial(gen_from_arbitrary_gp, | ||
batch_size=self.batch_size, | ||
iterations=self.iterations, | ||
min_kernel_length_scale=self.min_kernel_length_scale, | ||
max_kernel_length_scale=self.max_kernel_length_scale, | ||
min_num_context=self.min_num_context, | ||
max_num_context=self.max_num_context, | ||
min_num_target=self.min_num_target, | ||
max_num_target=self.max_num_target, | ||
min_x_val_uniform=self.min_x_val_uniform, | ||
max_x_val_uniform=self.max_x_val_uniform, | ||
testing=testing) | ||
|
||
|
||
def draw_single_example_from_arbitrary_gp(min_kernel_length_scale, max_kernel_length_scale, num_context, num_target): | ||
data_generator = RegressionDataGeneratorArbitraryGPWithVaryingKernel( | ||
iterations=1, | ||
n_iterations_test=1, | ||
batch_size=1, | ||
min_num_context=num_context-1, | ||
max_num_context=num_context, | ||
min_num_target=num_target-1, | ||
max_num_target=num_target, | ||
min_x_val_uniform=-2, | ||
max_x_val_uniform=2, | ||
min_kernel_length_scale=min_kernel_length_scale, | ||
max_kernel_length_scale=max_kernel_length_scale | ||
) | ||
|
||
train_ds, test_ds = data_generator.load_regression_data() | ||
(context_x, context_y, target_x), target_y, l1, l2 = next(iter(test_ds)) | ||
|
||
context_x = tf.squeeze(context_x, axis=0) | ||
context_y = tf.squeeze(context_y, axis=0) | ||
target_x = tf.squeeze(target_x, axis=0) | ||
target_y = tf.squeeze(target_y, axis=0) | ||
|
||
return (context_x, context_y, target_x), target_y, l1, l2 |
98 changes: 98 additions & 0 deletions
98
dataloader/dataloader_for_plotting/regression_data_generator_base.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
from typing import Callable, Optional, Tuple | ||
from abc import ABC, abstractmethod | ||
|
||
import tensorflow as tf | ||
import tensorflow_probability as tfp | ||
|
||
tfd = tfp.distributions | ||
|
||
import matplotlib.pyplot as plt | ||
|
||
|
||
class RegressionDataGeneratorBase(ABC): | ||
"""Abstract base class for regression data generators.""" | ||
def __init__(self, | ||
iterations: int, | ||
batch_size: int, | ||
min_num_context: int, | ||
max_num_context: int, | ||
min_num_target: int, | ||
max_num_target: int, | ||
min_x_val_uniform: int, | ||
max_x_val_uniform: int, | ||
n_iterations_test: Optional[int]=None): | ||
self.iterations = iterations | ||
self.batch_size = batch_size | ||
|
||
assert min_num_context < max_num_context, "min_num_context must be smaller than max_num_context" | ||
self.min_num_context = min_num_context | ||
self.max_num_context = max_num_context | ||
|
||
assert min_num_target < max_num_target, "min_num_target must be smaller than max_num_target" | ||
self.min_num_target = min_num_target | ||
self.max_num_target = max_num_target | ||
|
||
assert min_x_val_uniform < max_x_val_uniform, "min_val_uniform must be smaller than max_val_uniform" | ||
self.min_x_val_uniform = min_x_val_uniform | ||
self.max_x_val_uniform = max_x_val_uniform | ||
|
||
if n_iterations_test is None: | ||
self.n_iterations_test = self.iterations // 10 | ||
else: | ||
self.n_iterations_test = n_iterations_test | ||
|
||
# The following attributes will be set when calling load_regression_data() from | ||
# the child class: | ||
self.train_ds: tf.data.Dataset = None | ||
self.test_ds: tf.data.Dataset = None | ||
|
||
|
||
@abstractmethod | ||
def get_gp_curve_generator(self, testing: bool=False) -> Callable: | ||
"""Returns a generator function that generates regression data from a Gaussian Process.""" | ||
pass | ||
|
||
|
||
def load_regression_data(self) -> Tuple[tf.data.Dataset, tf.data.Dataset]: | ||
"""Returns a tuple of training and test datasets.""" | ||
train_ds = tf.data.Dataset.from_generator( | ||
self.get_gp_curve_generator(testing=False), | ||
output_types=((tf.float32, tf.float32, tf.float32), tf.float32, tf.float32, tf.float32) | ||
) | ||
test_ds = tf.data.Dataset.from_generator( | ||
self.get_gp_curve_generator(testing=True), | ||
output_types=((tf.float32, tf.float32, tf.float32), tf.float32, tf.float32, tf.float32) | ||
) | ||
|
||
train_ds = train_ds.prefetch(tf.data.experimental.AUTOTUNE) # No need to shuffle as the data is already generated randomly | ||
test_ds = test_ds.prefetch(tf.data.experimental.AUTOTUNE) | ||
|
||
return train_ds, test_ds | ||
|
||
|
||
@staticmethod | ||
def plot_first_elt_of_batch(context_x, context_y, target_x, target_y, | ||
ax: Optional[plt.Axes]=None, | ||
figsize=(8, 5)): | ||
"""Plot the first element of a batch.""" | ||
|
||
if ax is None: | ||
fig, ax = plt.subplots(figsize=figsize) | ||
|
||
context_x = context_x.numpy() | ||
context_y = context_y.numpy() | ||
target_x = target_x.numpy() | ||
target_y = target_y.numpy() | ||
|
||
ax.scatter(target_x[0, :, 0], target_y[0, :, 0], c="blue", label='Target') | ||
ax.scatter(context_x[0, :, 0], context_y[0, :, 0], marker="x", c="red", label='Observations') | ||
ax.legend() | ||
|
||
return ax | ||
|
||
|
||
def plot_first_elt_of_random_batch(self, figsize=(8, 5)): | ||
"""Plot a random batch from the training set.""" | ||
(context_x, context_y, target_x), target_y, l1, l2 = next(iter(self.train_ds.take(1))) | ||
ax = RegressionDataGeneratorBase.plot_first_elt_of_batch(context_x, context_y, target_x, target_y, figsize=figsize) | ||
return ax |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Oops, something went wrong.