Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ParallelDynamicalSystem for StroboscopicMap and other API #225

Merged
merged 9 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions src/derived_systems/parallel_systems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,23 @@ function ParallelDynamicalSystem(ds::CoreDynamicalSystem, states::Vector{<:Abstr
pds = CoupledODEs(prob, ds.diffeq; internalnorm = inorm)
end
M = ds isa CoupledODEs && isinplace(ds)
prob = referrenced_sciml_prob(ds)
prob = referrenced_sciml_prob(ds)
return ParallelDynamicalSystemAnalytic{typeof(pds), M}(pds, dynamic_rule(ds), prob)
end

function ParallelDynamicalSystem(smap::StroboscopicMap,states)
f, st = parallel_rule(smap.ds, states)
T = eltype(first(st))
prob = ODEProblem{true}(f, st, (T(initial_time(smap)), T(Inf)), current_parameters(smap))
inorm = prob.u0 isa Matrix ? matrixnorm : vectornorm
cont_pds = CoupledODEs(prob, smap.ds.diffeq; internalnorm = inorm)
pds = StroboscopicMap(cont_pds,smap.period)

M = smap.ds isa CoupledODEs && isinplace(smap.ds)
prob = referrenced_sciml_prob(smap.ds)
return ParallelDynamicalSystemAnalytic{typeof(pds), M}(pds, dynamic_rule(smap), prob)
end

function ParallelDynamicalSystem(ds::CoreDynamicalSystem, mappings::Vector{<:Dict})
# convert to vector of arrays:
u = Array(current_state(ds))
Expand Down Expand Up @@ -164,6 +177,14 @@ function set_state!(pdsa::PDSAM, u::AbstractArray, i::Int = 1)
return pdsa
end


function set_state!(pdsa::PDSAM{<: StroboscopicMap}, u::AbstractArray, i::Int = 1)
current_state(pdsa, i) .= u
u_modified!(pdsa.ds.ds.integ, true)
return pdsa
end


# We make one more extension here: for continuous time, in place systems
# the state is a matrix (each column a parallel state) for performance.
# re-init will not work because there is no way to do the recursive copy. we do it ourselves
Expand Down Expand Up @@ -210,7 +231,7 @@ current_states(pdtds::PDTDS) = [current_state(ds) for ds in pdtds.systems]
initial_states(pdtds::PDTDS) = [initial_state(ds) for ds in pdtds.systems]

# Set stuff
set_parameter!(pdtds::PDTDS) = for ds in pdtds.systems; set_parameter!(ds, args...); end
set_parameter!(pdtds::PDTDS,index,value) = for ds in pdtds.systems; set_parameter!(ds, index,value); end
function set_state!(pdtds::PDTDS, u, i::Int = 1)
# We need to set state in all systems, in case this does
# some kind of resetting, e.g., the `u_modified!` stuff.
Expand Down
38 changes: 36 additions & 2 deletions test/parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ for (ds, idt, iip) in zip(
end

@testset "parallel stroboscopic" begin
# Generic Parallel

@inbounds function duffing_rule(x, p, t)
ω, f, d, β = p
dx1 = x[2]
Expand All @@ -122,10 +122,44 @@ states = [u0, u0 .+ 0.01]
pds_cont_oop = ParallelDynamicalSystem(duffing_oop, states)
pds_cont_iip = ParallelDynamicalSystem(duffing_iip, deepcopy(states))

#generic ds test
@testset "IIP=$iip" for (ds, iip) in zip((pds_cont_oop, pds_cont_iip,), (true, false))
test_dynamical_system(ds, u0, p0; idt = true, iip = true, test_trajectory = false)
end

#tests for multistate stuff

#alteration
states = [ones(2) for i in 1:2]
p = p0 .+ 0.1
for i in 1:2
set_state!(pds_cont_oop,states[i],i)
set_state!(pds_cont_iip,states[i],i)
end
set_parameters!(pds_cont_oop,p)
set_parameters!(pds_cont_iip,p)

#obtaining info
@test all(current_states(pds_cont_oop) .== states)
@test all(current_states(pds_cont_iip) .== states)
@test all(current_parameters(pds_cont_oop) .== p)
@test all(current_parameters(pds_cont_iip) .== p)

#time evolution
step!(pds_cont_oop)
@test all(current_states(pds_cont_oop)[1] .== current_states(pds_cont_oop)[2])
step!(pds_cont_iip)
@test all(current_states(pds_cont_iip)[1] .== current_states(pds_cont_iip)[2])

#reinit!
reinit!(pds_cont_oop)
reinit!(pds_cont_iip)
@test all(current_states(pds_cont_oop) .== initial_states(pds_cont_oop))
@test all(current_states(pds_cont_iip) .== initial_states(pds_cont_iip))

end

# TODO: Test that Lyapunovs of this match the original system
# But test this in ChaosTools.jl
# But test this in ChaosTools.jl

#benchmarks, comparison with old version
Loading