Skip to content

Commit

Permalink
Merge pull request #30 from simon-hirsch/batch_gram_updates
Browse files Browse the repository at this point in the history
Add batch updates for Gramians
  • Loading branch information
simon-hirsch authored Dec 6, 2024
2 parents a6e65a4 + 215ad52 commit 38bd9ce
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 15 deletions.
16 changes: 8 additions & 8 deletions example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
"cells": [
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 1,
"id": "e37ae52f-fa91-4128-9580-d12c61984cc7",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.1.8\n"
"0.1.9\n"
]
}
],
Expand All @@ -35,7 +35,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 2,
"id": "1fd422ca-82aa-44d2-8287-5bfcae8c6d4e",
"metadata": {},
"outputs": [],
Expand All @@ -56,7 +56,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 3,
"id": "558ca51e-a311-4d4b-87b5-17986b364ed3",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -102,7 +102,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 4,
"id": "2d8d950d-3749-4c86-8f05-8082b9e25ae4",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -163,7 +163,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 5,
"id": "7e8a74f8-f170-440e-a27f-fa3a14c0b924",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -203,7 +203,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 6,
"id": "3a3740ca-2c78-427a-b1d1-b4e2a5abcdd7",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -238,7 +238,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "rolch310",
"language": "python",
"name": "python3"
},
Expand Down
48 changes: 41 additions & 7 deletions src/rolch/gram.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,15 @@ def update_gram(
Returns:
np.ndarray: Updated Gramian Matrix.
"""
new_gram = (1 - forget) * gram + w * np.outer(X, X)
if X.shape[0] == 1:
# Single Step Update
new_gram = (1 - forget) * gram + w * np.outer(X, X)
else:
# Batch Update
batch_size = X.shape[0]
f = init_forget_vector(size=batch_size, forget=forget)
weights = np.expand_dims((w * f) ** 0.5, axis=-1)
new_gram = gram * (1 - forget) ** batch_size + (X * weights).T @ (X * weights)
return new_gram


Expand All @@ -163,11 +171,31 @@ def update_y_gram(
Returns:
np.ndarray: Updated Gramian Matrix.
"""
new_gram = (1 - forget) * gram + w * np.outer(X, y)
if X.shape[0] == 1:
# Single Update
new_gram = (1 - forget) * gram + w * np.outer(X, y)
else:
# Batch update
batch_size = X.shape[0]
f = init_forget_vector(size=batch_size, forget=forget)
new_gram = gram * (1 - forget) ** batch_size + np.expand_dims(
((X * np.expand_dims((w * f) ** 0.5, axis=-1)).T @ (y * (w * f) ** 0.5)), -1
)
return new_gram


@nb.jit()
def _update_inverted_gram(
gram: np.ndarray, X: np.ndarray, forget: float = 0, w: float = 1
) -> np.ndarray:
"""Update the inverted Gramian for one step"""
gamma = 1 - forget
new_gram = (1 / gamma) * (
gram - ((w * gram @ np.outer(X, X) @ gram) / (gamma + w * X @ gram @ X.T))
)
return new_gram


@nb.njit()
def update_inverted_gram(
gram: np.ndarray, X: np.ndarray, forget: float = 0, w: float = 1
) -> np.ndarray:
Expand All @@ -185,8 +213,14 @@ def update_inverted_gram(
Returns:
np.ndarray: Updated inverted Gramian matrix.
"""
gamma = 1 - forget
new_gram = (1 / gamma) * (
gram - ((w * gram @ np.outer(X, X) @ gram) / (gamma + w * X @ gram @ X.T))
)
if X.shape[0] == 1:
new_gram = _update_inverted_gram(gram, X, forget=forget, w=w)
else:
new_gram = _update_inverted_gram(
gram, X=np.expand_dims(X[0, :], 0), forget=forget, w=w[0]
)
for i in range(1, X.shape[0]):
new_gram = _update_inverted_gram(
gram=new_gram, X=np.expand_dims(X[i, :], 0), forget=forget, w=w[i]
)
return new_gram
125 changes: 125 additions & 0 deletions tests/test_gram.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import numpy as np
import pytest
import scipy.stats as st

from rolch.gram import (
init_gram,
init_inverted_gram,
init_y_gram,
update_gram,
update_inverted_gram,
update_y_gram,
)


def make_x_y_w(N, D, random_weights=True):
X = st.multivariate_normal().rvs((N, D))
if D == 1:
X = X.reshape(-1, 1)
y = st.multivariate_normal().rvs((N, 1))
if random_weights:
w = st.uniform().rvs(N)
else:
w = np.ones(N)
return X, y, w


N = [100, 1000]
D = [1, 2, 10]
RANDOM_WEIGHTS = [True, False]
FORGET = [0, 0.0001, 0.001, 0.01, 0.1]
BATCH_SIZE = [10, 25]


@pytest.mark.parametrize("N", N)
@pytest.mark.parametrize("D", D)
@pytest.mark.parametrize("random_weights", RANDOM_WEIGHTS)
@pytest.mark.parametrize("forget", FORGET)
def test_single_update_x_gram(N, D, random_weights, forget):
X, _, w = make_x_y_w(N, D, random_weights=random_weights)
gram_start = init_gram(X[:-1], w[:-1], forget)
gram_final = init_gram(X, w, forget)
assert np.allclose(
gram_final, update_gram(gram_start, X[[-1]], forget=forget, w=w[-1])
)


@pytest.mark.parametrize("N", N)
@pytest.mark.parametrize("D", D)
@pytest.mark.parametrize("random_weights", RANDOM_WEIGHTS)
@pytest.mark.parametrize("forget", FORGET)
@pytest.mark.parametrize("batchsize", BATCH_SIZE)
def test_batch_update_x_gram(N, D, random_weights, forget, batchsize):
X, _, w = make_x_y_w(N, D, random_weights=random_weights)
gram_start = init_gram(X[:-batchsize], w[:-batchsize], forget)
gram_final = init_gram(X, w, forget)
assert np.allclose(
gram_final,
update_gram(gram_start, X[-batchsize:, :], forget=forget, w=w[-batchsize:]),
)


# INVERTED GRAM
@pytest.mark.parametrize("N", N)
@pytest.mark.parametrize("D", D)
@pytest.mark.parametrize("random_weights", RANDOM_WEIGHTS)
@pytest.mark.parametrize("forget", FORGET)
def test_single_update_inv_gram(N, D, random_weights, forget):
X, _, w = make_x_y_w(N, D, random_weights=random_weights)
gram_start = init_inverted_gram(X[:-1], w[:-1], forget)
gram_final = init_inverted_gram(X, w, forget)
assert np.allclose(
gram_final, update_inverted_gram(gram_start, X[[-1]], forget=forget, w=w[-1])
)


@pytest.mark.parametrize("N", N)
@pytest.mark.parametrize("D", D)
@pytest.mark.parametrize("random_weights", RANDOM_WEIGHTS)
@pytest.mark.parametrize("forget", FORGET)
@pytest.mark.parametrize("batchsize", BATCH_SIZE)
def test_batch_update_inv_gram(N, D, random_weights, forget, batchsize):
X, _, w = make_x_y_w(N, D, random_weights=random_weights)
gram_start = init_inverted_gram(X[:-batchsize], w[:-batchsize], forget)
gram_final = init_inverted_gram(X, w, forget)
assert np.allclose(
gram_final,
update_inverted_gram(
gram_start, X[-batchsize:, :], forget=forget, w=w[-batchsize:]
),
)


# Y-GRAM
@pytest.mark.parametrize("N", N)
@pytest.mark.parametrize("D", D)
@pytest.mark.parametrize("random_weights", RANDOM_WEIGHTS)
@pytest.mark.parametrize("forget", FORGET)
def test_single_update_y_gram(N, D, random_weights, forget):
X, y, w = make_x_y_w(N, D, random_weights=random_weights)
gram_start = init_y_gram(X[:-1], y[:-1], w[:-1], forget)
gram_final = init_y_gram(X, y, w, forget)
assert np.allclose(
gram_final, update_y_gram(gram_start, X[[-1]], y[[-1]], forget=forget, w=w[-1])
)


@pytest.mark.parametrize("N", N)
@pytest.mark.parametrize("D", D)
@pytest.mark.parametrize("random_weights", RANDOM_WEIGHTS)
@pytest.mark.parametrize("forget", FORGET)
@pytest.mark.parametrize("batchsize", BATCH_SIZE)
def test_batch_update_y_gram(N, D, random_weights, forget, batchsize):
X, y, w = make_x_y_w(N, D, random_weights=random_weights)
gram_start = init_y_gram(X[:-batchsize], y[:-batchsize], w[:-batchsize], forget)
gram_final = init_y_gram(X, y, w, forget)
assert np.allclose(
gram_final,
update_y_gram(
gram_start,
X[-batchsize:, :],
y[-batchsize:],
forget=forget,
w=w[-batchsize:],
),
)

0 comments on commit 38bd9ce

Please sign in to comment.