diff --git a/src/dynamic/regenerate.jl b/src/dynamic/regenerate.jl index 81ba8b3c..92cf6ad2 100644 --- a/src/dynamic/regenerate.jl +++ b/src/dynamic/regenerate.jl @@ -74,8 +74,7 @@ function traceat(state::GFRegenerateState, gen_fn::GenerativeFunction{T,U}, prev_call = get_call(state.prev_trace, key) prev_subtrace = prev_call.subtrace get_gen_fn(prev_subtrace) === gen_fn || gen_fn_changed_error(key) - (subtrace, weight, _) = regenerate( - prev_subtrace, args, map((_) -> UnknownChange(), args), subselection) + (subtrace, weight, _) = regenerate(prev_subtrace, args, subselection) else (subtrace, weight) = generate(gen_fn, args, EmptyChoiceMap()) end diff --git a/src/dynamic/update.jl b/src/dynamic/update.jl index 24e023f2..7621344e 100644 --- a/src/dynamic/update.jl +++ b/src/dynamic/update.jl @@ -87,8 +87,7 @@ function traceat(state::GFUpdateState, gen_fn::GenerativeFunction{T,U}, prev_call = get_call(state.prev_trace, key) prev_subtrace = prev_call.subtrace get_gen_fn(prev_subtrace) === gen_fn || gen_fn_changed_error(key) - (subtrace, weight, _, discard) = update(prev_subtrace, - args, map((_) -> UnknownChange(), args), constraints) + (subtrace, weight, _, discard) = update(prev_subtrace, args, constraints) else (subtrace, weight) = generate(gen_fn, args, constraints) end diff --git a/src/gen_fn_interface.jl b/src/gen_fn_interface.jl index 40b494a8..9f9a75b2 100644 --- a/src/gen_fn_interface.jl +++ b/src/gen_fn_interface.jl @@ -244,8 +244,9 @@ function assess(gen_fn::GenerativeFunction, args::Tuple, choices::ChoiceMap) end """ - (new_trace, weight, retdiff, discard) = update(trace, args::Tuple, argdiffs::Tuple, - constraints::ChoiceMap) + + (new_trace, weight, retdiff, discard) = update( + trace, args::Tuple, argdiffs::Tuple, constraints::ChoiceMap) Update a trace by changing the arguments and/or providing new values for some existing random choice(s) and values for some newly introduced random choice(s). @@ -272,13 +273,37 @@ that if the original `trace` was generated using non-default argument values, then for each optional argument that is omitted, the old value will be over-written by the default argument value in the updated trace. """ -function update(trace, args::Tuple, argdiffs::Tuple, ::ChoiceMap) +function update(trace, args::Tuple, argdiffs::Tuple, constraints::ChoiceMap) error("Not implemented") end """ - (new_trace, weight, retdiff) = regenerate(trace, args::Tuple, argdiffs::Tuple, - selection::Selection) + update( + trace, constraints::ChoiceMap, args::Tuple; + argdiffs::Tuple=map((_) -> UnknownChange(), args)) + +Convenience form of `update` with keyword argument for argdiffs. +""" +function update(trace, constraints::ChoiceMap, args::Tuple; + argdiffs::Tuple=map((_) -> UnknownChange(), args)) + update(trace, args, argdiffs, constraints) +end + +""" + update(trace, constraints::ChoiceMap) + +Convenience form of `update` when there is no change to arguments. +""" +function update(trace, constraints::ChoiceMap) + args = get_args(trace) + argdiffs = map((_) -> NoChange(), args) + update(trace, args, argdiffs, constraints) +end + + +""" + (new_trace, weight, retdiff) = regenerate( + trace, args::Tuple, argdiffs::Tuple, selection::Selection) Update a trace by changing the arguments and/or randomly sampling new values for selected random choices using the internal proposal distribution family. @@ -307,6 +332,31 @@ function regenerate(trace, args::Tuple, argdiffs::Tuple, selection::Selection) error("Not implemented") end +""" + regenerate( + trace, selection::Selection, args::Tuple; + argdiffs::Tuple=map((_) -> UnknownChange(), args), + +Convenience form of `regenerate` with keyword arguments for argdiffs. +""" +function regenerate(trace, selection::Selection, args::Tuple; + argdiffs::Tuple=map((_) -> UnknownChange(), args)) + regenerate(trace, args, argdiffs, selection) +end + +""" + regenerate(trace, selection::Selection) + +Convenience form of `regenerate` when there is no change to arguments. +""" +function regenerate(trace, selection::Selection) + args = get_args(trace) + argdiffs = map((_) -> NoChange(), args) + regenerate(trace, args, argdiffs, selection) +end + + + """ arg_grads = accumulate_param_gradients!(trace, retgrad=nothing, scale_factor=1.) diff --git a/src/inference/elliptical_slice.jl b/src/inference/elliptical_slice.jl index 3dc2b72a..eca3b2e0 100644 --- a/src/inference/elliptical_slice.jl +++ b/src/inference/elliptical_slice.jl @@ -12,8 +12,6 @@ Also takes the mean vector and covariance matrix of the prior. """ function elliptical_slice( trace, addr, mu, cov; check=false, observations=EmptyChoiceMap()) - args = get_args(trace) - argdiffs = map((_) -> NoChange(), args) # sample nu nu = mvnormal(zeros(length(mu)), cov) @@ -29,7 +27,7 @@ function elliptical_slice( f = trace[addr] .- mu new_f = f * cos(theta) + nu * sin(theta) - new_trace, weight = update(trace, args, argdiffs, choicemap((addr, new_f .+ mu))) + new_trace, weight = update(trace, choicemap((addr, new_f .+ mu))) while weight <= log(u) if theta < 0 theta_min = theta @@ -38,7 +36,7 @@ function elliptical_slice( end theta = uniform(theta_min, theta_max) new_f = f * cos(theta) + nu * sin(theta) - new_trace, weight = update(trace, args, argdiffs, choicemap((addr, new_f .+ mu))) + new_trace, weight = update(trace, choicemap((addr, new_f .+ mu))) end check && check_observations(get_choices(new_trace), observations) return new_trace diff --git a/src/inference/hmc.jl b/src/inference/hmc.jl index 0f156669..3daad568 100644 --- a/src/inference/hmc.jl +++ b/src/inference/hmc.jl @@ -25,9 +25,7 @@ function hmc( trace::U, selection::Selection; L=10, eps=0.1, check=false, observations=EmptyChoiceMap()) where {T,U} prev_model_score = get_score(trace) - args = get_args(trace) retval_grad = accepts_output_grad(get_gen_fn(trace)) ? zero(get_retval(trace)) : nothing - argdiffs = map((_) -> NoChange(), args) # run leapfrog dynamics new_trace = trace @@ -46,7 +44,7 @@ function hmc( # get new gradient values_trie = from_array(values_trie, values) - (new_trace, _, _) = update(new_trace, args, argdiffs, values_trie) + (new_trace, _, _) = update(new_trace, values_trie) (_, _, gradient_trie) = choice_gradients(new_trace, selection, retval_grad) gradient = to_array(gradient_trie, Float64) diff --git a/src/inference/involution_dsl.jl b/src/inference/involution_dsl.jl index 2f08399b..821f4973 100644 --- a/src/inference/involution_dsl.jl +++ b/src/inference/involution_dsl.jl @@ -299,7 +299,7 @@ function apply_involution(involution::InvolutionDSLProgram, trace, u, proposal_a # update model trace (new_trace, model_weight, _, discard) = update( - trace, get_args(trace), map((_) -> NoChange(), get_args(trace)), first_pass_state.constraints) + trace, first_pass_state.constraints) # create input array and mappings input addresses that are needed for Jacobian # exclude addresses that were moved to another address diff --git a/src/inference/mala.jl b/src/inference/mala.jl index 033a45a7..1bdba1fd 100644 --- a/src/inference/mala.jl +++ b/src/inference/mala.jl @@ -11,8 +11,6 @@ Apply a Metropolis-Adjusted Langevin Algorithm (MALA) update. function mala( trace, selection::Selection, tau::Real; check=false, observations=EmptyChoiceMap()) - args = get_args(trace) - argdiffs = map((_) -> NoChange(), args) std = sqrt(2 * tau) retval_grad = accepts_output_grad(get_gen_fn(trace)) ? zero(get_retval(trace)) : nothing @@ -30,8 +28,7 @@ function mala( # evaluate model weight constraints = from_array(values_trie, proposed_values) - (new_trace, weight, _, discard) = update(trace, - args, argdiffs, constraints) + (new_trace, weight, _, discard) = update(trace, constraints) check && check_observations(get_choices(new_trace), observations) # backward proposal diff --git a/src/inference/map_optimize.jl b/src/inference/map_optimize.jl index 16d419e4..8a75c94f 100644 --- a/src/inference/map_optimize.jl +++ b/src/inference/map_optimize.jl @@ -8,8 +8,6 @@ Selected random choices must have support on the entire real line. """ function map_optimize(trace, selection::Selection; max_step_size=0.1, tau=0.5, min_step_size=1e-16, verbose=false) - args = get_args(trace) - argdiffs = map((_) -> NoChange(), args) retval_grad = accepts_output_grad(get_gen_fn(trace)) ? zero(get_retval(trace)) : nothing (_, values, gradient) = choice_gradients(trace, selection, retval_grad) @@ -21,7 +19,7 @@ function map_optimize(trace, selection::Selection; new_values_vec = values_vec + gradient_vec * step_size values = from_array(values, new_values_vec) # TODO discard and weight are not actually needed, there should be a more specialized variant - (new_trace, _, _, discard) = update(trace, args, argdiffs, values) + (new_trace, _, _, discard) = update(trace, values) new_score = get_score(new_trace) change = new_score - score if verbose diff --git a/src/inference/mh.jl b/src/inference/mh.jl index 7856b9ea..7e66f3e0 100644 --- a/src/inference/mh.jl +++ b/src/inference/mh.jl @@ -14,9 +14,7 @@ Perform a Metropolis-Hastings update that proposes new values for the selected a function metropolis_hastings( trace, selection::Selection; check=false, observations=EmptyChoiceMap()) - args = get_args(trace) - argdiffs = map((_) -> NoChange(), args) - (new_trace, weight) = regenerate(trace, args, argdiffs, selection) + (new_trace, weight) = regenerate(trace, selection) check && check_observations(get_choices(new_trace), observations) if log(rand()) < weight # accept @@ -41,12 +39,9 @@ If the proposal modifies addresses that determine the control flow in the model, function metropolis_hastings( trace, proposal::GenerativeFunction, proposal_args::Tuple; check=false, observations=EmptyChoiceMap()) - model_args = get_args(trace) - argdiffs = map((_) -> NoChange(), model_args) proposal_args_forward = (trace, proposal_args...,) (fwd_choices, fwd_weight, _) = propose(proposal, proposal_args_forward) - (new_trace, weight, _, discard) = update(trace, - model_args, argdiffs, fwd_choices) + (new_trace, weight, _, discard) = update(trace, fwd_choices) proposal_args_backward = (new_trace, proposal_args...,) (bwd_weight, _) = assess(proposal, proposal_args_backward, discard) alpha = weight - fwd_weight + bwd_weight diff --git a/src/inference/particle_filter.jl b/src/inference/particle_filter.jl index 1d761bf6..382ab6f6 100644 --- a/src/inference/particle_filter.jl +++ b/src/inference/particle_filter.jl @@ -142,7 +142,8 @@ function particle_filter_step!(state::ParticleFilterState{U}, new_args::Tuple, a for i=1:num_particles (prop_choices, prop_weight, _) = propose(proposal, (state.traces[i], proposal_args...)) constraints = merge(observations, prop_choices) - (state.new_traces[i], up_weight, _, disc) = update(state.traces[i], new_args, argdiffs, constraints) + (state.new_traces[i], up_weight, _, disc) = update( + state.traces[i], new_args, argdiffs, constraints) @assert isempty(disc) state.log_weights[i] += up_weight - prop_weight end