Skip to content

Commit

Permalink
improve benchmarks
Browse files Browse the repository at this point in the history
  • Loading branch information
teddygroves committed Jan 29, 2025
1 parent b07a693 commit 3702bf9
Show file tree
Hide file tree
Showing 12 changed files with 136 additions and 121 deletions.
53 changes: 32 additions & 21 deletions benchmarks/analyse_results.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from pathlib import Path
from matplotlib import pyplot as plt
from matplotlib.legend_handler import HandlerPatch
import matplotlib.patches as mpatches

import matplotlib
import matplotlib.patches as mpatches
import numpy as np
import polars as pl
from matplotlib import pyplot as plt
from matplotlib import ticker
from matplotlib.legend_handler import HandlerPatch

HERE = Path(__file__).parent
CSV_FILE_METHIONINE = HERE / "methionine.csv"
Expand Down Expand Up @@ -77,26 +79,35 @@ def mm_fig(results_df: pl.DataFrame):


def performance_fig(results: pl.DataFrame):
f, ax = plt.subplots(figsize=[8, 5])
ax.set_ylim(ymin=0.0, ymax=3.0)
f, ax = plt.subplots(figsize=[5, 8])

models = results.group_by(["dim"]).agg(pl.col("model").first()).sort("dim")
model_names = [models["model"].to_list()[0]] + [
f"{m} dim {d}" for d, m in models.iter_rows() if m == "Rosenbrock"
models = (
results[["model", "dim"]]
.group_by(["model", "dim"])
.first()
.sort(["dim", "model"])
)
model_names = [
f"{m} dim {d}" if m == "Rosenbrock" else m for m, d in models.iter_rows()
]
models = models.with_columns(
xtick_loc=np.linspace(*ax.get_xlim(), len(model_names)) # type: ignore
ytick_loc=np.linspace(*ax.get_ylim(), len(model_names)) # type: ignore
)
results = results.join(models, on=["model", "dim"])
ax.scatter(results["xtick_loc"], results["perf_ratio"])
ax.axhline(1.0, linestyle="--", color="black", label="y=1")
ax.text(0.1, 1.05, "↑ grapeNUTS did better")
ax.text(0.1, 0.95, "↓ NUTS did better", verticalalignment="top")
ax.set_xticks(models["xtick_loc"], model_names, rotation=90)
ax.scatter(
results["perf_ratio"],
results["ytick_loc"],
color="black",
marker="|",
)
# ax.axvline(1.0, linestyle="--", color="gray")
ax.grid(visible=True, which="major", axis="x")
ax.set_yticks(models["ytick_loc"], model_names)
ax.set(
xlabel="Problem",
ylabel="Performance ratio grapeNUTS:NUTS",
xlabel="Performance ratio grapeNUTS:NUTS",
)
ax.semilogx()
ax.xaxis.set_major_formatter(ticker.ScalarFormatter())
return f, ax


Expand Down Expand Up @@ -171,20 +182,20 @@ def trajectory_fig(result: pl.DataFrame):
def main():
matplotlib.rcParams["savefig.dpi"] = 300
df_methionine = pl.read_csv(CSV_FILE_METHIONINE).with_columns(
model=pl.lit("Methionine"), dim=0
model=pl.lit("Methionine cycle"), dim=0
)
df_rb = pl.read_csv(CSV_FILE_ROSENBROCK).with_columns(model=pl.lit("Rosenbrock"))
df_linear = pl.read_csv(CSV_FILE_LINEAR).with_columns(
model=pl.lit("Small enzyme network"), dim=0
model=pl.lit("Toy reaction network"), dim=0
)
df_trajectory = pl.read_csv(CSV_FILE_TRAJECTORY)
df_performance = pl.concat([df_methionine, df_rb, df_linear], how="align")
df_performance = pl.concat([df_linear, df_methionine, df_rb], how="align")

f, _ = performance_fig(df_performance)
f.savefig(HERE / "performance.png", bbox_inches="tight")
f.savefig(HERE / "performance.png", bbox_inches="tight", dpi=300)

f, _ = trajectory_fig(df_trajectory)
f.savefig(HERE / "trajectory.png", bbox_inches="tight")
f.savefig(HERE / "trajectory.png", bbox_inches="tight", dpi=300)


if __name__ == "__main__":
Expand Down
7 changes: 7 additions & 0 deletions benchmarks/linear.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
neff_n,neff_gn,time_n,time_gn,perf_n,perf_gn,perf_ratio,rep
24072.294987635672,18974.690930263594,0.4788801670074463,0.4104759693145752,50267.88880831092,46226.07009601095,0.9195944208496036,0
18546.163737834835,23596.392322760545,0.47185778617858887,0.41691112518310547,39304.56226659674,56598.1354237,1.4399889519136135,1
18703.8835241752,21254.82381524762,0.5169639587402344,0.40186071395874023,36180.24662638732,52891.021881352375,1.4618756590447726,2
19247.542369171537,19974.481730038064,0.528895378112793,0.44822001457214355,36391.965529838264,44564.011156675064,1.2245563136768864,3
21932.88030739336,21220.67757871731,0.525414228439331,0.46139979362487793,41743.97859864186,45991.95290488124,1.1017625642989763,4
22924.323742632783,22761.506368702838,0.5135111808776855,0.4438629150390625,44642.306918129565,51280.48682936193,1.1486970627079434,5
2 changes: 1 addition & 1 deletion benchmarks/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
ERROR_SD = 0.05
PARAM_SD = 0.02
HERE = Path(__file__).parent
CSV_OUTPUT_FILE = HERE / "linear_pathway.csv"
CSV_OUTPUT_FILE = HERE / "linear.csv"
TRUE_PARAMS = OrderedDict(
log_km=jnp.array([2.0, 2.0]),
log_vmax=jnp.array(3.0),
Expand Down
7 changes: 0 additions & 7 deletions benchmarks/linear_pathway.csv

This file was deleted.

7 changes: 7 additions & 0 deletions benchmarks/methionine.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
neff_n,neff_gn,time_n,time_gn,perf_n,perf_gn,perf_ratio,rep
20.27737146139156,39.21248579845454,1301.8673107624054,137.1648519039154,0.015575605358365312,0.2858785268541176,18.354248215500565,0
43.47606460201753,52.918514420523984,1314.7235043048859,143.45895671844482,0.03306859918428551,0.3688756396324757,11.154861370957882,1
11.713577435417788,36.1954448023344,1286.6312198638916,130.41230010986328,0.009104067470597308,0.2775462496393535,30.485961416226626,2
24.68993575956241,38.47905969445961,1364.909896850586,136.5894739627838,0.018089059077476356,0.28171321389629117,15.573679796704694,3
52.47129754448197,48.556463695294354,1266.5149881839752,147.97230100631714,0.041429669632034344,0.3281456283714978,7.920546586202278,4
53.804180683663056,23.76518734518563,1417.9942479133606,135.16810774803162,0.037943863850532694,0.17581948686806048,4.633673775571281,5
10 changes: 5 additions & 5 deletions benchmarks/methionine.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@
SEED = 1234
PARAM_SD = 0.01
HERE = Path(__file__).parent
CSV_OUTPUT_FILE = HERE / "methionine_pathway.csv"
CSV_OUTPUT_FILE = HERE / "methionine.csv"
DEFAULT_GUESS = jnp.full((5,), 0.01)
N_WARMUP = 5
N_SAMPLE = 5
N_TEST = 2
N_WARMUP = 200
N_SAMPLE = 100
N_TEST = 6
INIT_STEPSIZE = 0.0001
MAX_TREEDEPTH = 10
TARGET_ACCEPT = 0.95
Expand Down Expand Up @@ -225,12 +225,12 @@ def run_comparison(n_test: int):
keys = jax.random.split(key, n_test)
results = []
for i, keyi in enumerate(keys):
print(f"Starting methionine rep {i}...")
compare_key, param_key = jax.random.split(keyi)
params = generate_random_params(param_key, TRUE_PARAMS, PARAM_SD)
result = compare_single(compare_key, params)
result["rep"] = i
results.append(result)
print(results)
return pl.from_records(results)


Expand Down
3 changes: 0 additions & 3 deletions benchmarks/methionine_pathway.csv

This file was deleted.

Binary file modified benchmarks/performance.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
60 changes: 30 additions & 30 deletions benchmarks/rosenbrock.csv
Original file line number Diff line number Diff line change
@@ -1,31 +1,31 @@
neff_n,neff_gn,time_n,time_gn,perf_n,perf_gn,perf_ratio,rep,dim
815.8725693265565,747.9061376295996,0.7671618461608887,0.5332858562469482,1063.4947155016,1402.448853403056,1.318717275187952,0,3
933.6261760214178,800.1695941752808,1.0438041687011719,0.5813190937042236,894.4457246067038,1376.4722384680672,1.5389108590946812,1,3
563.2690241344768,706.7098671534744,0.6170749664306641,0.4874701499938965,912.8048531811037,1449.7500352838485,1.588236554868769,2,3
863.0291450297117,851.08978888785,0.7249410152435303,0.5081350803375244,1190.4818831912735,1674.9282264127894,1.406932982401111,3,3
761.0703994524756,890.1527548450604,0.819774866104126,0.5358800888061523,928.3895261016775,1661.1043653966804,1.7892321258423545,4,3
699.6110865496336,747.1214968109043,0.7218139171600342,0.5170590877532959,969.2402292577603,1444.9441359927862,1.4908008276743916,5,3
946.9785551839869,1140.3001338993158,0.8764090538024902,0.6602158546447754,1080.5211916459737,1727.1626027715533,1.598453242865641,0,4
994.4928042178033,989.5726606363419,0.8976240158081055,0.6525299549102783,1107.9168858048986,1516.5168329665514,1.3688001802272434,1,4
1029.9715292945789,1232.6572305493132,0.938391923904419,0.7282321453094482,1097.5920647410514,1692.6707211276964,1.5421674185727996,2,4
1000.3101159743064,1103.923473356662,0.8224389553070068,0.7095808982849121,1216.2727817299246,1555.7401221268683,1.2791046099988472,3,4
1053.1662099315238,1134.0464864488015,0.9139180183410645,0.7243170738220215,1152.3639853859336,1565.6768664374442,1.358665218883155,4,4
1160.8154574893363,1076.7717863541673,1.0613470077514648,0.7534339427947998,1093.7190655001723,1429.1522125482866,1.3066904085600046,5,4
1202.6948925923186,1271.5925261048094,1.0775160789489746,0.9050240516662598,1116.1734994854512,1405.0372736101901,1.2587982730802185,0,5
1366.5412330459715,1147.2881443078966,1.2379977703094482,0.920604944229126,1103.83174010434,1246.2328727428085,1.1290061949342098,1,5
1275.0851584749612,1219.0506909817154,1.3068222999572754,0.9426817893981934,975.7142639183983,1293.1730565835535,1.3253604097066938,2,5
1188.8232383584213,1172.8142023381902,1.3885371685028076,0.9618020057678223,856.1695468622362,1219.3925520064947,1.4242419115179108,3,5
1127.3659089544478,1032.8851111795698,1.2755961418151855,0.939612865447998,883.7953267483197,1099.2666758421835,1.2438023177681095,4,5
1232.6560163252846,1143.8731036584581,1.1626739501953125,0.9061949253082275,1060.1906201805039,1262.2815155021842,1.1906175092242122,5,5
1135.1303680216727,1354.6212274186726,1.592085838317871,1.119549036026001,712.9831449421108,1209.9704290105003,1.6970533421357967,0,6
1325.7763014754355,1258.4063196825066,1.3597478866577148,1.1100869178771973,975.016261826461,1133.61062040884,1.1626581676548555,1,6
1399.1998671661706,1624.819311075795,1.8864622116088867,1.1712651252746582,741.7057487585982,1387.2344322510069,1.8703298910286696,2,6
1407.5265009904317,1191.5440234733355,1.3895790576934814,1.0770809650421143,1012.9157410638735,1106.2715451727863,1.0921654194166837,3,6
1088.2650805926578,1384.2941802320329,1.7636387348175049,1.1499660015106201,617.0566903007308,1203.769657897359,1.9508250648910164,4,6
1181.7490992811279,1162.7188326540447,1.3521599769592285,1.0949628353118896,873.9713639052349,1061.879723363264,1.2150051674672537,5,6
1345.6598235336494,1506.9763931949703,1.7867820262908936,1.365283727645874,753.1191850676092,1103.782578434018,1.4656147397638382,0,7
1333.8150270487874,1288.5614190889294,1.6042728424072266,1.3258731365203857,831.4140785724362,971.858757520816,1.1689226614847894,1,7
1296.065759130828,1521.7268861130756,2.332318067550659,1.437284231185913,555.6985460786328,1058.7515350791043,1.905262381106335,2,7
1546.9896220415646,1202.4422799060758,1.8471908569335938,1.35001802444458,837.4822862698799,890.6860931732971,1.0635282772849863,3,7
1152.329063554528,1344.0600124139287,1.5863208770751953,1.3473069667816162,726.4161243840863,997.5900411355819,1.3733038235920474,4,7
1569.3214511403014,1140.123172008908,2.082958936691284,1.3904142379760742,753.4096920955725,819.9881307807257,1.0883695011939234,5,7
2484.527485755945,2239.284870166519,0.9704198837280273,0.5692501068115234,2560.26028260182,3933.745191035928,1.5364629986129097,0,3
2596.8320332964286,2428.84976461932,0.7447302341461182,0.5270538330078125,3486.9432100791046,4608.352339946491,1.3216023497675335,1,3
2723.7324360882694,2725.861370465307,0.992279052734375,0.5480010509490967,2744.9258639317363,4974.190041687548,1.8121400315572425,2,3
2350.790208511543,2700.6433318933814,0.7026560306549072,0.5768849849700928,3345.5775030074105,4681.424204572407,1.3992873279319267,3,3
2119.3119865564913,2254.5312566486787,0.6429357528686523,0.5589208602905273,3296.304455150518,4033.7217964575025,1.2237103251050614,4,3
2330.421273513586,2264.0839208700736,0.8159811496734619,0.5945878028869629,2855.974398974989,3807.820997802235,1.3332826089648648,5,3
4097.484636114787,4327.828518135458,0.948634147644043,0.7528088092803955,4319.35182418954,5748.907909662213,1.330965418808158,0,4
4162.750617603454,4425.727060668486,1.387017011642456,0.8006398677825928,3001.225350995568,5527.737549375016,1.8418268883212492,1,4
4103.453733975872,4403.466583062734,1.0865201950073242,0.7768878936767578,3776.6934777942256,5668.0849565341505,1.5008061919403082,2,4
4316.446209722598,3744.2136765182595,1.2622199058532715,0.7926173210144043,3419.7259841221116,4723.860527961154,1.3813564449006082,3,4
3976.2547303531464,4420.30436091055,1.2244141101837158,0.7700138092041016,3247.475422964975,5740.552062929165,1.7676968460897504,4,4
4032.9536547732723,4337.739795274984,0.9424960613250732,0.7427756786346436,4279.013802034638,5839.905532782841,1.3647783818799584,5,4
5865.224701156879,6064.8522293497235,1.4781160354614258,0.9600808620452881,3968.0407765320824,6317.022314588787,1.5919751510491345,0,5
7090.381338048299,4905.098866132659,1.7004039287567139,1.008202075958252,4169.8217806592465,4865.194173966144,1.1667630968144063,1,5
5654.5818926998,5948.447145918388,1.2057430744171143,0.9405319690704346,4689.707129716141,6324.55604012852,1.3486036260245815,2,5
5442.179877499188,5332.5668691126075,1.5342731475830078,0.9921491146087646,3547.0736655154506,5374.763521525094,1.5152669575989899,3,5
6279.544341114237,5697.365107202284,1.251161813735962,0.9442539215087891,5018.970585717889,6033.721414784988,1.2021830596008474,4,5
5620.281439769922,6776.364253580559,1.2560698986053467,0.9398901462554932,4474.497355609186,7209.740713398776,1.6112962284714987,5,5
7469.436170200057,7902.032189738578,1.3285129070281982,1.109236240386963,5622.403915449137,7123.849638181595,1.2670469331822307,0,6
8001.902348931977,8107.952698662346,1.3668029308319092,1.1361048221588135,5854.466776758806,7136.623787280215,1.2190049169996735,1,6
6415.749042948314,7914.99780278376,1.401257038116455,1.1499860286712646,4578.566864200911,6882.690402707791,1.5032412121186776,2,6
8442.629340907077,7352.496032524874,1.4687731266021729,1.1232349872589111,5748.0826602799225,6545.821770088869,1.1387835139048084,3,6
8055.3970679762015,7678.972896728729,1.5518858432769775,1.1408910751342773,5190.7149632645305,6730.680135984893,1.2966768901045285,4,6
7104.265255290618,7771.757600251429,1.490936040878296,1.1302399635314941,4764.969831372219,6876.201382906479,1.4430734351420365,5,6
1742.2567820337977,7000.0,72.57898283004761,0.45693111419677734,24.004976566198245,15319.595848282397,638.1841617731109,0,7
9236.737285452782,8913.521465679327,1.6470563411712646,1.3288938999176025,5608.027518284103,6707.474137876625,1.1960487205185686,1,7
9607.269430521472,10550.509841407467,1.7503888607025146,1.3637700080871582,5488.648634718582,7736.2823488146305,1.4095058481023155,2,7
11249.71811160961,9604.996995836098,1.584542989730835,1.363759994506836,7099.661028143258,7043.0259243008995,0.9920228439614434,3,7
10447.407426768466,8432.367931256344,1.698009967803955,1.3454980850219727,6152.736217608988,6267.0976830997415,1.0185870905961243,4,7
9438.321349342834,10837.868616790474,1.7301969528198242,1.3544340133666992,5455.056046631301,8001.769381035348,1.466853743139218,5,7
Loading

0 comments on commit 3702bf9

Please sign in to comment.