diff --git a/examples/basic-1-bin/dummyproblem.ipynb b/examples/basic-1-bin/dummyproblem.ipynb index 953780e..080b8c2 100644 --- a/examples/basic-1-bin/dummyproblem.ipynb +++ b/examples/basic-1-bin/dummyproblem.ipynb @@ -27,7 +27,7 @@ "def yields(x):\n", " s = 15 + x\n", " b = 45 - 2 * x\n", - " db = 1 + 0.2 * x ** 2\n", + " db = 1 + 0.2 * x**2\n", " return jnp.asarray([s]), jnp.asarray([b]), jnp.asarray([db])" ] }, @@ -284,7 +284,7 @@ "x = np.linspace(0, 10)\n", "s = 15 + x\n", "b = 45 - 2 * x\n", - "db = 1 + 0.2 * x ** 2\n", + "db = 1 + 0.2 * x**2\n", "\n", "\n", "def get_cls(s, b, db):\n", diff --git a/neos.ipynb b/neos.ipynb index e7a91a6..1b0d813 100644 --- a/neos.ipynb +++ b/neos.ipynb @@ -93,20 +93,22 @@ "source": [ "from __future__ import annotations\n", "\n", - "import relaxed\n", - "import jax\n", - "from jax import jit\n", - "from chex import Array\n", - "from typing import NamedTuple, Callable, Any\n", - "import pyhf\n", - "from sklearn.model_selection import train_test_split\n", - "from jax.random import PRNGKey, multivariate_normal\n", - "import numpy.random as npr\n", - "import optax\n", - "import jaxopt\n", "import time\n", "from functools import partial\n", "from pprint import pprint\n", + "from typing import Any, Callable, NamedTuple\n", + "\n", + "import jax\n", + "import jaxopt\n", + "import numpy.random as npr\n", + "import optax\n", + "import pyhf\n", + "import relaxed\n", + "from chex import Array\n", + "from jax import jit\n", + "from jax.random import PRNGKey, multivariate_normal\n", + "from sklearn.model_selection import train_test_split\n", + "\n", "\n", "def make_model(s, b_nom, b_up, b_down):\n", " m = {\n", @@ -138,8 +140,10 @@ " }\n", " return pyhf.Model(m, validate=False)\n", "\n", - "def nn_summary_stat(pars, data, nn, bandwidth, bins, reflect=False, sig_scale=2,\n", - " bkg_scale=10, LUMI=10):\n", + "\n", + "def nn_summary_stat(\n", + " pars, data, nn, bandwidth, bins, reflect=False, sig_scale=2, bkg_scale=10, LUMI=10\n", + "):\n", " s_data, b_nom_data, b_up_data, b_down_data = data\n", "\n", " nn_s, nn_b_nom, nn_b_up, nn_b_down = (\n", @@ -151,7 +155,7 @@ "\n", " num_points = len(s_data)\n", "\n", - " yields =s, b_nom, b_up, b_down = [\n", + " yields = s, b_nom, b_up, b_down = [\n", " relaxed.hist(nn_s, bins, bandwidth, reflect_infinities=reflect)\n", " * sig_scale\n", " / num_points\n", @@ -178,7 +182,9 @@ " return yields\n", "\n", "\n", - "@partial(jit, static_argnames=[\"model\", \"return_mle_pars\", \"return_constrained_pars\"]) # forward pass\n", + "@partial(\n", + " jit, static_argnames=[\"model\", \"return_mle_pars\", \"return_constrained_pars\"]\n", + ") # forward pass\n", "def hypotest(\n", " test_poi: float,\n", " data: Array,\n", @@ -189,7 +195,9 @@ ") -> tuple[Array, Array] | Array:\n", " # hard-code 1 as inits for now\n", " # TODO: need to parse different inits for constrained and global fits\n", - " init_pars = jnp.asarray(model.config.suggested_init())[model.config.par_slice('correlated_bkg_uncertainty')]\n", + " init_pars = jnp.asarray(model.config.suggested_init())[\n", + " model.config.par_slice(\"correlated_bkg_uncertainty\")\n", + " ]\n", " conditional_pars = relaxed.mle.fixed_poi_fit(\n", " data, model, poi_condition=test_poi, init_pars=init_pars, lr=lr\n", " )\n", @@ -210,27 +218,34 @@ " else:\n", " return CLs\n", "\n", + "\n", "class Pipeline(NamedTuple):\n", " \"\"\"Class to compose the pipeline for training a learnable summary statistic.\"\"\"\n", + "\n", " yields_from_pars: Callable[..., tuple[Array, ...]]\n", " model_from_yields: Callable[..., pyhf.Model]\n", " init_pars: Array\n", " data: Array | None = None\n", " yield_kwargs: dict[str, Any] | None = None\n", - " nuisance_parname: str = 'correlated_bkg_uncertainty'\n", + " nuisance_parname: str = \"correlated_bkg_uncertainty\"\n", " random_state: int = 0\n", " num_epochs: int = 20\n", " batch_size: int = 500\n", " learning_rate: float = 0.001\n", - " optimizer: str = 'adam'\n", - " loss: Callable[[dict], float] = lambda x: x['CLs']\n", + " optimizer: str = \"adam\"\n", + " loss: Callable[[dict], float] = lambda x: x[\"CLs\"]\n", " test_size: float = 0.2\n", " per_epoch_callback: Callable = lambda x: None\n", - " first_epoch_callback: Callable = lambda x: None \n", - " last_epoch_callback: Callable = lambda x: None \n", - " post_training_callback: Callable = lambda x: None \n", - " plot_setup: Callable = lambda x: None \n", - " possible_metrics: tuple[str, ...] = ('CLs', 'mu_uncert', '1-pull_width**2', 'gaussianity')\n", + " first_epoch_callback: Callable = lambda x: None\n", + " last_epoch_callback: Callable = lambda x: None\n", + " post_training_callback: Callable = lambda x: None\n", + " plot_setup: Callable = lambda x: None\n", + " possible_metrics: tuple[str, ...] = (\n", + " \"CLs\",\n", + " \"mu_uncert\",\n", + " \"1-pull_width**2\",\n", + " \"gaussianity\",\n", + " )\n", " animate: bool = True\n", "\n", " def run(self):\n", @@ -241,15 +256,26 @@ " model = self.model_from_yields(*yields)\n", " state: dict[str, Any] = {}\n", " state[\"yields\"] = yields\n", - " bonly_pars = jnp.asarray(model.config.suggested_init()).at[model.config.poi_index].set(0.0)\n", + " bonly_pars = (\n", + " jnp.asarray(model.config.suggested_init())\n", + " .at[model.config.poi_index]\n", + " .set(0.0)\n", + " )\n", " data = jnp.asarray(model.expected_data(bonly_pars))\n", - " state[\"CLs\"], constrained = hypotest(1.0, data, model, return_constrained_pars=True, bonly_pars=bonly_pars, lr=1e-2)\n", + " state[\"CLs\"], constrained = hypotest(\n", + " 1.0,\n", + " data,\n", + " model,\n", + " return_constrained_pars=True,\n", + " bonly_pars=bonly_pars,\n", + " lr=1e-2,\n", + " )\n", " uncerts = relaxed.cramer_rao_uncert(model, bonly_pars, data)\n", " state[\"mu_uncert\"] = uncerts[model.config.poi_index]\n", " pull_width = uncerts[model.config.par_slice(self.nuisance_parname)][0]\n", " state[\"pull_width\"] = pull_width\n", - " state[\"1-pull_width**2\"] = (1-pull_width) **2\n", - " #state[\"gaussianity\"] = relaxed.gaussianity(model, bonly_pars, data, rng_key=PRNGKey(self.random_state))\n", + " state[\"1-pull_width**2\"] = (1 - pull_width) ** 2\n", + " # state[\"gaussianity\"] = relaxed.gaussianity(model, bonly_pars, data, rng_key=PRNGKey(self.random_state))\n", " state[\"pull\"] = jnp.array(\n", " [\n", " (constrained - jnp.array(model.config.suggested_init()))[\n", @@ -263,12 +289,10 @@ " loss = self.loss(state)\n", " state[\"loss\"] = loss\n", " return loss, state\n", - " \n", + "\n", " if self.data is not None:\n", " split = train_test_split(\n", - " *self.data, \n", - " test_size=self.test_size, \n", - " random_state=self.random_state\n", + " *self.data, test_size=self.test_size, random_state=self.random_state\n", " )\n", " train, test = split[::2], split[1::2]\n", "\n", @@ -282,70 +306,83 @@ " while True:\n", " perm = rng.permutation(num_train)\n", " for i in range(num_batches):\n", - " batch_idx = perm[i * self.batch_size : (i + 1) * self.batch_size]\n", + " batch_idx = perm[\n", + " i * self.batch_size : (i + 1) * self.batch_size\n", + " ]\n", " yield [points[batch_idx] for points in train]\n", "\n", " batches = data_stream()\n", " else:\n", + "\n", " def blank_data():\n", " while True:\n", " yield None\n", + "\n", " batches = blank_data()\n", "\n", - " solver = jaxopt.OptaxSolver(fun=pipeline, opt=optax.adam(self.learning_rate), has_aux=True)\n", + " solver = jaxopt.OptaxSolver(\n", + " fun=pipeline, opt=optax.adam(self.learning_rate), has_aux=True\n", + " )\n", " params, state = solver.init(self.init_pars)\n", "\n", " plot_kwargs = self.plot_setup(self)\n", "\n", " for epoch_num in range(self.num_epochs):\n", " batch_data = next(batches)\n", - " print(f'{epoch_num=}: ', end=\"\")\n", + " print(f\"{epoch_num=}: \", end=\"\")\n", " start = time.perf_counter()\n", " params, state = solver.update(params=params, state=state, data=batch_data)\n", " end = time.perf_counter()\n", - " t = end-start\n", - " print(f'took {t:.4f}s. state:')\n", + " t = end - start\n", + " print(f\"took {t:.4f}s. state:\")\n", " pprint(state.aux)\n", " if epoch_num == 0:\n", " plot_kwargs[\"camera\"] = self.first_epoch_callback(\n", " params,\n", - " this_batch=batch_data, \n", - " metrics=state.aux, \n", - " maxN = self.num_epochs,\n", - " **self.yield_kwargs, \n", - " **plot_kwargs\n", + " this_batch=batch_data,\n", + " metrics=state.aux,\n", + " maxN=self.num_epochs,\n", + " **self.yield_kwargs,\n", + " **plot_kwargs,\n", " )\n", - " elif epoch_num == self.num_epochs-1:\n", + " elif epoch_num == self.num_epochs - 1:\n", " plot_kwargs[\"camera\"] = self.last_epoch_callback(\n", " params,\n", - " this_batch=batch_data, \n", - " metrics=state.aux, \n", - " maxN = self.num_epochs,\n", - " **self.yield_kwargs, \n", - " **plot_kwargs\n", + " this_batch=batch_data,\n", + " metrics=state.aux,\n", + " maxN=self.num_epochs,\n", + " **self.yield_kwargs,\n", + " **plot_kwargs,\n", " )\n", " else:\n", " plot_kwargs[\"camera\"] = self.per_epoch_callback(\n", " params,\n", - " this_batch=batch_data, \n", + " this_batch=batch_data,\n", " metrics=state.aux,\n", - " maxN = self.num_epochs,\n", - " **self.yield_kwargs, \n", - " **plot_kwargs\n", + " maxN=self.num_epochs,\n", + " **self.yield_kwargs,\n", + " **plot_kwargs,\n", " )\n", " if self.animate:\n", - " plot_kwargs[\"camera\"].animate().save(\"animation.gif\", writer=\"imagemagick\", fps=8)\n", + " plot_kwargs[\"camera\"].animate().save(\n", + " \"animation.gif\", writer=\"imagemagick\", fps=8\n", + " )\n", "\n", "\n", - "from jax.example_libraries import stax\n", "import jax.numpy as jnp\n", + "from jax.example_libraries import stax\n", "\n", "rng_state = 0\n", "\n", - "def gen_blobs(rng = PRNGKey(rng_state), num_points=10000, sig_mean=jnp.asarray([-1, 1]),\n", + "\n", + "def gen_blobs(\n", + " rng=PRNGKey(rng_state),\n", + " num_points=10000,\n", + " sig_mean=jnp.asarray([-1, 1]),\n", " bup_mean=jnp.asarray([2.5, 2]),\n", " bdown_mean=jnp.asarray([-2.5, -1.5]),\n", - " b_mean=jnp.asarray([1, -1])):\n", + " b_mean=jnp.asarray([1, -1]),\n", + "):\n", " sig = multivariate_normal(\n", " rng, sig_mean, jnp.asarray([[1, 0], [0, 1]]), shape=(num_points,)\n", " )\n", @@ -355,12 +392,13 @@ " bkg_down = multivariate_normal(\n", " rng, bdown_mean, jnp.asarray([[1, 0], [0, 1]]), shape=(num_points,)\n", " )\n", - " \n", + "\n", " bkg_nom = multivariate_normal(\n", " rng, b_mean, jnp.asarray([[1, 0], [0, 1]]), shape=(num_points,)\n", " )\n", " return sig, bkg_nom, bkg_up, bkg_down\n", - " \n", + "\n", + "\n", "init_random_params, nn = stax.serial(\n", " stax.Dense(1024),\n", " stax.Relu,\n", @@ -375,18 +413,18 @@ "p = Pipeline(\n", " yields_from_pars=nn_summary_stat,\n", " model_from_yields=make_model,\n", - " init_pars=init_pars, \n", + " init_pars=init_pars,\n", " data=gen_blobs(),\n", - " yield_kwargs=dict(nn=nn, bandwidth=1e-1, bins=jnp.linspace(0,1,5)),\n", + " yield_kwargs=dict(nn=nn, bandwidth=1e-1, bins=jnp.linspace(0, 1, 5)),\n", " random_state=rng_state,\n", " loss=lambda x: x[\"CLs\"],\n", " first_epoch_callback=first_epoch,\n", " last_epoch_callback=last_epoch,\n", " per_epoch_callback=per_epoch,\n", " plot_setup=mpl_setup,\n", - " num_epochs=5\n", + " num_epochs=5,\n", ")\n", - "p.run()\n" + "p.run()" ] }, { @@ -399,6 +437,7 @@ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", + "\n", "def make_kde(data, bw):\n", " @jax.jit\n", " def get_kde(x):\n", @@ -408,7 +447,10 @@ "\n", " return get_kde\n", "\n", - "def bar_plot(ax, data, colors=None, total_width=0.8, single_width=1, legend=True, bins=None):\n", + "\n", + "def bar_plot(\n", + " ax, data, colors=None, total_width=0.8, single_width=1, legend=True, bins=None\n", + "):\n", " \"\"\"Draws a bar plot with multiple bars per data point.\n", "\n", " Parameters\n", @@ -482,7 +524,20 @@ " if legend:\n", " ax.legend(bars, data.keys(), fontsize=\"x-small\")\n", "\n", - "def plot(network, axs, axins, metrics, maxN, this_batch, nn, bins, bandwidth, legend=False, reflect=False):\n", + "\n", + "def plot(\n", + " network,\n", + " axs,\n", + " axins,\n", + " metrics,\n", + " maxN,\n", + " this_batch,\n", + " nn,\n", + " bins,\n", + " bandwidth,\n", + " legend=False,\n", + " reflect=False,\n", + "):\n", " ax = axs[\"Data space\"]\n", " g = np.mgrid[-5:5:101j, -5:5:101j]\n", " if jnp.inf in bins:\n", @@ -523,8 +578,8 @@ " ax = axs[\"Losses\"]\n", " # ax.axhline(0.05, c=\"slategray\", linestyle=\"--\")\n", " ax.plot(metrics[\"loss\"], c=\"C9\", linewidth=2.0, label=r\"train $\\log(CL_s)$\")\n", - " #ax.plot(metrics[\"test_loss\"], c=\"C4\", linewidth=2.0, label=r\"test $\\log(CL_s)$\") \n", - " #ax.set_yscale(\"log\")\n", + " # ax.plot(metrics[\"test_loss\"], c=\"C4\", linewidth=2.0, label=r\"test $\\log(CL_s)$\")\n", + " # ax.set_yscale(\"log\")\n", " # ax.set_ylim(1e-4, 0.06)\n", " ax.set_xlim(0, maxN)\n", " ax.set_xlabel(\"epoch\")\n", @@ -540,7 +595,7 @@ " label=r\"$\\sigma_{\\mathsf{nuisance}}$\",\n", " )\n", " ax.plot(metrics[\"mu_uncert\"], c=\"steelblue\", linewidth=2.0, label=r\"$\\sigma_\\mu$\")\n", - " ax.plot(metrics[\"CLs\"], c=\"C9\", linewidth=2, label=r'$CL_s$')\n", + " ax.plot(metrics[\"CLs\"], c=\"C9\", linewidth=2, label=r\"$CL_s$\")\n", " # ax.set_ylim(1e-4, 0.06)\n", " ax.set_xlim(0, maxN)\n", " ax.set_xlabel(\"epoch\")\n", @@ -572,7 +627,7 @@ " total_width=0.8,\n", " single_width=1,\n", " legend=legend,\n", - " bins=bins\n", + " bins=bins,\n", " )\n", " ax.set_ylabel(\"frequency\")\n", " ax.set_xlabel(\"interval over nn output\")\n", @@ -622,7 +677,7 @@ " pbins = bins\n", " ax.stairs(yields, pbins, label=\"KDE hist\", color=\"C1\")\n", " if reflect:\n", - " ax.plot(x, 2*jnp.abs(kde(x)), label=\"KDE\", color=\"C0\")\n", + " ax.plot(x, 2 * jnp.abs(kde(x)), label=\"KDE\", color=\"C0\")\n", " else:\n", " ax.plot(x, kde(x), label=\"KDE\", color=\"C0\")\n", "\n", @@ -644,7 +699,7 @@ "\n", " width = jnp.diff(noinf)[0]\n", " else:\n", - " width = jnp.diff(bins)[0] \n", + " width = jnp.diff(bins)[0]\n", " xlim = (\n", " [(width / 2) - (1.1 * bandwidth), (width / 2) + (1.1 * bandwidth)]\n", " if (width / 2) - bandwidth < 0\n", @@ -689,37 +744,60 @@ "metadata": {}, "outputs": [], "source": [ - "def first_epoch(network, camera, axs, axins, metrics, maxN, this_batch, nn, bins, bandwidth, **kwargs):\n", + "def first_epoch(\n", + " network,\n", + " camera,\n", + " axs,\n", + " axins,\n", + " metrics,\n", + " maxN,\n", + " this_batch,\n", + " nn,\n", + " bins,\n", + " bandwidth,\n", + " **kwargs,\n", + "):\n", " plot(\n", - " axs=axs, \n", - " axins=axins, \n", - " network=network, \n", - " metrics=metrics, \n", - " maxN=maxN, \n", - " this_batch=this_batch, \n", - " nn=nn,\n", - " bins=bins, \n", - " bandwidth=bandwidth, \n", - " legend=True\n", + " axs=axs,\n", + " axins=axins,\n", + " network=network,\n", + " metrics=metrics,\n", + " maxN=maxN,\n", + " this_batch=this_batch,\n", + " nn=nn,\n", + " bins=bins,\n", + " bandwidth=bandwidth,\n", + " legend=True,\n", " )\n", " plt.tight_layout()\n", " camera.snap()\n", " return camera\n", "\n", + "\n", "def last_epoch(\n", - " network, camera, axs, axins, metrics, maxN, this_batch, nn, bins, bandwidth, **kwargs \n", + " network,\n", + " camera,\n", + " axs,\n", + " axins,\n", + " metrics,\n", + " maxN,\n", + " this_batch,\n", + " nn,\n", + " bins,\n", + " bandwidth,\n", + " **kwargs,\n", "):\n", " plot(\n", - " axs=axs, \n", - " axins=axins, \n", - " network=network, \n", - " metrics=metrics, \n", - " maxN=maxN, \n", - " this_batch=this_batch, \n", - " nn=nn,\n", - " bins=bins, \n", - " bandwidth=bandwidth, \n", - " ) \n", + " axs=axs,\n", + " axins=axins,\n", + " network=network,\n", + " metrics=metrics,\n", + " maxN=maxN,\n", + " this_batch=this_batch,\n", + " nn=nn,\n", + " bins=bins,\n", + " bandwidth=bandwidth,\n", + " )\n", " plt.tight_layout()\n", " camera.snap()\n", " fig2, axs2 = plt.subplot_mosaic(\n", @@ -734,40 +812,54 @@ " axins2 = axs2[\"Example KDE\"].inset_axes([0.01, 0.79, 0.3, 0.2])\n", " axins2.axis(\"off\")\n", " plot(\n", - " axs=axs2, \n", - " axins=axins2, \n", - " network=network, \n", - " metrics=metrics, \n", - " maxN=maxN, \n", - " this_batch=this_batch, \n", - " nn=nn,\n", - " bins=bins, \n", - " bandwidth=bandwidth,\n", - " legend=True \n", - " ) \n", - " plt.tight_layout()\n", - " fig2.savefig(\n", - " f\"random.pdf\"\n", + " axs=axs2,\n", + " axins=axins2,\n", + " network=network,\n", + " metrics=metrics,\n", + " maxN=maxN,\n", + " this_batch=this_batch,\n", + " nn=nn,\n", + " bins=bins,\n", + " bandwidth=bandwidth,\n", + " legend=True,\n", " )\n", + " plt.tight_layout()\n", + " fig2.savefig(f\"random.pdf\")\n", " return camera\n", "\n", - "def per_epoch(network, camera, axs, axins, metrics, maxN, this_batch, nn, bins, bandwidth, **kwargs):\n", + "\n", + "def per_epoch(\n", + " network,\n", + " camera,\n", + " axs,\n", + " axins,\n", + " metrics,\n", + " maxN,\n", + " this_batch,\n", + " nn,\n", + " bins,\n", + " bandwidth,\n", + " **kwargs,\n", + "):\n", " plot(\n", - " axs=axs, \n", - " axins=axins, \n", - " network=network, \n", - " metrics=metrics, \n", - " maxN=maxN, \n", - " this_batch=this_batch, \n", - " nn=nn,\n", - " bins=bins, \n", - " bandwidth=bandwidth, \n", + " axs=axs,\n", + " axins=axins,\n", + " network=network,\n", + " metrics=metrics,\n", + " maxN=maxN,\n", + " this_batch=this_batch,\n", + " nn=nn,\n", + " bins=bins,\n", + " bandwidth=bandwidth,\n", " )\n", " plt.tight_layout()\n", " camera.snap()\n", " return camera\n", "\n", + "\n", "from celluloid import Camera\n", + "\n", + "\n", "def mpl_setup(pipeline):\n", " plt.style.use(\"default\")\n", "\n", @@ -802,7 +894,9 @@ " axins_cpy = axins\n", " if pipeline.animate:\n", " camera = Camera(fig)\n", - " return dict(camera=camera, axs=axs, axins=axins, ax_cpy=ax_cpy, axins_cpy=axins_cpy, fig=fig)" + " return dict(\n", + " camera=camera, axs=axs, axins=axins, ax_cpy=ax_cpy, axins_cpy=axins_cpy, fig=fig\n", + " )" ] }, {