Skip to content

Commit

Permalink
Fixed benchmarking demo
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholaskl97 committed Mar 7, 2025
1 parent c626cd9 commit 078cd5c
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions docs/src/demos/benchmarking.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ opt = [OptimizationOptimisers.Adam(0.05), OptimizationOptimJL.BFGS()]
optimization_args = [[:maxiters => 300], [:maxiters => 300]]

# Run benchmark
endpoint_check = (x) -> ([sin(x[1]), cos(x[1]), x[2]], [0, -1, 0], atol = 5e-3)
cm, time = benchmark(
open_loop_pendulum_dynamics,
lb,
Expand All @@ -97,7 +98,8 @@ cm, time = benchmark(
state_syms,
parameter_syms,
policy_search = true,
endpoint_check = (x) -> ([sin(x[1]), cos(x[1]), x[2]], [0, -1, 0], atol=5e-3),
endpoint_check,
classifier = (V, V̇, x) ->< zero(V̇) || endpoint_check(x),
init_params = ps
)
```
Expand All @@ -112,7 +114,6 @@ using NeuralPDE, NeuralLyapunov, Lux
import Boltz.Layers: PeriodicEmbedding
using Random, StableRNGs
rng = StableRNG(0)
Random.seed!(200)
# Define dynamics and domain
Expand Down Expand Up @@ -142,7 +143,7 @@ chain = [Chain(
PeriodicEmbedding([1], [2π]),
Dense(3, dim_hidden, tanh),
Dense(dim_hidden, dim_hidden, tanh),
Dense(dim_hidden, 1, use_bias = false)
Dense(dim_hidden, 1)
) for _ in 1:dim_output]
ps = Lux.initialparameters(StableRNG(0), chain)
Expand Down Expand Up @@ -209,6 +210,7 @@ endpoint_check = (x) -> ≈([sin(x[1]), cos(x[1]), x[2]], [0, -1, 0], atol=5e-3)
parameter_syms,
policy_search = true,
endpoint_check,
classifier = (V, V̇, x) -> V̇ < zero(V̇) || endpoint_check(x),
verbose = true,
init_params = ps
);
Expand Down Expand Up @@ -238,7 +240,6 @@ all(endpoint_check.(endpoints) .== actual)
```
Similarly, the `predicted` labels are the results of the neural Lyapunov classifier.
In this case, we used the default classifier, which just checks for negative values of ``\dot{V}``.
```@example benchmarking
classifier = (V, V̇, x) ->< zero(V̇) || endpoint_check(x)
Expand Down

0 comments on commit 078cd5c

Please sign in to comment.