Skip to content

Commit

Permalink
More realistic mock experiment (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfschubert authored Nov 7, 2023
1 parent 7efc97d commit 1595afd
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 25 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ tests = [
"pytest",
"pytest-cov",
"pytest-subtests",
"invrs-opt",
]
dev = [
"bump-my-version",
Expand Down
97 changes: 72 additions & 25 deletions tests/experiment/test_mock_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
import time
import unittest


NUM_WORK_UNITS = 40


def run_analysis(experiment_path, timeout):
from invrs_utils.experiment import data

df = None
done = False
max_time = time.time() + timeout
while not done and time.time() < max_time:
Expand All @@ -34,10 +34,11 @@ def run_analysis(experiment_path, timeout):
return df


def run_experiment(experiment_path, workers):
def run_experiment(experiment_path, workers, steps):
from invrs_utils.experiment import sweep

sweeps = sweep.product(
sweep.sweep("steps", [steps]),
sweep.sweep("seed", [0, 1, 2, 3, 4]),
sweep.product(
sweep.zip(
Expand All @@ -61,7 +62,7 @@ def _run_work_unit(path_and_kwargs):
run_work_unit(wid_path, **kwargs)


def run_work_unit(wid_path, seed, a, b, c):
def run_work_unit(wid_path, seed, steps, a, b, c):
print(f"Launching {wid_path}")
if not os.path.exists(wid_path):
os.makedirs(wid_path)
Expand All @@ -70,8 +71,12 @@ def run_work_unit(wid_path, seed, a, b, c):
with open(wid_path + "/setup.json", "w") as f:
json.dump(work_unit_config, f, indent=4)

import invrs_opt
import jax
import jax.numpy as jnp
from jax import random
from jax import random, tree_util
from totypes import types

from invrs_utils.experiment import checkpoint

mngr = checkpoint.CheckpointManager(
Expand All @@ -80,34 +85,61 @@ def run_work_unit(wid_path, seed, a, b, c):
max_to_keep=1,
)

dummy_loss = []
dummy_distance = []

key = random.PRNGKey(seed)
for i in range(100):
dummy_loss.append(jnp.exp(-(a + b + c) * i / 10))
dummy_distance.append(jnp.exp(-(a + b + c) * i / 10) - 0.1)
dummy_params = random.uniform(random.fold_in(key, i), shape=(50, 50))
def dummy_loss(x):
leaves = tree_util.tree_leaves(x)
leaf_sums = tree_util.tree_map(lambda leaf: jnp.sum(jnp.abs(leaf) ** 2), leaves)
return jnp.sum(jnp.asarray(tree_util.tree_leaves(leaf_sums)))

opt = invrs_opt.lbfgsb()

if mngr.latest_step() is None:
key = random.PRNGKey(seed)
k1, k2 = random.split(key)
params = {
"bounded_array": types.BoundedArray(
array=random.normal(k1, (20, 20)),
lower_bound=-2,
upper_bound=2,
),
"density2d": types.Density2DArray(
array=random.normal(k2, shape=(30, 40)),
lower_bound=0,
upper_bound=2,
),
}
state = opt.init(params)
scalars = {}
latest_step = -1
else:
ckpt = mngr.restore(mngr.latest_step())
params = ckpt["params"]
state = ckpt["state"]
scalars = ckpt["scalars"]
latest_step = mngr.latest_step()

def _log_scalar(name, value):
if name not in scalars:
scalars[name] = jnp.zeros((0,))
scalars[name] = jnp.concatenate([scalars[name], jnp.asarray([value])])

for i in range(latest_step + 1, steps):
params = opt.params(state)
value, grad = jax.value_and_grad(dummy_loss)(params)
state = opt.update(grad=grad, value=value, params=params, state=state)
_log_scalar("loss", value)
_log_scalar("distance", value - 0.01)
mngr.save(
step=i,
pytree={
"scalars": {
"loss": jnp.asarray(dummy_loss),
"distance": jnp.asarray(dummy_distance),
},
"params": dummy_params,
},
pytree={"scalars": scalars, "params": params, "state": state},
force_save=False,
)

mngr.save(
step=i,
pytree={
"scalars": {"loss": dummy_loss, "distance": dummy_distance},
"params": dummy_params,
},
pytree={"scalars": scalars, "params": params, "state": state},
force_save=True,
)

with open(wid_path + "/completed.txt", "w") as f:
os.utime(wid_path, None)
print(f"Completed {wid_path}")
Expand All @@ -120,12 +152,27 @@ def test_mock_experiment(self):
# spawn multiple workers that carry out the experiment.
p = mp.Process(
target=run_experiment,
kwargs={"experiment_path": tmpdir, "workers": 5},
kwargs={"experiment_path": tmpdir, "workers": 5, "steps": 50},
)
p.start()

# Run the analysis. This will repeatedly summarize the experiment, and
# return once all work units have been finished.
df = run_analysis(experiment_path=tmpdir, timeout=100)
p.join()
self.assertIsNotNone(df)
self.assertEqual(len(df), NUM_WORK_UNITS)
self.assertTrue((df["wid.latest_step"] == 49).all())

self.assertEqual(len(df), NUM_WORK_UNITS)
# Relaunch the experiment, which runs all work units for a few more steps.
p = mp.Process(
target=run_experiment,
kwargs={"experiment_path": tmpdir, "workers": 5, "steps": 100},
)
p.start()
p.join()

df = run_analysis(experiment_path=tmpdir, timeout=100)
self.assertIsNotNone(df)
self.assertEqual(len(df), NUM_WORK_UNITS)
self.assertTrue((df["wid.latest_step"] == 99).all())

0 comments on commit 1595afd

Please sign in to comment.