-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdiffprog-hep.qmd
1358 lines (979 loc) · 107 KB
/
diffprog-hep.qmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
---
execute:
echo: false
format:
html:
default-image-extension: png
pdf:
default-image-extension: pdf
jupyter: python3
---
# Data Analysis in High-Energy Physics as a Differentiable Program
This is the title track of this thesis, and rightly so; it dominated metrics in both my time spent and headspace given for any of the topics I've written about. I feel incredibly privileged to have worked on something like this, which is fairly self-contained, and draws upon themes from both machine learning and statistical inference in order to make headway in addressing a long-standing issue: *systematic-aware optimization*. What's even cooler is that it goes further than this, opening up a whole variety of possibilities to optimize with the whole statistical inference procedure in the loop, and rethink the ways in which we can improve our workflows. I hope you enjoy it!
## Motivation
Given the success of the Standard Model, analysis of data from the LHC usually occurs for two reasons:
- Precisely measuring Standard Model processes to look for small deviations from their predicted values
- Searching for new physics signatures as predicted by models beyond the Standard Model
When analyzing data in this way, we'll have lots of free parameters to tune. These can be as simple as a threshold value that you limit the $p_T$ to, or as complicated as the weights and biases that determine a neural network for identifying $b$-jets. We can of course choose any values for these quantities to do our analysis, but the resulting physics that follows may suffer as a result. As such, we're likely to try some kind of optimization to improve the answers to our physics questions. How do we do this in practice?
In either case above, there is a notion of <span style="color:#13becf">signal</span> (what you’re looking for) and <span style="color:#ff7f0e">background</span> (everything else).
Generally, we then try to choose a parameter configuration that can separate (or discriminate) the signal from the background, allowing us to extract just the data we think is relevant to the physics process we're looking at. As an example, machine learning models are often trained using the **binary cross-entropy** loss as an objective, which corresponds to optimizing the ability of the model to identify whether an event originated from signal or background processes. A closely related goal is the **Asimov significance** in the case of signal and background event counts $s$ and $b$ with *no uncertainty* on either quantity. The formula for this stems from assuming a Poisson likelihood function as in @sec-hifa, and is equal to
$$
Z_A = \sqrt{2\sum_{i\in bins}((s_i + b_i)(\log{(1 + s_i / b_i)}) - s_i)}~.
$$ {#eq-asimov-significance}
As indicated in the sum, these counts can be spread across different bins in the case where your data is a histogram, but the formula is more commonly reduced to the 1-bin scenario that just deals with the overall numbers of signal and background events. In this case, we can then Taylor expand the logarithm to get
$$Z_A = \sqrt{2((s+b)(s/b + \mathcal{O}(s/b) - s)} \approx s/\sqrt{b}~~~\mathrm{for}~s<<b.$$
This makes it much clearer to see that optimizing with respect to $Z_A$ is just a fancier way of trying to increase the amount of signal compared to the amount of background, which is directly analogous to separating signal from background, just as binary cross-entropy would do.
Now, this is all very sensible of course (we want to discover our signal), but this approach has some shortcomings that distance the efficacy of the resulting configuration from our physics goals. A recent review of deep learning in LHC physics [@deeplhc] lets us in on why:
> (...) tools are often optimized for performance on a particular task that is **several steps removed from the ultimate physical goal** of searching for a new particle or testing a new physical theory.
> (...) sensitivity to high-level physics questions **must account for systematic uncertainties**, which involve a nonlinear trade-off between the typical machine learning performance metrics and the systematic uncertainty estimates.
This is the crux of the issue: we're not accounting for uncertainty. Our data analysis process comes with many sources of systematic error, which we endeavour to model in the likelihood function as nuisance parameters. However, optimizing with respect to any of the above quantities isn't going to be aware of that process. We need something better.
Okay, I hear you: blah blah this is all just talk... let's prove this scenario needs addressing with an example!
### A simplified analysis example, both with and without uncertainty {#sec-simple-anal}
Let's define an analysis with a predicted number of signal and background events (e.g. from simulation), with some uncertainty on the background estimate. We'll abstract the analysis configuration into a single parameter $\phi$ like so:
$$s = 15 + \phi $$
$$b = 45 - 2 \phi $$
$$\sigma_b = 0.5 + 0.1\phi^2 $$
Note that $s \propto \phi$ and $\propto -2\phi$, so increasing $\phi$ corresponds to increasing the signal/backround ratio. However, our uncertainty scales like $\phi^2$, so we're also going to compromise in our certainty of the background count as we do that. This kind of tradeoff between $s/b$ ratio and uncertainty is important for the discovery of a new signal, so it may be that can't get away with optimizing $s/b$ alone, as the $p$-value may be worse!
Let's start by visualizing the model itself, which we do for three values of $\phi$ as an example in @fig-simple-model.
```{python}
#| label: fig-simple-model
#| fig-cap: "Plot of the predicted counts from our model at three values of $\\phi$."
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import optax
from jaxopt import OptaxSolver
import relaxed
from functools import partial
import matplotlib.lines as mlines
from plothelp import autogrid
subplot_settings = dict(figsize=[7, 3], dpi=150, tight_layout=True)
# model definition
def yields(phi, uncertainty=True):
s = 15 + phi
b = 45 - 2 * phi
db = (
0.5 + 0.1 * phi**2 if uncertainty else jnp.zeros_like(phi) + 0.001
) # small enough to be negligible
return jnp.asarray([s]), jnp.asarray([b]), jnp.asarray([db])
# just plotting code
def yield_plot(dct):
ax, phi, i = list(dct.values())
s, b, db = yields(phi)
s, b, db = s.ravel(), b.ravel(), db.ravel() # everything is [[x]] for pyhf
ax.set_ylim((0, 80))
b1 = ax.bar(0.5, b, facecolor="C1", label="b")
b2 = ax.bar(0.5, s, bottom=b, facecolor="C9", label="s")
b3 = ax.bar(
0.5, db, bottom=b - db / 2, facecolor="k", alpha=0.5, label=r"$\sigma_b$"
)
ax.set_title(r"$\phi = $" + f'{phi}')
ax.set_xlabel("x")
if i ==0 :
ax.set_ylabel("yield")
ax.set_xticks([])
if i==2:
ax.legend([b1, b2, b3], ["b", "s", r"$\sigma_b$"], frameon=False, bbox_to_anchor=(1.04, 0.5), loc="center left", borderaxespad=0)
autogrid(
[0,5,10],
yield_plot,
subplot_kwargs={**subplot_settings, **dict(sharex=True, sharey=True)},
);
```
Using this very simple histogram, we can form a statistical model as if we're using @sec-hifa principles, which would look something like
$$
p(x | \mu) = \mathrm{Poisson}(x | \mu x^{\mathrm{sig}} + \gamma x^{\mathrm{bkg}})\,\mathrm{Normal}(y | \gamma, 1)~,
$$ {#eq-simplemodel}
where $\gamma$ is a continuous description of $\sigma_b$ that we get from interpolating between the yields, just like in the HistFactory approach, which has the constraint term $\mathrm{Normal}(y | \gamma, 1)$ attached to penalize fitting a value of $\gamma$ that differs largely from the information provided by $\sigma_b$.
Using this likelihood, we can calculate the expected discovery $p$-value by doing a hypothesis test using the observed data as the Asimov dataset for the nominal model $\mu, \gamma = 1$. We can plot this across all the values of $\phi$, and see what value gives us the lowest $p$-value (in practice, scanning over the space is computationally impossible for a given analysis configuration and a complicated model). We do this in @fig-simple-model-pval, where we include the result using a model both with and without uncertainty. Notice how much the curves differ; if we optimized the model without uncertainty (i.e. optimize for signal/background separation only), we'd end up at the *worst* solution! This is pathologically constructed of course, but it goes to show that these objectives don't talk to each other directly.
```{python}
#| label: fig-simple-model-pval
#| fig-cap: "Plot of the calculated $p$-value from using our statistical model across of $\\phi$, both including the uncertainty and neglecting it."
# our analysis pipeline, from phi to p-value
def pipeline(phi, return_yields=False, uncertainty=True):
# grab the yields at the value of phi we want
y = yields(phi, uncertainty=uncertainty)
# use a dummy version of pyhf for simplicity + compatibility with jax
model = relaxed.dummy_pyhf.uncorrelated_background(*y)
# calculate expected discovery significance
nominal_pars = jnp.array([1.0, 1.0]) # sets gamma, mu =1 in gamma*b + mu*s
data = model.expected_data(nominal_pars) # Asimov data
# do the hypothesis test (and fit model pars with gradient descent)
pvalue = relaxed.infer.hypotest(
0.0, # value of mu for the alternative hypothesis (background-only)
data,
model,
test_stat="q0", # discovery significance test
lr=1e-3, # learning rate for the minimization loop
expected_pars=nominal_pars, # optionally providing MLE pars in advance
)
if return_yields:
return pvalue, y
else:
return pvalue
# calculate p-values for a range of phi values
phis = jnp.linspace(0, 10, 100)
# with uncertainty
pipe = partial(pipeline, return_yields=True, uncertainty=True)
pvals, ys = jax.vmap(pipe)(phis) # map over phi grid
# without uncertainty
pipe_no_uncertainty = partial(pipeline, uncertainty=False)
pvals_no_uncertainty = jax.vmap(pipe_no_uncertainty)(phis)
fig, ax = plt.subplots(**subplot_settings)
axs = [ax]
axs[0].plot(phis, pvals, label="with uncertainty", color="C0")
axs[0].plot(phis, pvals_no_uncertainty, label="no uncertainty", color="C2")
axs[0].set_ylabel("$p$-value")
# plot vertical dotted line at minimum of p-values + s/b
best_phi = phis[jnp.argmin(pvals)]
axs[0].axvline(x=best_phi, linestyle="dotted", color="C2", label="optimal p-value")
axs[0].axvline(
x=phis[jnp.argmin(pvals_no_uncertainty)],
linestyle="dotted",
color="C4",
label=r"optimal $s/b$",
)
axs[0].legend(loc="upper left", ncol=2)
axs[0].set_xlabel("$\phi$")
plt.suptitle("Discovery p-values, with and without uncertainty")
plt.tight_layout()
```
If we optimize this analysis then, we want to arrive at the value of $\phi$ at the dotted green line (around ~4.3 or so), which gives us the benefit of rejecting the background hypothesis more strongly when the signal exists in the data. This is made possible if we use the $p$-value as our objective -- it clearly accounts for the uncertainty!
The reason for this makes sense: in these physics likelihoods, we're careful to include all the details of the systematic uncertainties that we're able to quantify by constructing nuisance parameters that vary the shape and normalization of the model. From here, to calculate the $p$-value, we then construct the **profile likelihood ratio** as a test statistic, which accounts for these systematic uncertainties by fitting the value of the nuisance parameters depending on the hypothesis you test (see @sec-hyptests for more).
All this makes the $p$-value seem like a good candidate for an objective function! So why haven't we used this already?
As emphasized in @sec-gradient-descent, if we want to perform optimization using gradient-based methods,^[We don't have to use gradient based methods! They're just very well implemented and studied, as well as enabling things like this paradigm.] then we need the objective that we optimize to be *differentiable*. This is not immediately the case for the $p$-value -- we would have to be able to differentiate through all stages of the full calculation, including model building, profiling, and even histograms, which are not generally known for their smoothness. But say we were able to decompose this complicated pipeline into bite-size chunks, each of which we can find a way to take gradients of. What becomes possible then? This begins our view of **data analysis in high-energy physics as a differentiable program**.
In the following sections, we'll take a collider physics analysis apart step-by-step, then see how we can employ tricks and substitutes to recover gradients for each piece. After that, we'll explore the ways that we can use the result to perform gradient-based optimization of different parts of the analysis with respect to physics goals. We'll then do it all at once by *optimizing a toy physics analysis from end-to-end*, exploring the common example of a summary statistic based on a neural network, accounting for uncertainties all the while.
<!-- ## Related work
The only real analog to what's done in this section is INFERNO [@inferno], a similar method for inference aware optimization. In terms of -->
<!-- ### How do we optimize in an uncertainty-aware way?
Attempts:
- Asimov sig with assumptions on bkg uncert: [@asimovuncert]
- Learning to pivot: [@pivot]
- Directly incorporate NPs: [@uncert] -->
## Making HEP Analysis Differentiable
The goal of this section is to study components within a HEP analysis chain that are not typically differentiable, and show that when we overcome this, we can employ the use of gradient-based optimization methods -- both to optimize free parameters jointly, and to use objectives we care about. From there, we'll examine the typical steps needed to calculate the sensitivity of a physics analysis, and see how we can make that whole chain differentiable at once, opening up a way to incorporate the full inference procedure when finding the best analysis configuration.
First, we're going to jump right in with an example to illustrate how we can take advantage of gradient descent to optimize a typical problem faced in collider physics analyses: choosing the best selection criteria.
### A simple example: cut optimization with gradient descent
We begin with a toy signal and background distribution over some variable $x$, where the signal lies as a peak on top of an exponentially decaying background, as shown in @fig-exp-bkg.
```{python}
#| label: fig-exp-bkg
#| fig-cap: "Histogram of a situation with a simple exponentially falling background and a small signal peak."
import relaxed
from functools import partial
from jax.random import PRNGKey
# generate background data from an exponential distribution with a little noise
def generate_background(key, n_samples, n_features, noise_std):
key, subkey = jax.random.split(key, 2)
data = jax.random.exponential(subkey, (n_samples, n_features))
key, subkey = jax.random.split(key, 2)
data += jax.random.normal(subkey, (n_samples, n_features)) * noise_std
return data
# generate signal data from a normal distribution close to the background
def generate_signal(key, n_samples, n_features):
key, subkey = jax.random.split(key, 2)
data = jax.random.normal(subkey, (n_samples, n_features)) / 2 + 2
return data
# get 1000 samples from the background and 100 samples from the signal
bkg = generate_background(PRNGKey(0), 1000, 1, 0.1).ravel()
sig = generate_signal(PRNGKey(1), 100, 1).ravel()
sig = sig[sig>0]
bkg = bkg[bkg>0]
fig, ax = plt.subplots(**subplot_settings)
# plot!
ax.hist(
[bkg, sig], stacked=True, bins=30, histtype="step", label=["background", "signal"]
)
ax.set_xlabel("x")
ax.set_ylabel("count")
ax.legend();
```
A quintessential operation for data filtering in HEP is the simple threshold, also called a **cut**: we keep all data above (or below) a certain value of the quantity we're concerned with. To increase the significance (e.g. as defined by @eq-asimov-significance), we can try to remove data such that we increase the overall ratio of signal to background. In @fig-exp-bkg, it looks like there's not much signal for low values of $x$, which motivates us to put a cut at say $x=1$. We can see the result of applying this cut in @fig-compare-cut, where we've increased the Asimov significance compared to using no cut at all.
```{python}
#| label: fig-compare-cut
#| fig-cap: "Comparing the significance resulting from applying a cut to no cut at all."
def significance_after_cut(cut):
# treat analysis as a one-bin counting experiment
s = len(sig[sig > cut]) + 1e-1
b = len(bkg[bkg > cut]) + 1e-1
return relaxed.metrics.asimov_sig(s, b) # stat-only significance
cut = 1 # change me to change the plot!
def make_cut_plot(cut, ax):
significance = significance_after_cut(cut)
ax.hist(
[bkg, sig], stacked=True, bins=30, histtype="step", label=["background", "signal"]
)
ax.axvline(x=cut, color="k", linestyle="--", alpha=0.5, label=f"cut = {cut:.2f}")
ax.axvspan(0,cut, hatch='//', color="grey", alpha=0.3,zorder=-999)
ax.text(
0.7,
0.2,
f"significance = {significance:.2f}",
ha="center",
va="center",
transform=ax.transAxes,
)
ax.set_xlabel("x")
ax.legend()
fig, axs = plt.subplots(1,2,**subplot_settings, sharey=True)
ax = axs[0]
# plot!
ax.hist(
[bkg, sig], stacked=True, bins=30, histtype="step", label=["background", "signal"]
)
ax.set_xlabel("x")
ax.set_ylabel("count")
significance = significance_after_cut(0)
ax.text(
0.7,
0.2,
f"significance = {significance:.2f}",
ha="center",
va="center",
transform=ax.transAxes,
)
ax.legend()
make_cut_plot(cut, axs[1])
```
We had a nice go at a guess, but how do we pick the *best* cut? For this simple problem, it suffices to scan over the different significances we'll get by cutting at each value of $x$, then just use the value with the highest significance. Doing this leads to the optimal cut being around $x=1.54$.
<!--
```{python}
#| label: fig-cut-scan
#| fig-cap: "A scan over all cut values to find the best resulting Asimov significance."
cut_values = jnp.linspace(0, 8, 100)
significances_hard = jnp.array([significance_after_cut(cut) for cut in cut_values])
fig, ax = plt.subplots(**subplot_settings)
ax.plot(cut_values, significances_hard, label="significance")
optimal_cut = cut_values[jnp.argmax(significances_hard)]
ax.axvline(x=optimal_cut, color="k", linestyle="--", alpha=0.5, label="optimal cut")
ax.text(
0.7,
0.5,
f"optimal cut = {optimal_cut:.2f}",
ha="center",
va="center",
transform=plt.gca().transAxes,
)
ax.set_xlabel("x")
ax.set_ylabel(r"$Z_A$")
```
-->
In reality, though, this could be an expensive procedure to do for a wide range of $x$ and for many different cut variables. This prompts the search for some kind of intelligent optimization that can handle large dimensional parameter spaces. Gradient descent is just that! But, to make it work, we need to be able to calculate the gradient of the significance with respect to the cut value -- something only possible if the cut itself is differentiable (it isn't).
To see this, note that cuts are step functions, i.e. logical less than or more than statements. These can be viewed as applying weights to the data -- 0 on one side of the threshold, and 1 on the other. If we change the cut value, the events either keep their weight (0 change in significance) or sharply gain/lose their weight value (discrete jump in significance). We would then like to replace this thresholding with a *smooth* weight assignment such that the cut value varies smoothly with the weights applied. What kind of operation can do this? We have such a candidate in the *sigmoid function* $1/(1+e^{-x})$.
Normally, the sigmoid serves as a method to map values on the real line to [0,1], so we leverage this to be used as a cut by applying it to data, which results in a set of weights for each point in [0,1]. (A normal cut does this too, but the weights are all 0 or 1, and you drop the 0s. One could similarly threshold on a minimum weight value here.)
Practically, we introduce slope and intercept terms that control the sigmoid's $x$ position and how "hard" the cut is: $1/(1+e^{-\mathrm{slope}(x-\mathrm{cut~value}})$. This slope allows us to control the degree to which we approximate the cut as a thresholding operation, with higher values of the slope meaning less approximation (but this will also increase the variance of the gradients, as we're getting closer to the discrete situation outlined previously). See the sigmoid plotted with different slopes in @fig-sigmoid.
```{python}
#| label: fig-sigmoid
#| fig-cap: "Comparing the sigmoid to a regular hard cut for different values of the sigmoid slope."
# plot significance for all cut values
cut_val = 5 # translates on the x-axis
fig, ax = plt.subplots(**subplot_settings)
ax.plot(cut_values, cut_values > cut_val, label="hard cut")
ax.plot(
cut_values, relaxed.cut(cut_values, cut_val, slope=1), label="sigmoid (slope=1)", color="C1"
)
ax.plot(
cut_values,
relaxed.cut(cut_values, cut_val, slope=10),
label="sigmoid (slope=10)",
color="C2",
alpha=0.4,
)
ax.plot(
cut_values,
relaxed.cut(cut_values, cut_val, slope=0.5),
label="sigmoid (slope=0.5)",
color="C3",
alpha=0.4,
)
ax.set_ylabel("weight applied at x")
ax.set_xlabel("x")
ax.legend();
```
Now that we have a differentiable cut, we can see what the significance scan looks like for both the differentiable and standard cases, shown in @fig-cut-scan-2. It's an interesting plot; there's a clear smoothing out of the overall envelope of the significance in comparison to using the hard cut. However, the important thing is the **coincidence of the maxima**: when optimizing, we'll use the differentiable cut, but we'll plug the value of the cut position from the optimization back in to the hard cut for our actual physics results. This is a very important distinction - *we don't use approximate operations in the final calculation!* Moreover, since we can control the degree to which we're approximating the significance landscape, one could even imagine a fine-tuning of the slope when we're close to a local minima during optimization, allowing us to make jumps more in-line with the true optimum value (though this is not explored here).
```{python}
#| label: fig-cut-scan-2
#| fig-cap: "A scan over all cut values to find the best resulting Asimov significance -- both for the regular cut, and for the sigmoid."
# plot significance for all cut values
def significance_after_soft_cut(cut, slope):
s_weights = (
relaxed.cut(sig, cut, slope) + 1e-4
) # add small offset to avoid 0 weights
b_weights = relaxed.cut(bkg, cut, slope) + 1e-4
return relaxed.metrics.asimov_sig(s_weights.sum(), b_weights.sum())
# choosing the cut slope: increasing slope reduces bias but also noises gradients.
# I increased it until gradients were nan in the next step, then went a touch lower.
# I'll think about a more principled way to do this (suggestions welcome!)
slope = 2.7
fig, ax = plt.subplots(**subplot_settings)
# plot significance for all cut values
cut_values = jnp.linspace(0, 8, 100)
soft = partial(significance_after_soft_cut, slope=slope)
significances = jax.vmap(soft, in_axes=(0))(cut_values)
ax.plot(cut_values, significances_hard, label="hard cut")
ax.plot(cut_values, significances, label="sigmoid (slope=2.7)")
ax.set_xlabel("cut value")
ax.set_ylabel("$Z_A$")
ax.legend();
```
Now that we've done the groundwork, we can do the optimization and see if we converge to the correct result! Using gradient descent and the Adam optimizer with a learning rate of 1e-3, we find the cut shown in @fig-optimized-cut (we optimize $1/Z_A$ since we're doing minimization). The significance (calculated with the *hard* cut) is extremely close to the best possible value, so I'd call this a success!
```{python}
#| label: fig-optimized-cut
#| fig-cap: "The resulting cut from optimization compared to the true best cut. Significances in both cases are shown."
from jaxopt import OptaxSolver
from optax import adam
# define something to minimise (1/significance)
def loss(cut):
s_weights = relaxed.cut(sig, cut, slope) + 1e-4
b_weights = relaxed.cut(bkg, cut, slope) + 1e-4
return 1 / relaxed.metrics.asimov_sig(s_weights.sum(), b_weights.sum())
fig, ax = plt.subplots(**subplot_settings)
# play with the keyword arguments to the optimiser if you want :)
solver = OptaxSolver(loss, adam(learning_rate=1e-2), maxiter=10000, tol=1e-6)
init = 6.0
cut_opt = solver.run(init).params
significance = significance_after_cut(cut_opt)
ax.hist(
[bkg, sig], stacked=True, bins=30, histtype="step", label=["background", "signal"]
)
ax.axvline(
x=cut_opt,
color="r",
linestyle="-",
alpha=0.5,
label=f"optimised cut = {cut_opt:.2f}",
)
significance = significance_after_cut(cut_opt)
ax.axvline(
x=optimal_cut,
color="k",
linestyle="--",
alpha=0.5,
label=f"true best cut = {optimal_cut:.2f}",
)
ax.text(
0.65,
0.3,
f"significance at optimised cut = {significance:.2f}",
ha="center",
va="center",
transform=plt.gca().transAxes,
)
ax.text(
0.65,
0.15,
f"significance at best cut = {significance_after_cut(optimal_cut):.2f}",
ha="center",
va="center",
transform=plt.gca().transAxes,
)
ax.set_xlabel("x")
ax.set_ylabel("count")
ax.legend();
```
### Examining a typical analysis
Now that we've looked at an example of the kind of thing we may want to do, we can zoom out and look at the big picture. Given a pre-filtered dataset, a commonly used analysis pipeline in HEP involves the
following stages:
1. Construction of a learnable 1-D summary statistic from data (with
parameters $\varphi$)
2. Binning of the summary statistic, e.g. through a histogram
3. Statistical model building, using the summary statistic as a
template
4. Calculation of a test statistic, used to perform a frequentist
hypothesis test of signal versus background
5. A $p$-value (or $\mathrm{CL_s}$ value) resulting from that
hypothesis test, used to characterize the sensitivity of the
analysis
We can express this workflow as a direct function of the input dataset
$\mathcal{D}$ and observable parameters $\varphi$:
$$
\mathrm{CL}_s = f(\mathcal{D},\varphi) = (f_{\mathrm{sensitivity}} \circ f_{\mathrm{test\,stat}} \circ f_{\mathrm{likelihood}} \circ f_{\mathrm{histogram}} \circ f_{\mathrm{observable}})(\mathcal{D},\varphi).
$$ {#eq-neos}
Is this going to be differentiable? To calculate $\partial \text{CL}_s / \partial \varphi$, we'll have to split this up by the chain rule into the different components, which can be written verbosely as
$$
\frac{\partial\,\mathrm{CL}_s}{\partial \varphi} = \frac{\partial f_{\mathrm{sensitivity}}}{\partial f_{\mathrm{test\,stat}}}\frac{\partial f_{\mathrm{test\,stat}}}{\partial f_{ \mathrm{likelihood}}} \frac{\partial f_{\mathrm{likelihood}}}{\partial f_{\mathrm{histogram}}} \frac{\partial f_{\mathrm{histogram}}}{\partial f_{\mathrm{observable}}} \frac{\partial f_{\mathrm{observable}}}{\partial \varphi}~.
$${#eq-analysis-chain-rule}
In the case of an observable that has well-defined gradients with respect to $\phi$ (e.g. a neural network), the last term in @eq-analysis-chain-rule is possible to calculate through automatic differentiation. But none of the other terms are differentiable by default! We're going to have to figure out some way to either *relax* (make differentiable) these operations, or use tricks to make the gradient easier to calculate. This is explored in the following sections, starting with the histogram.
### Binned density estimation (histograms) {#sec-bkde}
Histograms are discontinuous by nature. They are defined for 1-D data as a set of two quantities: intervals (or *bins*) over the domain of that data, and counts of the number of data points that fall into each bin. For small changes in the underlying data distribution, bin counts will either remain static, or jump in integer intervals as data migrate between bins, both of which result in ill-defined gradients. Similarly to the cut example with the sigmoid, we're assigning a number (there the weight, and here a count in a bin) in a discrete way to the data -- to make this differentiable, we need to come up with a smooth version of this that allows gradients to be calculated across the result.
To say a little more to that effect, we'll look at the types of gradients that we may be interested in. Say we have a data distribution that depends on some latent parameter $\mu$, e.g. data that's drawn from $\mathrm{Normal}(\mu, 1)$. We can then make a histogram of the resulting data. What happens to that histogram When we shift the value of $\mu$? Well, shifting the mean will just translate the histogram along the $x$-axis; an example of this is shown in @fig-hist-mus for a couple values of $\mu$ (with the random seed kept constant).
```{python}
#| label: fig-hist-mus
#| fig-cap: "Translating a histogram from left to right by varying the center of the distribution the data is drawn from."
import jax
import jax.numpy as jnp
from jax.random import normal, PRNGKey
rng = PRNGKey(7)
from matplotlib.colors import to_rgb
import matplotlib.pyplot as plt
subplot_settings = dict(figsize=[7,3],dpi=150,facecolor='w', tight_layout=True)
from functools import partial
lo, hi = -2, 2
grid_points = 500
mu_grid = jnp.linspace(lo, hi, grid_points)
num_samples = 100
points = jnp.tile(
normal(rng, shape = (num_samples,)),
reps = (grid_points,1)
) + mu_grid.reshape(-1,1)
bins = jnp.linspace(lo-3,hi+3,6)
make_hists = jax.vmap(partial(jnp.histogram, bins = bins))
hists, _ = make_hists(points)
centers = bins[:-1] + jnp.diff(bins) / 2.0
width = (bins[-1] - bins[0])/(len(bins) - 1)
fig, axs = plt.subplots(1,3, **subplot_settings)
# first mu value
axs[0].bar(
centers,
hists[0],
width = width,
label=f'$\mu$={mu_grid[0]}'
)
axs[0].legend()
# axs[0].axis('off')
axs[0].set_xlabel('x')
# middle mu value
axs[1].bar(
centers, hists[len(hists)//2],
width = width,
label=f'$\mu$={mu_grid[len(hists)//2]:.2f}',
color = 'C1'
)
axs[1].legend()
# axs[1].axis('off')
axs[1].set_xlabel('x')
# last mu value
axs[2].bar(
centers,
hists[-1],
width = width,
label=f'$\mu$={mu_grid[-1]}',
color = 'C2'
)
axs[2].legend()
# axs[2].axis('off');
axs[2].set_xlabel('x');
```
Let us now shift our focus to a single bin: we'll choose the bin centered on 0, and monitor its height as we vary $\mu$, shown in @fig-bin-height-mu.
```{python}
#| label: fig-bin-height-mu
#| fig-cap: "Demonstrating the shift in the central histogram bin as $\\mu$ is varied from -2 to 2."
middle = len(bins)//2 - 1
mu_width = mu_grid[1]-mu_grid[0]
fig, ax = plt.subplots(**subplot_settings)
ax.bar(mu_grid, hists[:,middle], width=mu_width,alpha=0.7, color='C1')
ax.set_title("Height of central histogram bin as a function of $\mu$")
ax.set_ylabel("bin height")
ax.set_xlabel('$\mu$');
```
We can see that the bin height jumps around in discrete intervals as we translate the underlying data, which would produce ill-defined gradient estimates if we used something numerical like finite differences. To exploit the magic of automatic differentiation here, we want to make some other function such that this envelope becomes smooth; varying $\mu$ by a very small amount should also vary the bin height by a small amount instead of leaving it static or jumping discontinuously.
The solution that we developed to address this involves a **kernel density estimate** (KDE). We discussed this in @sec-kde, but just to recap: a KDE is essentially the average of a set of normal distributions centered at each data point, with their width controlled by a global parameter called the **bandwidth**. There's a neat way to take this and cast it into a bin-like form (i.e. defined over intervals): We can calculate the "count" in an interval by taking the area under the KDE between the interval endpoints. We can do this using the cumulative density function (cdf), as $P(a \leqslant X \leqslant b) = P(X \leqslant b) - P(X \leqslant a)$. Since the KDE is the mean over some normal distributions, its cdf is also just the mean of the cdfs for each normal distribution. Moreover, to turn this into a histogram-like object, we can multiply the result by the total number of events, which just changes the mean into a sum. We put this all together in @fig-bkde-code, where a pseudocoded implementation of a **binned KDE** (bKDE) can be found.
```{python}
#| label: fig-bkde-code
#| fig-cap: none
#| echo: true
#| eval: false
def bKDE(data: Array, bins: Array, bandwidth: float) -> Array:
edge_hi = bins[1:] # ending bin edges ||<-
edge_lo = bins[:-1] # starting bin edges ->||
# get cumulative counts (area under kde) for each set of bin edges
cdf_hi = norm.cdf(edge_hi.reshape(-1, 1), loc=data, scale=bandwidth)
cdf_lo = norm.cdf(edge_lo.reshape(-1, 1), loc=data, scale=bandwidth)
return (cdf_hi - cdf_lo).sum(axis=1) # sum cdfs over each kernel
```
```{python}
#| label: fig-bin-height-bKDE
#| fig-cap: "Demonstrating the shift in the central histogram bin as $\\mu$ is varied from -2 to 2 for both a regular histogram and a bKDE."
import jax.scipy as jsc
def kde_hist(events, bins, bandwidth=None, density=False):
edge_hi = bins[1:] # ending bin edges ||<-
edge_lo = bins[:-1] # starting bin edges ->||
# get cumulative counts (area under kde) for each set of bin edges
cdf_up = jsc.stats.norm.cdf(edge_hi.reshape(-1, 1), loc=events, scale=bandwidth)
cdf_dn = jsc.stats.norm.cdf(edge_lo.reshape(-1, 1), loc=events, scale=bandwidth)
# sum kde contributions in each bin
counts = (cdf_up - cdf_dn).sum(axis=1)
if density: # normalize by bin width and counts for total area = 1
db = jnp.array(jnp.diff(bins), float) # bin spacing
return counts / db / counts.sum(axis=0)
return counts
# make hists as before
bins = jnp.linspace(lo-3,hi+3,6)
make_kde_hists = jax.vmap(partial(kde_hist, bins = bins, bandwidth = .5))
kde_hists = make_kde_hists(points)
middle = len(bins)//2 - 1
mu_width = mu_grid[1]-mu_grid[0]
fig, axs = plt.subplots(2,1, sharex=True, **subplot_settings)
axs[0].bar(
mu_grid,
hists[:,middle],
# fill=False,
color = 'C1',
width = mu_width,
alpha = .7,
label = 'histogram',
)
axs[0].legend(frameon=False)
axs[0].set_ylabel("bin height")
axs[0].set_title("Height of central histogram bin as a function of $\mu$")
axs[1].bar(
mu_grid,
kde_hists[:,middle],
color = 'C0',
width = mu_width,
alpha = .7,
label = 'bKDE',
)
axs[1].set_ylabel("bin height")
axs[1].legend(frameon=False)
plt.xlabel('$\mu$');
```
Using this, we can remake the plot from @fig-bin-height-mu for the bKDE, which we can see in @fig-bin-height-bKDE, showing that the variation of the bin height with $\mu$ is much more well-behaved.
#### Choosing the bandwidth {-}
I'll show a few studies here that illustrate what happens to the accuracy of the bKDE histogram from the perspective of both the distribution and the resulting gradients.
We know what happens to a KDE when we change the bandwidth: small bandwidth gives a function with high variance, and a large bandwidth oversmooths the distribution. How do these effects impact the bKDE? We can quantify this *relative to the bin width* by examining the shape of the bKDE relative to a "hard" histogram, which is shown in @fig-bkde-bandwidth. For low bandwidths, we recover something almost resembling a regular histogram. In fact, in the limit of zero bandwidth, we will *exactly* get a histogram! The reason is that zero bandwidth would turn each normal distribution into an infinite spike at each data point, which, when integrated over to get the counts, would have a contribution of 1 if the event lies in the bin, and 0 otherwise^[For non-uniform bin widths, an extension of the bKDE to non-uniform bandwidths could be interesting -- one could keep the bin width/bandwidth ratio fixed for each bin, and if the event falls in a given bin, the resulting bandwidth from using that ratio is applied to that event. This would make the analogy between bin width and bandwidth more general in some ways, albeit at the cost of someone's coding time.].
![Illustration of the bias/smoothness tradeoff when tuning the bandwidth of a bKDE, defined over 200 samples from a bi-modal Gaussian mixture. All distributions are normalized to unit area. The individual kernels that make up the KDE are scaled down for visibility.](images/relaxed_hist){#fig-bkde-bandwidth}
This idea of a bias/variance tradeoff with the bandwidth is the gist of it, but there's an additional factor that will influence the value of the bandwidth chosen: the number of data samples available. We may expect that as we add more samples to a KDE, there will be a lot more kernels centered on the new points, so we'd want to reduce the bandwidth in order to faithfully represent the envelope of the distribution. We then can inspect the degree to which this also continues to hold for the bKDE; it may be that good defaults for KDEs differ slightly compared to those for bKDEs.
First, let's examine the distribution accuracy as a function of bandwidth and number of data samples. We can define this by looking at the "true" histogram, which can be calculated using the cumulative distribution of $\mathrm{Normal}(\mu, 1)$ in a way analagous to the bKDE (i.e. the integral under the curve over the intervals defined by the bins), which we then normalize to the number of data samples available. We can then plot the true height of the central bin as it varies with $\mu$, and compare it to that obtained from the histogram and bKDE estimates across a number of different settings for the sample size and bandwidth. These plots are shown in @fig-bin-height-all, which looks at bandwidths of 0.05, 0.5, and 0.8 in tandem with sample sizes of 20, 100, and 5000. As expected, we see that the low bandwidth case has the histogram and bKDE predictions for the bin mostly agreeing, while they diverge for larger bandwidths. The best-case scenario appears to be when we have a large number of samples and a low bandwidth, which is when we'd expect all three estimates to converge. If we choose a bandwidth too large though, we're going to introduce a bias as we oversmooth the data features.
```{python}
#| label: fig-bin-height-all
#| fig-cap: "Demonstrating the shift in the central histogram bin as a function of bandwidth and the number of samples for a histogram and bKDE, which are compared to the true bin height."
def true_hist(bins, mu):
edge_hi = bins[1:] # ending bin edges ||<-
edge_lo = bins[:-1] # starting bin edges ->||
# get cumulative counts (area under curve) for each set of bin edges
cdf_up = jsc.stats.norm.cdf(edge_hi.reshape(-1, 1), loc=mu)
cdf_dn = jsc.stats.norm.cdf(edge_lo.reshape(-1, 1), loc=mu)
counts = (cdf_up - cdf_dn).T
return counts
truth = true_hist(bins,mu_grid)*100
# make hists as before (but normalize)
bins = jnp.linspace(lo-3,hi+3,6)
make_kde_hists = jax.vmap(partial(kde_hist, bins = bins, bandwidth = .5, density=False))
kde_hists = make_kde_hists(points)
make_hists = jax.vmap(partial(jnp.histogram, bins = bins, density = False))
hists, _ = make_hists(points)
def make_points(num_samples, grid_points=300, lo=-2, hi=+2):
mu_grid = jnp.linspace(lo, hi, grid_points)
rngs = [PRNGKey(i) for i in range(9)]
points = jnp.asarray(
[
jnp.tile(
normal(rng, shape = (num_samples,)),
reps = (grid_points,1)
) + mu_grid.reshape(-1,1) for rng in rngs
]
)
return points, mu_grid
def make_kdes(points, bandwidth, bins):
make_kde_hists = jax.vmap(
partial(kde_hist, bins = bins, bandwidth = bandwidth)
)
return make_kde_hists(points)
def make_mu_scan(bandwidth, num_samples, grid_points=500, lo=-2, hi=+2):
points, mu_grid = make_points(num_samples, grid_points, lo, hi)
bins = jnp.linspace(lo-3,hi+3,6)
truth = true_hist(bins,mu_grid)*num_samples
get_kde_hists = jax.vmap(partial(make_kdes, bins=bins, bandwidth=bandwidth))
kde_hists = get_kde_hists(points)
make_hists = jax.vmap(jax.vmap(partial(jnp.histogram, bins = bins)))
hists, _ = make_hists(points)
study_bin = len(bins)//2 - 1
h = jnp.array([truth[:,study_bin],
hists[:,:,study_bin].mean(axis=0),
kde_hists[:,:,study_bin].mean(axis=0)])
stds = jnp.array([hists[:,:,study_bin].std(axis=0),
kde_hists[:,:,study_bin].std(axis=0)])
return h, stds
bws = jnp.array([0.05,0.5,0.8])
lo_samp = jax.vmap(partial(make_mu_scan, num_samples = 20))
mid_samp = jax.vmap(partial(make_mu_scan, num_samples = 100))
hi_samp = jax.vmap(partial(make_mu_scan, num_samples = 5000))
lo_hists, lo_stds = lo_samp(bws)
mid_hists, mid_stds = mid_samp(bws)
hi_hists, hi_stds = hi_samp(bws)
# colors = fade('C0','C9',num_points=7)
fig, axarr = plt.subplots(3,len(bws), sharex=True, sharey='row', figsize=[7,4], dpi=150)
up, mid, down = axarr
for i,res in enumerate(zip(lo_hists, lo_stds)):
hists, stds = res
up[i].plot(mu_grid,hists[0],alpha=.4, color='C3',label="actual", linestyle=':')
up[i].fill_between(mu_grid, hists[1]+stds[0], hists[1]-stds[0], alpha=.2,color='C1',label='histogram variance')
up[i].plot(mu_grid,hists[1],alpha=.4, color='C1',label="histogram")
up[i].fill_between(mu_grid, hists[2]+stds[1], hists[2]-stds[1], alpha=.2,color='C0')
up[i].plot(mu_grid,hists[2],alpha=.6,color='C0',label="bKDE")
up[i].set_title(f'bw = {bws[i]:.2f}', color='C0')
for i,res in enumerate(zip(mid_hists, mid_stds)):
hists, stds = res
mid[i].plot(mu_grid,hists[0],alpha=.4, color='C3',label="true bin height", linestyle=':')
mid[i].fill_between(mu_grid, hists[1]+stds[0], hists[1]-stds[0], alpha=.2,color='C1',label='histogram $\pm$ std')
mid[i].plot(mu_grid,hists[1],alpha=.4,color='C1',label="histogram")
mid[i].fill_between(mu_grid, hists[2]+stds[1], hists[2]-stds[1], alpha=.2,color='C0',label='bKDE $\pm$ std')
mid[i].plot(mu_grid,hists[2],alpha=.6,color='C0',label="bKDE")
for i,res in enumerate(zip(hi_hists, hi_stds)):
hists, stds = res
down[i].plot(mu_grid,hists[0],alpha=.4, color='C3',label="actual", linestyle=':')
down[i].fill_between(mu_grid, hists[1]+stds[0], hists[1]-stds[0], alpha=.2,color='C1')
down[i].plot(mu_grid,hists[1],alpha=.4, color='C1',label="histogram")
down[i].fill_between(mu_grid, hists[2]+stds[1], hists[2]-stds[1], alpha=.2,color='C0')
down[i].plot(mu_grid,hists[2],alpha=.6,color='C0',label="bKDE")
#down[0].set_ylabel('n=1e6', rotation=0, size='large')
down[1].set_xlabel("$\mu$",size='large')
down[0].set_ylabel("yield",size='large',labelpad=11)
mid[0].set_ylabel("yield",size='large',labelpad=11)
up[0].set_ylabel("yield",size='large',labelpad=11)
# mid[0].set_ylabel("frequency",size='large',labelpad=11)
mid[-1].legend(bbox_to_anchor=(1.1, 1.05), frameon=False)
rows = [f"{s} samples" for s in [20, 100, 5000]]
for ax, row in zip(axarr[:,0], rows):
ax.annotate(row, xy=(0, 0.5), xytext=(-ax.yaxis.labelpad , 0),
xycoords=ax.yaxis.label, textcoords='offset points',
size='large', ha='right', va='center', color='C2')
# plt.suptitle("Height of central histogram bin as a function of $\mu$")
fig.tight_layout()
# tight_layout doesn't take these labels into account. We'll need
# to make some room. These numbers are are manually tweaked.
# You could automatically calculate them, but it's a pain.
fig.subplots_adjust(left=0.15, top=0.95)
plt.subplots_adjust(hspace=0.2);
```
So far things seem all to follow intuition somewhat, but we've only checked half the picture; the whole reason we're using the bKDE construct in the first place is so we can access *gradients* of the histogram yields. To study these, we can derive the "true" gradients from the definition of the bin height: as before, a bin defined by $(a,b)$ for a given $\mu$ value is just
$$\operatorname{yield}_{\mathsf{true}}(\mu; a,b) = \Phi(b;\mu, \sigma) - \Phi(a;\mu, \sigma) ~,$$
where $\Phi(x; \mu, \sigma)$ is the normal cumulative distribution parametrized by $\sigma, \mu=$. We can then just take the gradient of this expression with respect $\mu$ by hand. First we write the explicit definition of the cdf:
$$\Phi(x; \mu, \sigma) = \frac{1}{2}\left[1+\operatorname{erf}\left(\frac{x-\mu}{\sigma \sqrt{2}}\right)\right]~,$$
where the convenient short hand of the error function $\operatorname{erf}$ is given by
$$
\operatorname{erf}(x) \equiv \frac{2}{\sqrt{\pi}} \int_0^x e^{-t^2} d x~.
$$
Then, the derivative is as follows:
$$\frac{\partial}{\partial\mu}\Phi(x;\mu, \sigma) = \frac{1}{2}\left[1-\left(\frac{2}{\sqrt{2\pi}\sigma} e^{-\frac{(x-\mu)^2}{2\sigma^2}}\right)\right]~,$$
since $\frac{d}{dx} \operatorname{erf}(x)=\frac{2}{\sqrt{\pi}} e^{-x^{2}}$.
As mentioned, we have $\sigma=1$ in this particular example, making this expression simpler:
$$ \frac{\partial}{\partial\mu}\Phi(x;\mu, \sigma=1) = \frac{1}{2}\left[1-\left(\frac{2}{\sqrt{2\pi}} e^{-\frac{(x-\mu)^2}{2}}\right)\right]~.$$
Putting this all together gives us
$$\Rightarrow \frac{\partial}{\partial\mu}\operatorname{yield}_{\mathsf{true}}(\mu; a,b) = -\frac{1}{\sqrt{2\pi}}\left[\left(e^{-\frac{(b-\mu)^2}{2}}\right) - \left( e^{-\frac{(a-\mu)^2}{2}}\right)\right]~,$$
which we can use as a way to quantify the accuracy of the gradients obtained from using a bKDE compared to those of the amount of the true distribution in the interval $(a,b)$.
The comparative plots between the true gradient, the histogram gradient, and the bKDE gradient are shown in @fig-bin-grad-all, where the histogram gradient is calculated using the finite differences method, and the bKDE gradient with automatic differentiation. A similar trend can be seen to @fig-bin-height-all, where the estimate from the bKDE improves with more samples, and becomes much less noisy. This is in contrast to the histogram, which struggles with gradients unless the sample size is large (here 5000), and produces very high variance estimates in general. The bKDE, however, is able to avoid this high variance while keeping a reasonably low bias depending on how many samples are present; the central plot, for instance, shows a case where the bKDE of bandwidth 0.5 far outperforms estimating the true gradient with just 100 samples compared to the erratic estimates using the regular histogram.
```{python}
#| label: fig-bin-grad-all
#| fig-cap: "Variation of the gradient of the central histogram bin as a function of bandwidth and the number of samples for a histogram and bKDE, which are compared to the true gradient of the yield."
def true_grad(mu,bins):
b = bins[1:] # ending bin edges ||<-
a = bins[:-1] # starting bin edges ->||
return -(1/((2*jnp.pi)**0.5))*(jnp.exp(-((b-mu)**2)/2) - jnp.exp(-((a-mu)**2)/2))
def gen_points(mu, jrng, nsamples):
points = normal(jrng, shape = (nsamples,))+mu
return points
def bin_height(mu, jrng, bw, nsamples, bins):
points = gen_points(mu, jrng, nsamples)
return kde_hist(points, bins, bandwidth=bw)[2]
def kde_grads(bw, nsamples, lo=-2, hi=+2, grid_size=300):
bins = jnp.linspace(lo-3,hi+3,6)
mu_grid = jnp.linspace(lo,hi,grid_size)
rngs = [PRNGKey(i) for i in range(9)]
grad_fun = jax.grad(bin_height)
grads = []
for i,jrng in enumerate(rngs):
get_grads = jax.vmap(partial(
grad_fun, jrng=jrng, bw=bw, nsamples=nsamples, bins=bins
))
grads.append(get_grads(mu_grid))
return jnp.asarray(grads)
def get_hist(mu, jrng, nsamples, bins):
points = gen_points(mu, jrng, nsamples)
hist, _ = jnp.histogram(points, bins)
return hist[2]
def hist_grad_numerical(bin_heights, mu_width):
# in mu plane
lo = bin_heights[:-1]
hi = bin_heights[1:]
bin_width = (bins[1]-bins[0])
grad_left = -(lo-hi)/mu_width
# grad_right = -grad_left
return grad_left
def hist_grads(nsamples, lo=-2, hi=+2, grid_size=300):
bins = jnp.linspace(lo-3,hi+3,6)
mu_grid = jnp.linspace(lo,hi,grid_size)
rngs = [PRNGKey(i) for i in range(9)]
grad_fn = partial(hist_grad_numerical, mu_width=mu_grid[1]-mu_grid[0])
grads = []
for jrng in rngs:
get_heights = jax.vmap(partial(
get_hist, jrng=jrng, nsamples=nsamples, bins=bins
))
grads.append(grad_fn(get_heights(mu_grid)))
return jnp.asarray(grads)
def both_grads(bw, nsamples, lo=-2, hi=+2, grid_size=300):
bins = jnp.linspace(lo-3,hi+3,6)
mu_grid = jnp.linspace(lo,hi,grid_size)
hist_grad_fun = partial(hist_grad_numerical, mu_width=mu_grid[1]-mu_grid[0])
grad_fun = jax.grad(bin_height)
hist_grads = []
kde_grads = []
rngs = [PRNGKey(i) for i in range(3)]
for jrng in rngs:
get_heights = jax.vmap(partial(
get_hist, jrng=jrng, nsamples=nsamples, bins=bins
))
hist_grads.append(hist_grad_fun(get_heights(mu_grid)))
get_grads = jax.vmap(partial(
grad_fun, jrng=jrng, bw=bw, nsamples=nsamples, bins=bins
))
kde_grads.append(get_grads(mu_grid))
hs = jnp.array(hist_grads)
ks = jnp.array(kde_grads)
h = jnp.array([hs.mean(axis=0),hs.std(axis=0)])
k = jnp.array([ks.mean(axis=0),ks.std(axis=0)])
return h,k
bws = jnp.array([0.05,0.5,0.8])
samps = [20,100,5000]
grid_size = 60
lo_samp = jax.vmap(partial(both_grads, nsamples = samps[0],grid_size=grid_size))
mid_samp = jax.vmap(partial(both_grads, nsamples = samps[1],grid_size=grid_size))
hi_samp = jax.vmap(partial(both_grads, nsamples = samps[2],grid_size=grid_size))
lo_hist, lo_kde = lo_samp(bws)
mid_hist, mid_kde = mid_samp(bws)
hi_hist, hi_kde = hi_samp(bws)
mu_grid = jnp.linspace(-2,2,grid_size)
true_grad_many = jax.vmap(partial(true_grad, bins = bins))
true = [true_grad_many(mu_grid)[:,2]*s for s in samps]
fig, axarr = plt.subplots(3,len(bws), sharex=True, sharey='row', figsize=[7,4], dpi=150)
up, mid, down = axarr
for i,res in enumerate(zip(lo_hist, lo_kde)):
hist_grads, hist_stds = res[0]
kde_grads, kde_stds = res[1]
up[i].plot(mu_grid,true[0],alpha=.4, color='C3',label="actual", linestyle=':')
y = jnp.array(up[i].get_ylim())
up[i].plot(mu_grid[:-1], hist_grads,alpha=.7, color='C1',label="histogram",linewidth=0.6)
up[i].fill_between(mu_grid, kde_grads+kde_stds, kde_grads-kde_stds, alpha=.2,color='C0',label='kde histogram $\pm$ std')
up[i].plot(mu_grid,kde_grads,alpha=.6,color='C0',label="kde histogram")
up[i].set_title(f'bw = {bws[i]:.2f}', color='C0')
up[i].set_ylim(y*1.3)
for i,res in enumerate(zip(mid_hist, mid_kde)):
hist_grads, hist_stds = res[0]
kde_grads, kde_stds = res[1]
mid[i].plot(mu_grid,true[1],alpha=.4, color='C3',label="true gradient", linestyle=':')
y = jnp.array(mid[i].get_ylim())
mid[i].plot(mu_grid[:-1], hist_grads,alpha=.7, color='C1',label="histogram",linewidth=0.6)
mid[i].fill_between(mu_grid, kde_grads+kde_stds, kde_grads-kde_stds, alpha=.2,color='C0',label='bKDE $\pm$ std')
mid[i].plot(mu_grid,kde_grads,alpha=.6,color='C0',label="bKDE")
mid[i].set_ylim(y*1.3)
for i,res in enumerate(zip(hi_hist, hi_kde)):
hist_grads, hist_stds = res[0]
kde_grads, kde_stds = res[1]
down[i].plot(mu_grid,true[2],alpha=.4, color='C3',label="actual", linestyle=':')
y = jnp.array(down[i].get_ylim())
down[i].plot(mu_grid[:-1], hist_grads,alpha=.7, color='C1',label="histogram")
down[i].fill_between(mu_grid, kde_grads+kde_stds, kde_grads-kde_stds, alpha=.2,color='C0')
down[i].plot(mu_grid,kde_grads,alpha=.6,color='C0',label="kde histogram")