From 159c3ecab7280e36174ad95c3682e6724c1701a6 Mon Sep 17 00:00:00 2001 From: Marco Cusumano-Towner Date: Fri, 10 Apr 2020 20:35:33 -0400 Subject: [PATCH 1/8] Change src, but not yet tests --- src/dynamic/update.jl | 5 +++-- src/gen_fn_interface.jl | 14 ++++++++++---- src/inference/elliptical_slice.jl | 6 ++---- src/inference/hmc.jl | 4 +--- src/inference/involution_dsl.jl | 2 +- src/inference/mala.jl | 5 +---- src/inference/map_optimize.jl | 4 +--- src/inference/mh.jl | 5 +---- src/inference/particle_filter.jl | 5 +++-- src/modeling_library/call_at/call_at.jl | 2 +- src/modeling_library/custom_determ.jl | 2 +- src/modeling_library/map/update.jl | 2 +- src/modeling_library/recurse/recurse.jl | 4 ++-- src/modeling_library/unfold/update.jl | 2 +- src/static_ir/update.jl | 4 ++-- 15 files changed, 31 insertions(+), 35 deletions(-) diff --git a/src/dynamic/update.jl b/src/dynamic/update.jl index 24e023f2..60595d8a 100644 --- a/src/dynamic/update.jl +++ b/src/dynamic/update.jl @@ -87,8 +87,9 @@ function traceat(state::GFUpdateState, gen_fn::GenerativeFunction{T,U}, prev_call = get_call(state.prev_trace, key) prev_subtrace = prev_call.subtrace get_gen_fn(prev_subtrace) === gen_fn || gen_fn_changed_error(key) - (subtrace, weight, _, discard) = update(prev_subtrace, - args, map((_) -> UnknownChange(), args), constraints) + (subtrace, weight, _, discard) = update(prev_subtrace; + args=args, argdiffs=map((_) -> UnknownChange(), args), + constraints=constraints) else (subtrace, weight) = generate(gen_fn, args, constraints) end diff --git a/src/gen_fn_interface.jl b/src/gen_fn_interface.jl index 40b494a8..f05179d7 100644 --- a/src/gen_fn_interface.jl +++ b/src/gen_fn_interface.jl @@ -244,8 +244,11 @@ function assess(gen_fn::GenerativeFunction, args::Tuple, choices::ChoiceMap) end """ - (new_trace, weight, retdiff, discard) = update(trace, args::Tuple, argdiffs::Tuple, - constraints::ChoiceMap) + (new_trace, weight, retdiff, discard) = update( + trace; + args=get_args(trace), + argdiffs=map((_) -> NoChange(), args), + constraints=EmptyChoiceMap()) Update a trace by changing the arguments and/or providing new values for some existing random choice(s) and values for some newly introduced random choice(s). @@ -272,8 +275,11 @@ that if the original `trace` was generated using non-default argument values, then for each optional argument that is omitted, the old value will be over-written by the default argument value in the updated trace. """ -function update(trace, args::Tuple, argdiffs::Tuple, ::ChoiceMap) - error("Not implemented") +function update(trace; + args=get_args(trace), + argdiffs=map((_) -> NoChange(), args), + constraints=EmptyChoiceMap()) + update(trace, args, argdiffs, constraints) end """ diff --git a/src/inference/elliptical_slice.jl b/src/inference/elliptical_slice.jl index 3dc2b72a..615d759d 100644 --- a/src/inference/elliptical_slice.jl +++ b/src/inference/elliptical_slice.jl @@ -12,8 +12,6 @@ Also takes the mean vector and covariance matrix of the prior. """ function elliptical_slice( trace, addr, mu, cov; check=false, observations=EmptyChoiceMap()) - args = get_args(trace) - argdiffs = map((_) -> NoChange(), args) # sample nu nu = mvnormal(zeros(length(mu)), cov) @@ -29,7 +27,7 @@ function elliptical_slice( f = trace[addr] .- mu new_f = f * cos(theta) + nu * sin(theta) - new_trace, weight = update(trace, args, argdiffs, choicemap((addr, new_f .+ mu))) + new_trace, weight = update(trace; constraints=choicemap((addr, new_f .+ mu))) while weight <= log(u) if theta < 0 theta_min = theta @@ -38,7 +36,7 @@ function elliptical_slice( end theta = uniform(theta_min, theta_max) new_f = f * cos(theta) + nu * sin(theta) - new_trace, weight = update(trace, args, argdiffs, choicemap((addr, new_f .+ mu))) + new_trace, weight = update(trace; constraints=choicemap((addr, new_f .+ mu))) end check && check_observations(get_choices(new_trace), observations) return new_trace diff --git a/src/inference/hmc.jl b/src/inference/hmc.jl index 0f156669..bde4dda0 100644 --- a/src/inference/hmc.jl +++ b/src/inference/hmc.jl @@ -25,9 +25,7 @@ function hmc( trace::U, selection::Selection; L=10, eps=0.1, check=false, observations=EmptyChoiceMap()) where {T,U} prev_model_score = get_score(trace) - args = get_args(trace) retval_grad = accepts_output_grad(get_gen_fn(trace)) ? zero(get_retval(trace)) : nothing - argdiffs = map((_) -> NoChange(), args) # run leapfrog dynamics new_trace = trace @@ -46,7 +44,7 @@ function hmc( # get new gradient values_trie = from_array(values_trie, values) - (new_trace, _, _) = update(new_trace, args, argdiffs, values_trie) + (new_trace, _, _) = update(new_trace; constraints=values_trie) (_, _, gradient_trie) = choice_gradients(new_trace, selection, retval_grad) gradient = to_array(gradient_trie, Float64) diff --git a/src/inference/involution_dsl.jl b/src/inference/involution_dsl.jl index 2f08399b..7e27fcde 100644 --- a/src/inference/involution_dsl.jl +++ b/src/inference/involution_dsl.jl @@ -299,7 +299,7 @@ function apply_involution(involution::InvolutionDSLProgram, trace, u, proposal_a # update model trace (new_trace, model_weight, _, discard) = update( - trace, get_args(trace), map((_) -> NoChange(), get_args(trace)), first_pass_state.constraints) + trace; constraints=first_pass_state.constraints) # create input array and mappings input addresses that are needed for Jacobian # exclude addresses that were moved to another address diff --git a/src/inference/mala.jl b/src/inference/mala.jl index 033a45a7..8fc81849 100644 --- a/src/inference/mala.jl +++ b/src/inference/mala.jl @@ -11,8 +11,6 @@ Apply a Metropolis-Adjusted Langevin Algorithm (MALA) update. function mala( trace, selection::Selection, tau::Real; check=false, observations=EmptyChoiceMap()) - args = get_args(trace) - argdiffs = map((_) -> NoChange(), args) std = sqrt(2 * tau) retval_grad = accepts_output_grad(get_gen_fn(trace)) ? zero(get_retval(trace)) : nothing @@ -30,8 +28,7 @@ function mala( # evaluate model weight constraints = from_array(values_trie, proposed_values) - (new_trace, weight, _, discard) = update(trace, - args, argdiffs, constraints) + (new_trace, weight, _, discard) = update(trace; constraints=constraints) check && check_observations(get_choices(new_trace), observations) # backward proposal diff --git a/src/inference/map_optimize.jl b/src/inference/map_optimize.jl index 16d419e4..e01ce76f 100644 --- a/src/inference/map_optimize.jl +++ b/src/inference/map_optimize.jl @@ -8,8 +8,6 @@ Selected random choices must have support on the entire real line. """ function map_optimize(trace, selection::Selection; max_step_size=0.1, tau=0.5, min_step_size=1e-16, verbose=false) - args = get_args(trace) - argdiffs = map((_) -> NoChange(), args) retval_grad = accepts_output_grad(get_gen_fn(trace)) ? zero(get_retval(trace)) : nothing (_, values, gradient) = choice_gradients(trace, selection, retval_grad) @@ -21,7 +19,7 @@ function map_optimize(trace, selection::Selection; new_values_vec = values_vec + gradient_vec * step_size values = from_array(values, new_values_vec) # TODO discard and weight are not actually needed, there should be a more specialized variant - (new_trace, _, _, discard) = update(trace, args, argdiffs, values) + (new_trace, _, _, discard) = update(trace; constraints=values) new_score = get_score(new_trace) change = new_score - score if verbose diff --git a/src/inference/mh.jl b/src/inference/mh.jl index 7856b9ea..11978141 100644 --- a/src/inference/mh.jl +++ b/src/inference/mh.jl @@ -41,12 +41,9 @@ If the proposal modifies addresses that determine the control flow in the model, function metropolis_hastings( trace, proposal::GenerativeFunction, proposal_args::Tuple; check=false, observations=EmptyChoiceMap()) - model_args = get_args(trace) - argdiffs = map((_) -> NoChange(), model_args) proposal_args_forward = (trace, proposal_args...,) (fwd_choices, fwd_weight, _) = propose(proposal, proposal_args_forward) - (new_trace, weight, _, discard) = update(trace, - model_args, argdiffs, fwd_choices) + (new_trace, weight, _, discard) = update(trace; constraints=fwd_choices) proposal_args_backward = (new_trace, proposal_args...,) (bwd_weight, _) = assess(proposal, proposal_args_backward, discard) alpha = weight - fwd_weight + bwd_weight diff --git a/src/inference/particle_filter.jl b/src/inference/particle_filter.jl index 1d761bf6..7e4a3757 100644 --- a/src/inference/particle_filter.jl +++ b/src/inference/particle_filter.jl @@ -142,7 +142,8 @@ function particle_filter_step!(state::ParticleFilterState{U}, new_args::Tuple, a for i=1:num_particles (prop_choices, prop_weight, _) = propose(proposal, (state.traces[i], proposal_args...)) constraints = merge(observations, prop_choices) - (state.new_traces[i], up_weight, _, disc) = update(state.traces[i], new_args, argdiffs, constraints) + (state.new_traces[i], up_weight, _, disc) = update( + state.traces[i]; args=new_args, argdiffs=argdiffs, constraints=constraints) @assert isempty(disc) state.log_weights[i] += up_weight - prop_weight end @@ -166,7 +167,7 @@ function particle_filter_step!(state::ParticleFilterState{U}, new_args::Tuple, a num_particles = length(state.traces) for i=1:num_particles (state.new_traces[i], increment, _, discard) = update( - state.traces[i], new_args, argdiffs, observations) + state.traces[i]; args=new_args, argdiffs=argdiffs, constraints=observations) if !isempty(discard) error("Choices were updated or deleted inside particle filter step: $discard") end diff --git a/src/modeling_library/call_at/call_at.jl b/src/modeling_library/call_at/call_at.jl index 23411697..db82e85e 100644 --- a/src/modeling_library/call_at/call_at.jl +++ b/src/modeling_library/call_at/call_at.jl @@ -115,7 +115,7 @@ function update(trace::CallAtTrace, args::Tuple, argdiffs::Tuple, retdiff = UnknownChange() else (subtrace, weight, retdiff, subdiscard) = update( - trace.subtrace, kernel_args, argdiffs[1:end-1], submap) + trace.subtrace; args=kernel_args, argdiffs=argdiffs[1:end-1], constraints=submap) discard = CallAtChoiceMap(key, subdiscard) end new_trace = CallAtTrace(trace.gen_fn, subtrace, key) diff --git a/src/modeling_library/custom_determ.jl b/src/modeling_library/custom_determ.jl index 48045ddb..e0f8ca27 100644 --- a/src/modeling_library/custom_determ.jl +++ b/src/modeling_library/custom_determ.jl @@ -117,7 +117,7 @@ function update(trace::CustomDetermGFTrace{T,S}, args::Tuple, argdiffs::Tuple, c end function regenerate(trace::CustomDetermGFTrace, args::Tuple, argdiffs::Tuple, selection::Selection) - update(trace, args, argdiffs, EmptyChoiceMap()) + update(trace; args=args, argdiffs=argdiffs, constraints=EmptyChoiceMap()) end function choice_gradients(trace::CustomDetermGFTrace, selection::Selection, retgrad) diff --git a/src/modeling_library/map/update.jl b/src/modeling_library/map/update.jl index fd8ffc77..62d1a74d 100644 --- a/src/modeling_library/map/update.jl +++ b/src/modeling_library/map/update.jl @@ -22,7 +22,7 @@ function process_retained!(gen_fn::Map{T,U}, args::Tuple, # get new subtrace with recursive call to update() prev_subtrace = state.subtraces[key] (subtrace, weight, retdiff, discard) = update( - prev_subtrace, kernel_args, kernel_argdiffs, submap) + prev_subtrace; args=kernel_args, argdiffs=kernel_argdiffs, constraints=submap) # retrieve retdiff if retdiff != NoChange() diff --git a/src/modeling_library/recurse/recurse.jl b/src/modeling_library/recurse/recurse.jl index 71580073..fbbaf9a0 100644 --- a/src/modeling_library/recurse/recurse.jl +++ b/src/modeling_library/recurse/recurse.jl @@ -504,7 +504,7 @@ function update(trace::RecurseTrace{S,T,U,V,W,X,Y}, # call update on production kernel prev_subtrace = production_traces[cur] (subtrace, subweight, subretdiff, subdiscard) = update( - prev_subtrace, input, (subargdiff,), subconstraints) + prev_subtrace; args=input, argdiffs=(subargdiff,), constraints=subconstraints) prev_num_children = get_num_children(production_traces[cur]) new_num_children = length(get_retval(subtrace).children) idx_to_prev_num_children[cur] = prev_num_children @@ -622,7 +622,7 @@ function update(trace::RecurseTrace{S,T,U,V,W,X,Y}, # call update on aggregation kernel prev_subtrace = aggregation_traces[cur] (subtrace, subweight, subretdiff, subdiscard) = update( - prev_subtrace, input, subargdiffs, subconstraints) + prev_subtrace; args=input, argdiffs=subargdiffs, constraints=subconstraints) # update trace, weight, and score, and discard aggregation_traces = assoc(aggregation_traces, cur, subtrace) diff --git a/src/modeling_library/unfold/update.jl b/src/modeling_library/unfold/update.jl index d49cea30..4af8a5c6 100644 --- a/src/modeling_library/unfold/update.jl +++ b/src/modeling_library/unfold/update.jl @@ -25,7 +25,7 @@ function process_retained!(gen_fn::Unfold{T,U}, params::Tuple, # get new subtrace with recursive call to update() prev_subtrace = state.subtraces[key] (subtrace, weight, retdiff, discard) = update( - prev_subtrace, kernel_args, kernel_argdiffs, submap) + prev_subtrace; args=kernel_args, argdiffs=kernel_argdiffs, constraints=submap) # retrieve retdiff if retdiff != NoChange() diff --git a/src/static_ir/update.jl b/src/static_ir/update.jl index dc4fddf3..67f9ff9f 100644 --- a/src/static_ir/update.jl +++ b/src/static_ir/update.jl @@ -323,7 +323,7 @@ function process_codegen!(stmts, fwd::ForwardPassState, back::BackwardPassState, push!(stmts, :($call_constraints = $qn_empty_choice_map)) end push!(stmts, :(($subtrace, $call_weight, $(calldiff_var(node)), $(call_discard_var(node))) = - $qn_update($prev_subtrace, $(Expr(:tuple, arg_values...)), $(Expr(:tuple, arg_diffs...)), $call_constraints))) + $qn_update($prev_subtrace; args=$(Expr(:tuple, arg_values...)), argdiffs=$(Expr(:tuple, arg_diffs...)), constraints=$call_constraints))) push!(stmts, :($weight += $call_weight)) push!(stmts, :($total_score_fieldname += $qn_get_score($subtrace) - $qn_get_score($prev_subtrace))) push!(stmts, :($total_noise_fieldname += $qn_project($subtrace, $qn_empty_selection) - $qn_project($prev_subtrace, $qn_empty_selection))) @@ -471,7 +471,7 @@ function codegen_update(trace_type::Type{T}, args_type::Type, argdiffs_type::Typ # convert the constraints to a static assignment if it is not already one if !(isa(schema, StaticAddressSchema) || isa(schema, EmptyAddressSchema)) - return quote $qn_update(trace, args, argdiffs, $(QuoteNode(StaticChoiceMap))(constraints)) end + return quote $qn_update(trace; args=args, argdiffs=argdiffs, constraints=$(QuoteNode(StaticChoiceMap))(constraints)) end end ir = get_ir(gen_fn_type) From ba4ab5b285fd8156cdd14d06614fccf465f0a395 Mon Sep 17 00:00:00 2001 From: Marco Cusumano-Towner Date: Fri, 10 Apr 2020 20:50:13 -0400 Subject: [PATCH 2/8] Revert GFI implementations to use the GFI version of update --- src/dynamic/update.jl | 5 ++--- src/gen_fn_interface.jl | 23 ++++++++++++++++------- src/modeling_library/call_at/call_at.jl | 2 +- src/modeling_library/custom_determ.jl | 2 +- src/modeling_library/map/update.jl | 2 +- src/modeling_library/recurse/recurse.jl | 4 ++-- src/modeling_library/unfold/update.jl | 2 +- src/static_ir/update.jl | 4 ++-- 8 files changed, 26 insertions(+), 18 deletions(-) diff --git a/src/dynamic/update.jl b/src/dynamic/update.jl index 60595d8a..24e023f2 100644 --- a/src/dynamic/update.jl +++ b/src/dynamic/update.jl @@ -87,9 +87,8 @@ function traceat(state::GFUpdateState, gen_fn::GenerativeFunction{T,U}, prev_call = get_call(state.prev_trace, key) prev_subtrace = prev_call.subtrace get_gen_fn(prev_subtrace) === gen_fn || gen_fn_changed_error(key) - (subtrace, weight, _, discard) = update(prev_subtrace; - args=args, argdiffs=map((_) -> UnknownChange(), args), - constraints=constraints) + (subtrace, weight, _, discard) = update(prev_subtrace, + args, map((_) -> UnknownChange(), args), constraints) else (subtrace, weight) = generate(gen_fn, args, constraints) end diff --git a/src/gen_fn_interface.jl b/src/gen_fn_interface.jl index f05179d7..3354504c 100644 --- a/src/gen_fn_interface.jl +++ b/src/gen_fn_interface.jl @@ -245,10 +245,7 @@ end """ (new_trace, weight, retdiff, discard) = update( - trace; - args=get_args(trace), - argdiffs=map((_) -> NoChange(), args), - constraints=EmptyChoiceMap()) + trace, args::Tuple, argdiffs::Tuple, constraints::ChoiceMap) Update a trace by changing the arguments and/or providing new values for some existing random choice(s) and values for some newly introduced random choice(s). @@ -275,10 +272,22 @@ that if the original `trace` was generated using non-default argument values, then for each optional argument that is omitted, the old value will be over-written by the default argument value in the updated trace. """ +function update(trace, args::Tuple, argdiffs::Tuple, constraints::ChoiceMap) + error("Not implemented") +end + +""" + update(trace; + args::Tuple=get_args(trace), + argdiffs::Tuple=map((_) -> NoChange(), args), + constraints::ChoiceMap=EmptyChoiceMap()) + +Form of `update` with keyword arguments providing common defaults. +""" function update(trace; - args=get_args(trace), - argdiffs=map((_) -> NoChange(), args), - constraints=EmptyChoiceMap()) + args::Tuple=get_args(trace), + argdiffs::Tuple=map((_) -> NoChange(), args), + constraints::ChoiceMap=EmptyChoiceMap()) update(trace, args, argdiffs, constraints) end diff --git a/src/modeling_library/call_at/call_at.jl b/src/modeling_library/call_at/call_at.jl index db82e85e..23411697 100644 --- a/src/modeling_library/call_at/call_at.jl +++ b/src/modeling_library/call_at/call_at.jl @@ -115,7 +115,7 @@ function update(trace::CallAtTrace, args::Tuple, argdiffs::Tuple, retdiff = UnknownChange() else (subtrace, weight, retdiff, subdiscard) = update( - trace.subtrace; args=kernel_args, argdiffs=argdiffs[1:end-1], constraints=submap) + trace.subtrace, kernel_args, argdiffs[1:end-1], submap) discard = CallAtChoiceMap(key, subdiscard) end new_trace = CallAtTrace(trace.gen_fn, subtrace, key) diff --git a/src/modeling_library/custom_determ.jl b/src/modeling_library/custom_determ.jl index e0f8ca27..48045ddb 100644 --- a/src/modeling_library/custom_determ.jl +++ b/src/modeling_library/custom_determ.jl @@ -117,7 +117,7 @@ function update(trace::CustomDetermGFTrace{T,S}, args::Tuple, argdiffs::Tuple, c end function regenerate(trace::CustomDetermGFTrace, args::Tuple, argdiffs::Tuple, selection::Selection) - update(trace; args=args, argdiffs=argdiffs, constraints=EmptyChoiceMap()) + update(trace, args, argdiffs, EmptyChoiceMap()) end function choice_gradients(trace::CustomDetermGFTrace, selection::Selection, retgrad) diff --git a/src/modeling_library/map/update.jl b/src/modeling_library/map/update.jl index 62d1a74d..fd8ffc77 100644 --- a/src/modeling_library/map/update.jl +++ b/src/modeling_library/map/update.jl @@ -22,7 +22,7 @@ function process_retained!(gen_fn::Map{T,U}, args::Tuple, # get new subtrace with recursive call to update() prev_subtrace = state.subtraces[key] (subtrace, weight, retdiff, discard) = update( - prev_subtrace; args=kernel_args, argdiffs=kernel_argdiffs, constraints=submap) + prev_subtrace, kernel_args, kernel_argdiffs, submap) # retrieve retdiff if retdiff != NoChange() diff --git a/src/modeling_library/recurse/recurse.jl b/src/modeling_library/recurse/recurse.jl index fbbaf9a0..71580073 100644 --- a/src/modeling_library/recurse/recurse.jl +++ b/src/modeling_library/recurse/recurse.jl @@ -504,7 +504,7 @@ function update(trace::RecurseTrace{S,T,U,V,W,X,Y}, # call update on production kernel prev_subtrace = production_traces[cur] (subtrace, subweight, subretdiff, subdiscard) = update( - prev_subtrace; args=input, argdiffs=(subargdiff,), constraints=subconstraints) + prev_subtrace, input, (subargdiff,), subconstraints) prev_num_children = get_num_children(production_traces[cur]) new_num_children = length(get_retval(subtrace).children) idx_to_prev_num_children[cur] = prev_num_children @@ -622,7 +622,7 @@ function update(trace::RecurseTrace{S,T,U,V,W,X,Y}, # call update on aggregation kernel prev_subtrace = aggregation_traces[cur] (subtrace, subweight, subretdiff, subdiscard) = update( - prev_subtrace; args=input, argdiffs=subargdiffs, constraints=subconstraints) + prev_subtrace, input, subargdiffs, subconstraints) # update trace, weight, and score, and discard aggregation_traces = assoc(aggregation_traces, cur, subtrace) diff --git a/src/modeling_library/unfold/update.jl b/src/modeling_library/unfold/update.jl index 4af8a5c6..d49cea30 100644 --- a/src/modeling_library/unfold/update.jl +++ b/src/modeling_library/unfold/update.jl @@ -25,7 +25,7 @@ function process_retained!(gen_fn::Unfold{T,U}, params::Tuple, # get new subtrace with recursive call to update() prev_subtrace = state.subtraces[key] (subtrace, weight, retdiff, discard) = update( - prev_subtrace; args=kernel_args, argdiffs=kernel_argdiffs, constraints=submap) + prev_subtrace, kernel_args, kernel_argdiffs, submap) # retrieve retdiff if retdiff != NoChange() diff --git a/src/static_ir/update.jl b/src/static_ir/update.jl index 67f9ff9f..dc4fddf3 100644 --- a/src/static_ir/update.jl +++ b/src/static_ir/update.jl @@ -323,7 +323,7 @@ function process_codegen!(stmts, fwd::ForwardPassState, back::BackwardPassState, push!(stmts, :($call_constraints = $qn_empty_choice_map)) end push!(stmts, :(($subtrace, $call_weight, $(calldiff_var(node)), $(call_discard_var(node))) = - $qn_update($prev_subtrace; args=$(Expr(:tuple, arg_values...)), argdiffs=$(Expr(:tuple, arg_diffs...)), constraints=$call_constraints))) + $qn_update($prev_subtrace, $(Expr(:tuple, arg_values...)), $(Expr(:tuple, arg_diffs...)), $call_constraints))) push!(stmts, :($weight += $call_weight)) push!(stmts, :($total_score_fieldname += $qn_get_score($subtrace) - $qn_get_score($prev_subtrace))) push!(stmts, :($total_noise_fieldname += $qn_project($subtrace, $qn_empty_selection) - $qn_project($prev_subtrace, $qn_empty_selection))) @@ -471,7 +471,7 @@ function codegen_update(trace_type::Type{T}, args_type::Type, argdiffs_type::Typ # convert the constraints to a static assignment if it is not already one if !(isa(schema, StaticAddressSchema) || isa(schema, EmptyAddressSchema)) - return quote $qn_update(trace; args=args, argdiffs=argdiffs, constraints=$(QuoteNode(StaticChoiceMap))(constraints)) end + return quote $qn_update(trace, args, argdiffs, $(QuoteNode(StaticChoiceMap))(constraints)) end end ir = get_ir(gen_fn_type) From 5e2343e8ed8b7163e06a2ced8a48a99ff5a30e0d Mon Sep 17 00:00:00 2001 From: Marco Cusumano-Towner Date: Fri, 10 Apr 2020 20:55:31 -0400 Subject: [PATCH 3/8] do the same for regenerate --- src/gen_fn_interface.jl | 16 ++++++++++++++++ src/inference/mh.jl | 4 +--- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/src/gen_fn_interface.jl b/src/gen_fn_interface.jl index 3354504c..7683ba4b 100644 --- a/src/gen_fn_interface.jl +++ b/src/gen_fn_interface.jl @@ -322,6 +322,22 @@ function regenerate(trace, args::Tuple, argdiffs::Tuple, selection::Selection) error("Not implemented") end +""" + regenerate(trace; + args::Tuple=get_args(trace), + argdiffs::Tuple=map((_) -> NoChange(), args), + selection::Selection=EmptySelection()) + +Form of `regenerate` with keyword arguments providing common defaults. +""" +function regenerate(trace; + args::Tuple=get_args(trace), + argdiffs::Tuple=map((_) -> NoChange(), args), + selection::Selection=EmptySelection()) + regenerate(trace, args, argdiffs, selection) +end + + """ arg_grads = accumulate_param_gradients!(trace, retgrad=nothing, scale_factor=1.) diff --git a/src/inference/mh.jl b/src/inference/mh.jl index 11978141..bba2b0a1 100644 --- a/src/inference/mh.jl +++ b/src/inference/mh.jl @@ -14,9 +14,7 @@ Perform a Metropolis-Hastings update that proposes new values for the selected a function metropolis_hastings( trace, selection::Selection; check=false, observations=EmptyChoiceMap()) - args = get_args(trace) - argdiffs = map((_) -> NoChange(), args) - (new_trace, weight) = regenerate(trace, args, argdiffs, selection) + (new_trace, weight) = regenerate(trace; selection=selection) check && check_observations(get_choices(new_trace), observations) if log(rand()) < weight # accept From 485a675fb4ba896500cbdb927ec6492b6db6d9c5 Mon Sep 17 00:00:00 2001 From: Marco Cusumano-Towner Date: Sat, 18 Apr 2020 14:51:46 -0400 Subject: [PATCH 4/8] Make unknownchange the default for argdiffs, but add a form where there is no change to args --- src/gen_fn_interface.jl | 31 +++++++++++++++++++++++-------- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/src/gen_fn_interface.jl b/src/gen_fn_interface.jl index 7683ba4b..6c80ab4f 100644 --- a/src/gen_fn_interface.jl +++ b/src/gen_fn_interface.jl @@ -244,8 +244,11 @@ function assess(gen_fn::GenerativeFunction, args::Tuple, choices::ChoiceMap) end """ + (new_trace, weight, retdiff, discard) = update( - trace, args::Tuple, argdiffs::Tuple, constraints::ChoiceMap) + trace, args::Tuple; + argdiffs::Tuple=map((_) -> UnknownChange(), args), + constraints::ChoiceMap=EmptyChoiceMap()) Update a trace by changing the arguments and/or providing new values for some existing random choice(s) and values for some newly introduced random choice(s). @@ -277,20 +280,32 @@ function update(trace, args::Tuple, argdiffs::Tuple, constraints::ChoiceMap) end """ - update(trace; - args::Tuple=get_args(trace), - argdiffs::Tuple=map((_) -> NoChange(), args), + (new_trace, weight, retdiff, discard) = update( + trace, args::Tuple; + argdiffs::Tuple=map((_) -> UnknownChange(), args), constraints::ChoiceMap=EmptyChoiceMap()) -Form of `update` with keyword arguments providing common defaults. +Convenience form of `update` with keyword arguments for argdiffs and constraints. """ -function update(trace; - args::Tuple=get_args(trace), - argdiffs::Tuple=map((_) -> NoChange(), args), +function update(trace, args; + argdiffs::Tuple=map((_) -> UnknownChange(), args), constraints::ChoiceMap=EmptyChoiceMap()) update(trace, args, argdiffs, constraints) end +""" + (new_trace, weight, retdiff, discard) = update( + trace; constraints::ChoiceMap=EmptyChoiceMap()) + +Convenience form of `update` when there is no change to arguments. +""" +function update(trace; constraints::ChoiceMap=EmptyChoiceMap()) + args = get_args(trace) + argdiffs = map((_) -> NoChange(), args) + update(trace, args, argdiffs, constraints) +end + + """ (new_trace, weight, retdiff) = regenerate(trace, args::Tuple, argdiffs::Tuple, selection::Selection) From 1884f57de4e3b1b72ee64045da7a7327f226f8f6 Mon Sep 17 00:00:00 2001 From: Marco Cusumano-Towner Date: Sat, 18 Apr 2020 17:13:38 -0400 Subject: [PATCH 5/8] Add two-argument and three-argument forms of update and regenerate --- src/gen_fn_interface.jl | 44 ++++++++++++++++--------------- src/inference/elliptical_slice.jl | 4 +-- src/inference/hmc.jl | 2 +- src/inference/involution_dsl.jl | 2 +- src/inference/mala.jl | 2 +- src/inference/map_optimize.jl | 2 +- src/inference/mh.jl | 4 +-- src/inference/particle_filter.jl | 4 +-- 8 files changed, 33 insertions(+), 31 deletions(-) diff --git a/src/gen_fn_interface.jl b/src/gen_fn_interface.jl index 6c80ab4f..97e692c3 100644 --- a/src/gen_fn_interface.jl +++ b/src/gen_fn_interface.jl @@ -246,9 +246,7 @@ end """ (new_trace, weight, retdiff, discard) = update( - trace, args::Tuple; - argdiffs::Tuple=map((_) -> UnknownChange(), args), - constraints::ChoiceMap=EmptyChoiceMap()) + trace, args::Tuple, argdiffs::Tuple, constraints::ChoiceMap) Update a trace by changing the arguments and/or providing new values for some existing random choice(s) and values for some newly introduced random choice(s). @@ -281,25 +279,22 @@ end """ (new_trace, weight, retdiff, discard) = update( - trace, args::Tuple; - argdiffs::Tuple=map((_) -> UnknownChange(), args), - constraints::ChoiceMap=EmptyChoiceMap()) + trace, constraints::ChoiceMap, args::Tuple; + argdiffs::Tuple=map((_) -> UnknownChange(), args)) Convenience form of `update` with keyword arguments for argdiffs and constraints. """ -function update(trace, args; - argdiffs::Tuple=map((_) -> UnknownChange(), args), - constraints::ChoiceMap=EmptyChoiceMap()) +function update(trace, constraints::ChoiceMap, args::Tuple; + argdiffs::Tuple=map((_) -> UnknownChange(), args)) update(trace, args, argdiffs, constraints) end """ - (new_trace, weight, retdiff, discard) = update( - trace; constraints::ChoiceMap=EmptyChoiceMap()) + (new_trace, weight, retdiff, discard) = update(trace, constraints::ChoiceMap) Convenience form of `update` when there is no change to arguments. """ -function update(trace; constraints::ChoiceMap=EmptyChoiceMap()) +function update(trace, constraints::ChoiceMap) args = get_args(trace) argdiffs = map((_) -> NoChange(), args) update(trace, args, argdiffs, constraints) @@ -338,17 +333,24 @@ function regenerate(trace, args::Tuple, argdiffs::Tuple, selection::Selection) end """ - regenerate(trace; - args::Tuple=get_args(trace), - argdiffs::Tuple=map((_) -> NoChange(), args), - selection::Selection=EmptySelection()) + regenerate(trace, selection::Selection) + +Convenience form of `regenerate` when there is no change to arguments. +""" +function regenerate(trace, selection::Selection) + args = get_args(trace) + argdiffs = map((_) -> NoChange(), args) + regenerate(trace, args, argdiffs, selection) +end + +""" + regenerate(trace, selection::Selection, args::Tuple=get_args(trace); + argdiffs::Tuple=map((_) -> UnknownChange(), args), -Form of `regenerate` with keyword arguments providing common defaults. +Convenience form of `regenerate` with keyword arguments for argdiffs and constraints. """ -function regenerate(trace; - args::Tuple=get_args(trace), - argdiffs::Tuple=map((_) -> NoChange(), args), - selection::Selection=EmptySelection()) +function regenerate(trace, selection::Selection, args::Tuple; + argdiffs::Tuple=map((_) -> UnknownChange(), args)) regenerate(trace, args, argdiffs, selection) end diff --git a/src/inference/elliptical_slice.jl b/src/inference/elliptical_slice.jl index 615d759d..eca3b2e0 100644 --- a/src/inference/elliptical_slice.jl +++ b/src/inference/elliptical_slice.jl @@ -27,7 +27,7 @@ function elliptical_slice( f = trace[addr] .- mu new_f = f * cos(theta) + nu * sin(theta) - new_trace, weight = update(trace; constraints=choicemap((addr, new_f .+ mu))) + new_trace, weight = update(trace, choicemap((addr, new_f .+ mu))) while weight <= log(u) if theta < 0 theta_min = theta @@ -36,7 +36,7 @@ function elliptical_slice( end theta = uniform(theta_min, theta_max) new_f = f * cos(theta) + nu * sin(theta) - new_trace, weight = update(trace; constraints=choicemap((addr, new_f .+ mu))) + new_trace, weight = update(trace, choicemap((addr, new_f .+ mu))) end check && check_observations(get_choices(new_trace), observations) return new_trace diff --git a/src/inference/hmc.jl b/src/inference/hmc.jl index bde4dda0..3daad568 100644 --- a/src/inference/hmc.jl +++ b/src/inference/hmc.jl @@ -44,7 +44,7 @@ function hmc( # get new gradient values_trie = from_array(values_trie, values) - (new_trace, _, _) = update(new_trace; constraints=values_trie) + (new_trace, _, _) = update(new_trace, values_trie) (_, _, gradient_trie) = choice_gradients(new_trace, selection, retval_grad) gradient = to_array(gradient_trie, Float64) diff --git a/src/inference/involution_dsl.jl b/src/inference/involution_dsl.jl index 7e27fcde..821f4973 100644 --- a/src/inference/involution_dsl.jl +++ b/src/inference/involution_dsl.jl @@ -299,7 +299,7 @@ function apply_involution(involution::InvolutionDSLProgram, trace, u, proposal_a # update model trace (new_trace, model_weight, _, discard) = update( - trace; constraints=first_pass_state.constraints) + trace, first_pass_state.constraints) # create input array and mappings input addresses that are needed for Jacobian # exclude addresses that were moved to another address diff --git a/src/inference/mala.jl b/src/inference/mala.jl index 8fc81849..1bdba1fd 100644 --- a/src/inference/mala.jl +++ b/src/inference/mala.jl @@ -28,7 +28,7 @@ function mala( # evaluate model weight constraints = from_array(values_trie, proposed_values) - (new_trace, weight, _, discard) = update(trace; constraints=constraints) + (new_trace, weight, _, discard) = update(trace, constraints) check && check_observations(get_choices(new_trace), observations) # backward proposal diff --git a/src/inference/map_optimize.jl b/src/inference/map_optimize.jl index e01ce76f..8a75c94f 100644 --- a/src/inference/map_optimize.jl +++ b/src/inference/map_optimize.jl @@ -19,7 +19,7 @@ function map_optimize(trace, selection::Selection; new_values_vec = values_vec + gradient_vec * step_size values = from_array(values, new_values_vec) # TODO discard and weight are not actually needed, there should be a more specialized variant - (new_trace, _, _, discard) = update(trace; constraints=values) + (new_trace, _, _, discard) = update(trace, values) new_score = get_score(new_trace) change = new_score - score if verbose diff --git a/src/inference/mh.jl b/src/inference/mh.jl index bba2b0a1..7e66f3e0 100644 --- a/src/inference/mh.jl +++ b/src/inference/mh.jl @@ -14,7 +14,7 @@ Perform a Metropolis-Hastings update that proposes new values for the selected a function metropolis_hastings( trace, selection::Selection; check=false, observations=EmptyChoiceMap()) - (new_trace, weight) = regenerate(trace; selection=selection) + (new_trace, weight) = regenerate(trace, selection) check && check_observations(get_choices(new_trace), observations) if log(rand()) < weight # accept @@ -41,7 +41,7 @@ function metropolis_hastings( check=false, observations=EmptyChoiceMap()) proposal_args_forward = (trace, proposal_args...,) (fwd_choices, fwd_weight, _) = propose(proposal, proposal_args_forward) - (new_trace, weight, _, discard) = update(trace; constraints=fwd_choices) + (new_trace, weight, _, discard) = update(trace, fwd_choices) proposal_args_backward = (new_trace, proposal_args...,) (bwd_weight, _) = assess(proposal, proposal_args_backward, discard) alpha = weight - fwd_weight + bwd_weight diff --git a/src/inference/particle_filter.jl b/src/inference/particle_filter.jl index 7e4a3757..382ab6f6 100644 --- a/src/inference/particle_filter.jl +++ b/src/inference/particle_filter.jl @@ -143,7 +143,7 @@ function particle_filter_step!(state::ParticleFilterState{U}, new_args::Tuple, a (prop_choices, prop_weight, _) = propose(proposal, (state.traces[i], proposal_args...)) constraints = merge(observations, prop_choices) (state.new_traces[i], up_weight, _, disc) = update( - state.traces[i]; args=new_args, argdiffs=argdiffs, constraints=constraints) + state.traces[i], new_args, argdiffs, constraints) @assert isempty(disc) state.log_weights[i] += up_weight - prop_weight end @@ -167,7 +167,7 @@ function particle_filter_step!(state::ParticleFilterState{U}, new_args::Tuple, a num_particles = length(state.traces) for i=1:num_particles (state.new_traces[i], increment, _, discard) = update( - state.traces[i]; args=new_args, argdiffs=argdiffs, constraints=observations) + state.traces[i], new_args, argdiffs, observations) if !isempty(discard) error("Choices were updated or deleted inside particle filter step: $discard") end From 59c961cc3a354ba0627795a515cf2c23e376cfaf Mon Sep 17 00:00:00 2001 From: Marco Cusumano-Towner Date: Sat, 18 Apr 2020 17:15:07 -0400 Subject: [PATCH 6/8] Fix docstrings --- src/gen_fn_interface.jl | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/src/gen_fn_interface.jl b/src/gen_fn_interface.jl index 97e692c3..6436ee1b 100644 --- a/src/gen_fn_interface.jl +++ b/src/gen_fn_interface.jl @@ -282,7 +282,7 @@ end trace, constraints::ChoiceMap, args::Tuple; argdiffs::Tuple=map((_) -> UnknownChange(), args)) -Convenience form of `update` with keyword arguments for argdiffs and constraints. +Convenience form of `update` with keyword argument for argdiffs. """ function update(trace, constraints::ChoiceMap, args::Tuple; argdiffs::Tuple=map((_) -> UnknownChange(), args)) @@ -333,28 +333,29 @@ function regenerate(trace, args::Tuple, argdiffs::Tuple, selection::Selection) end """ - regenerate(trace, selection::Selection) + regenerate(trace, selection::Selection, args::Tuple=get_args(trace); + argdiffs::Tuple=map((_) -> UnknownChange(), args), -Convenience form of `regenerate` when there is no change to arguments. +Convenience form of `regenerate` with keyword arguments for argdiffs. """ -function regenerate(trace, selection::Selection) - args = get_args(trace) - argdiffs = map((_) -> NoChange(), args) +function regenerate(trace, selection::Selection, args::Tuple; + argdiffs::Tuple=map((_) -> UnknownChange(), args)) regenerate(trace, args, argdiffs, selection) end """ - regenerate(trace, selection::Selection, args::Tuple=get_args(trace); - argdiffs::Tuple=map((_) -> UnknownChange(), args), + regenerate(trace, selection::Selection) -Convenience form of `regenerate` with keyword arguments for argdiffs and constraints. +Convenience form of `regenerate` when there is no change to arguments. """ -function regenerate(trace, selection::Selection, args::Tuple; - argdiffs::Tuple=map((_) -> UnknownChange(), args)) +function regenerate(trace, selection::Selection) + args = get_args(trace) + argdiffs = map((_) -> NoChange(), args) regenerate(trace, args, argdiffs, selection) end + """ arg_grads = accumulate_param_gradients!(trace, retgrad=nothing, scale_factor=1.) From bc2fa2ff6052a52bdf750447c2f213997b6ac10b Mon Sep 17 00:00:00 2001 From: Marco Cusumano-Towner Date: Sat, 18 Apr 2020 17:17:00 -0400 Subject: [PATCH 7/8] Fix docstrings --- src/gen_fn_interface.jl | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/gen_fn_interface.jl b/src/gen_fn_interface.jl index 6436ee1b..9f9a75b2 100644 --- a/src/gen_fn_interface.jl +++ b/src/gen_fn_interface.jl @@ -278,7 +278,7 @@ function update(trace, args::Tuple, argdiffs::Tuple, constraints::ChoiceMap) end """ - (new_trace, weight, retdiff, discard) = update( + update( trace, constraints::ChoiceMap, args::Tuple; argdiffs::Tuple=map((_) -> UnknownChange(), args)) @@ -290,7 +290,7 @@ function update(trace, constraints::ChoiceMap, args::Tuple; end """ - (new_trace, weight, retdiff, discard) = update(trace, constraints::ChoiceMap) + update(trace, constraints::ChoiceMap) Convenience form of `update` when there is no change to arguments. """ @@ -302,8 +302,8 @@ end """ - (new_trace, weight, retdiff) = regenerate(trace, args::Tuple, argdiffs::Tuple, - selection::Selection) + (new_trace, weight, retdiff) = regenerate( + trace, args::Tuple, argdiffs::Tuple, selection::Selection) Update a trace by changing the arguments and/or randomly sampling new values for selected random choices using the internal proposal distribution family. @@ -333,7 +333,8 @@ function regenerate(trace, args::Tuple, argdiffs::Tuple, selection::Selection) end """ - regenerate(trace, selection::Selection, args::Tuple=get_args(trace); + regenerate( + trace, selection::Selection, args::Tuple; argdiffs::Tuple=map((_) -> UnknownChange(), args), Convenience form of `regenerate` with keyword arguments for argdiffs. From 0da203d9f15c39d16935e87bf420e35ec69d06f6 Mon Sep 17 00:00:00 2001 From: Marco Cusumano-Towner Date: Sat, 18 Apr 2020 17:23:06 -0400 Subject: [PATCH 8/8] remove unecessary argdiffs from dynamic update and regenerate impl --- src/dynamic/regenerate.jl | 3 +-- src/dynamic/update.jl | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/dynamic/regenerate.jl b/src/dynamic/regenerate.jl index 81ba8b3c..92cf6ad2 100644 --- a/src/dynamic/regenerate.jl +++ b/src/dynamic/regenerate.jl @@ -74,8 +74,7 @@ function traceat(state::GFRegenerateState, gen_fn::GenerativeFunction{T,U}, prev_call = get_call(state.prev_trace, key) prev_subtrace = prev_call.subtrace get_gen_fn(prev_subtrace) === gen_fn || gen_fn_changed_error(key) - (subtrace, weight, _) = regenerate( - prev_subtrace, args, map((_) -> UnknownChange(), args), subselection) + (subtrace, weight, _) = regenerate(prev_subtrace, args, subselection) else (subtrace, weight) = generate(gen_fn, args, EmptyChoiceMap()) end diff --git a/src/dynamic/update.jl b/src/dynamic/update.jl index 24e023f2..7621344e 100644 --- a/src/dynamic/update.jl +++ b/src/dynamic/update.jl @@ -87,8 +87,7 @@ function traceat(state::GFUpdateState, gen_fn::GenerativeFunction{T,U}, prev_call = get_call(state.prev_trace, key) prev_subtrace = prev_call.subtrace get_gen_fn(prev_subtrace) === gen_fn || gen_fn_changed_error(key) - (subtrace, weight, _, discard) = update(prev_subtrace, - args, map((_) -> UnknownChange(), args), constraints) + (subtrace, weight, _, discard) = update(prev_subtrace, args, constraints) else (subtrace, weight) = generate(gen_fn, args, constraints) end