Skip to content

Commit

Permalink
Feature: Add bandwidth_decorator for passing kernel_bandwidth as list
Browse files Browse the repository at this point in the history
  • Loading branch information
mjalali committed Nov 7, 2023
1 parent 3d1837e commit 9ecfd7b
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 10 deletions.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,16 @@ In this work, we propose an information-theoretic diversity evaluation method fo

### Installation

Using PIP
Using PyPi

```shell
pip install rke
pip install rke-score
```

Manually
```shell
git clone https://github.com/mjalali/renyi-kernel-entropy-score
python setup.py install
pip install -e .
```

### Example
Expand All @@ -70,8 +70,8 @@ fake_features = np.random.normal(loc=0.0, scale=1.0,
kernel = RKE(kernel_bandwidth=[0.2, 0.3, 0.4])


print(kernel.compute_rke_mc)
print(kernel.compute_rrke)
print(kernel.compute_rke_mc(fake_features))
print(kernel.compute_rrke(real_features, fake_features))
```

### Guide to evaluate your model
Expand Down
41 changes: 36 additions & 5 deletions rke_score/rke.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,23 @@ def gaussian_kernel(x, y, sigma):
return kernel


def bandwidth_decorator(function):
"""
This decorator is only for RKE class is used when the `kernel_bandwidth` is a list.
"""
def wrap_bandwidth_list(self, *args, **kwargs):
output = {}
if self.kernel_bandwidth is not None: # Gaussian kernel
for bandwidth in self.kernel_bandwidth:
self.kernel_function = partial(gaussian_kernel, sigma=bandwidth)
output[bandwidth] = function(self, *args, **kwargs)
else: # Specified kernel
return function(self, *args, **kwargs)

return output
return wrap_bandwidth_list


class RKE:
def __init__(self, kernel_bandwidth=None, kernel_function=None):
"""
Expand All @@ -32,17 +49,30 @@ def __init__(self, kernel_bandwidth=None, kernel_function=None):
"""
if kernel_function is None and kernel_bandwidth is None:
raise ValueError('Expected either kernel_function or kernel_bandwidth args')
if kernel_function is None:
kernel_function = partial(gaussian_kernel, sigma=kernel_bandwidth)
self.kernel_function = kernel_function

def compute_rke_mc_frobenius_norm(self, X, **kwargs):
if kernel_function is not None and kernel_bandwidth is not None:
raise ValueError('`kernel_function` is mutually exclusive with `kernel_bandwidth`')

if kernel_function is None: # Gaussian kernel
# Make `kernel_bandwidth` into a list if the input is float or int
if isinstance(kernel_bandwidth, (float, int)):
self.kernel_bandwidth = [kernel_bandwidth]
else:
self.kernel_bandwidth = kernel_bandwidth
self.kernel_function = partial(gaussian_kernel, sigma=self.kernel_bandwidth[0])

else: # Specified kernel
self.kernel_bandwidth = None
self.kernel_function = kernel_function

@bandwidth_decorator
def compute_rke_mc_frobenius_norm(self, X):
f_norm = 0
for i in range(X.shape[0]):
for j in range(X.shape[0]):
f_norm += self.kernel_function(X[i], X[j])**2
return f_norm / X.shape[0]**2

@bandwidth_decorator
def compute_rke_mc(self, X, n_samples=1_000_000):
"""
Computing RKE-MC = exp(-RKE(X))
Expand All @@ -67,6 +97,7 @@ def __compute_relative_kernel(self, X, Y):
output[i][j] = self.kernel_function(X[i], Y[j])
return output / np.sqrt(X.shape[0] * Y.shape[0])

@bandwidth_decorator
def compute_rrke(self, X, Y, x_samples=500, y_samples=None):
if y_samples is None:
y_samples = x_samples
Expand Down

0 comments on commit 9ecfd7b

Please sign in to comment.