From 1cc4767b3ea7cc7434ed84d86244c768a16e1384 Mon Sep 17 00:00:00 2001 From: Marco Cusumano-Towner Date: Wed, 12 May 2021 16:12:20 +0000 Subject: [PATCH 01/24] checkpoint on experimental performance improvements to gradient accumulation --- src/builtin_optimization.jl | 127 ++++++++++++++++++++++++++++++++++-- src/dynamic/backprop.jl | 6 +- src/dynamic/dynamic.jl | 4 +- src/static_ir/backprop.jl | 88 ++++++++++++++++--------- src/static_ir/static_ir.jl | 4 +- 5 files changed, 189 insertions(+), 40 deletions(-) diff --git a/src/builtin_optimization.jl b/src/builtin_optimization.jl index 3a897577f..d4f5ade80 100644 --- a/src/builtin_optimization.jl +++ b/src/builtin_optimization.jl @@ -1,3 +1,92 @@ +############################# + +# primitives for in-place gradient accumulation + +function in_place_add!(param::Array, increment::Array, scale_factor::Real) + # NOTE: it ignores the scale_factor, because it is not a parameter... + # scale factors only affect parameters + # TODO this is potentially very confusing! + @simd for i in 1:length(param) + param[i] += increment[i] + end + return param +end + +function in_place_add!(param::Array, increment::Array) + @inbounds @simd for i in 1:length(param) + param[i] += increment[i] + end + return param +end + +function in_place_add!(param::Real, increment::Real, scale_factor::Real) + return param + increment +end + +function in_place_add!(param::Real, increment::Real) + return param + increment +end + +mutable struct ThreadsafeAccumulator{T} + value::T + lock::ReentrantLock +end + +ThreadsafeAccumulator(value) = ThreadsafeAccumulator(value, ReentrantLock()) + +# TODO not threadsafe +function get_current_value(accum::ThreadsafeAccumulator) + return accum.value +end + +function in_place_add!(param::ThreadsafeAccumulator{Real}, increment::Real, scale_factor::Real) + lock(param.lock) + try + param.value = param.value + increment * scale_factor + finally + unlock(param.lock) + end + return param +end + +function in_place_add!(param::ThreadsafeAccumulator{Real}, increment::Real) + lock(param.lock) + try + param.value = param.value + increment + finally + unlock(param.lock) + end + return param +end + +function in_place_add!(param::ThreadsafeAccumulator{<:Array}, increment, scale_factor::Real) + lock(param.lock) + try + @simd for i in 1:length(param.value) + param.value[i] += increment[i] * scale_factor + end + finally + unlock(param.lock) + end + return param +end + +function in_place_add!(param::ThreadsafeAccumulator{<:Array}, increment) + lock(param.lock) + try + @simd for i in 1:length(param.value) + param.value[i] += increment[i] + end + finally + unlock(param.lock) + end + return param +end + +############################# + + + """ set_param!(gen_fn::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, name::Symbol, value) @@ -6,7 +95,7 @@ Set the value of a trainable parameter of the generative function. NOTE: Does not update the gradient accumulator value. """ function set_param!(gf::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, name::Symbol, value) - gf.params[name] = value + return gf.params[name] = value end """ @@ -15,36 +104,62 @@ end Get the current value of a trainable parameter of the generative function. """ function get_param(gf::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, name::Symbol) - gf.params[name] + return gf.params[name] end """ value = get_param_grad(gen_fn::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, name::Symbol) Get the current value of the gradient accumulator for a trainable parameter of the generative function. + +Not threadsafe. """ function get_param_grad(gf::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, name::Symbol) - gf.params_grad[name] + try + val = gf.params_grad[name] # the accumulator + return get_current_value(val) + catch KeyError + error("parameter $name not found") + end + return val end """ zero_param_grad!(gen_fn::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, name::Symbol) Reset the gradient accumlator for a trainable parameter of the generative function to all zeros. + +Not threadsafe. """ function zero_param_grad!(gf::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, name::Symbol) - gf.params_grad[name] = zero(gf.params[name]) + gf.params_grad[name] = ThreadsafeAccumulator(zero(gf.params[name])) # TODO avoid allocation? + return gf.params_grad[name] end """ set_param_grad!(gen_fn::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, name::Symbol, grad_value) Set the gradient accumlator for a trainable parameter of the generative function. + +Not threadsafe. """ function set_param_grad!(gf::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, name::Symbol, grad_value) - gf.params_grad[name] = grad_value + gf.params_grad[name] = ThreadsafeAccumulator(grad_value) + return grad_value +end + +# TODO document me; it is threadsafe.. +function increment_param_grad!(gf::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, name::Symbol, increment, scale_factor) + in_place_add!(gf.params_grad[name], increment, scale_factor) +end + +# TODO document me; it is threadsafe.. +function increment_param_grad!(gf::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, name::Symbol, increment) + in_place_add!(gf.params_grad[name], increment) end + + """ init_param!(gen_fn::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, name::Symbol, value) @@ -64,7 +179,7 @@ end get_params(gf::Union{DynamicDSLFunction,StaticIRGenerativeFunction}) = keys(gf.params) -export set_param!, get_param, get_param_grad, zero_param_grad!, set_param_grad!, init_param! +export set_param!, get_param, get_param_grad, zero_param_grad!, set_param_grad!, init_param!, increment_param_grad! ######################################### # gradient descent with fixed step size # diff --git a/src/dynamic/backprop.jl b/src/dynamic/backprop.jl index 9e02c0657..0489ec3a6 100644 --- a/src/dynamic/backprop.jl +++ b/src/dynamic/backprop.jl @@ -196,12 +196,14 @@ function accumulate_param_gradients!(trace::DynamicDSLTrace, retval_grad, scale_ # increment the gradient accumulators for trainable parameters in scope for (name, tracked) in state.tracked_params - gen_fn.params_grad[name] += deriv(tracked) * state.scale_factor + #increment = deriv(tracked) * state.scale_factor + increment_param_grad!(gen_fn, name, deriv(tracked), state.scale_factor) end # increment the gradient accumulators for trainable parameters in splice calls for ((spliced_gen_fn, name), tracked) in state.splice_tracked_params - spliced_gen_fn.params_grad[name] += deriv(tracked) * state.scale_factor + #increment = deriv(tracked) * state.scale_factor + increment_param_grad!(spliced_gen_fn, name, deriv(tracked), state.scale_factor) end # return gradients with respect to arguments with gradients, or nothing diff --git a/src/dynamic/dynamic.jl b/src/dynamic/dynamic.jl index d83055444..e136ed723 100644 --- a/src/dynamic/dynamic.jl +++ b/src/dynamic/dynamic.jl @@ -17,6 +17,7 @@ struct DynamicDSLFunction{T} <: GenerativeFunction{T,DynamicDSLTrace} julia_function::Function has_argument_grads::Vector{Bool} accepts_output_grad::Bool + params_grad_lock::ReentrantLock end function DynamicDSLFunction(arg_types::Vector{Type}, @@ -30,7 +31,8 @@ function DynamicDSLFunction(arg_types::Vector{Type}, DynamicDSLFunction{T}(params_grad, params, arg_types, has_defaults, arg_defaults, julia_function, - has_argument_grads, accepts_output_grad) + has_argument_grads, accepts_output_grad, + ReentrantLock()) end function DynamicDSLTrace(gen_fn::T, args) where {T<:DynamicDSLFunction} diff --git a/src/static_ir/backprop.jl b/src/static_ir/backprop.jl index 96f63da0c..b6b0087d6 100644 --- a/src/static_ir/backprop.jl +++ b/src/static_ir/backprop.jl @@ -71,8 +71,10 @@ end function back_pass!(back_marked, node::GenerativeFunctionCallNode) # the logpdf of every generative function call is a SINK - for input_node in node.inputs - push!(back_marked, input_node) + for (input_node, has_grad) in zip(node.inputs, has_argument_grads(node.generative_function)) + if has_grad + push!(back_marked, input_node) + end end end @@ -85,7 +87,10 @@ function fwd_codegen!(stmts, fwd_marked, back_marked, node::TrainableParameterNo if node in fwd_marked && node in back_marked # initialize gradient to zero - push!(stmts, :($(gradient_var(node)) = zero($(node.name)))) + # NOTE: we are avoiding allocating a new gradient accumulator for this function + # instead, we are using the threadsafe gradient accumulator directly.. + #push!(stmts, :($(gradient_var(node)) = zero($(node.name)))) + push!(stmts, :($(gradient_var(node)) = $(QuoteNode(get_gen_fn))(trace).params_grad[$(QuoteNode(node.name))])) end end @@ -156,7 +161,7 @@ function fwd_codegen!(stmts, fwd_marked, back_marked, node::GenerativeFunctionCa # for reference by other nodes during back_codegen! # could performance optimize this away subtrace_fieldname = get_subtrace_fieldname(node) - push!(stmts, :($(node.name) = get_retval(trace.$subtrace_fieldname))) + push!(stmts, :($(node.name) = $(QuoteNode(get_retval))(trace.$subtrace_fieldname))) # NOTE: we will still potentially run choice_gradients recursively on the generative function, # we just might not use its return value gradient. @@ -172,15 +177,17 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node: if node === ir.return_node && node in fwd_marked @assert node in back_marked push!(stmts, :(isnothing(retval_grad) && error("Required return value gradient but got nothing"))) - push!(stmts, :($(gradient_var(node)) += retval_grad)) + push!(stmts, :($(gradient_var(node)) = $(QuoteNode(in_place_add!))( + $(gradient_var(node)), retval_grad, scale_factor))) + #push!(stmts, :($(gradient_var(node)) += retval_grad)) end if node in fwd_marked && node in back_marked - cur_param_grad = :($(QuoteNode(get_param_grad))(trace.$static_ir_gen_fn_ref, - $(QuoteNode(node.name)))) - push!(stmts, :($(QuoteNode(set_param_grad!))(trace.$static_ir_gen_fn_ref, - $(QuoteNode(node.name)), - $cur_param_grad + $(gradient_var(node))))) + #NOTE: unecessary, because we accumulated in-place already + #push!(stmts, :($(QuoteNode(increment_param_grad!))(trace.$static_ir_gen_fn_ref, + #$(QuoteNode(node.name)), + #$(gradient_var(node)), + #scale_factor))) end end @@ -189,7 +196,9 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node: # handle case when it is the return node if node === ir.return_node && node in fwd_marked @assert node in back_marked - push!(stmts, :($(gradient_var(node)) += retval_grad)) + push!(stmts, :($(gradient_var(node)) = $(QuoteNode(in_place_add!))( + $(gradient_var(node)), retval_grad, scale_factor))) + #push!(stmts, :($(gradient_var(node)) += retval_grad)) end end @@ -197,7 +206,9 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node: # handle case when it is the return node if node === ir.return_node && node in fwd_marked @assert node in back_marked - push!(stmts, :($(gradient_var(node)) += retval_grad)) + push!(stmts, :($(gradient_var(node)) = $(QuoteNode(in_place_add!))( + $(gradient_var(node)), retval_grad, scale_factor))) + #push!(stmts, :($(gradient_var(node)) += retval_grad)) end if node in back_marked && any(input_node in fwd_marked for input_node in node.inputs) @@ -209,7 +220,9 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node: for (i, input_node) in enumerate(node.inputs) if input_node in fwd_marked arg_maybe_tracked = maybe_tracked_arg_var(node, i) - push!(stmts, :($(gradient_var(input_node)) += $(QuoteNode(deriv))($arg_maybe_tracked))) + push!(stmts, :($(gradient_var(input_node)) = $(QuoteNode(in_place_add!))( + $(gradient_var(input_node)), $(QuoteNode(deriv))($arg_maybe_tracked), scale_factor))) + #push!(stmts, :($(gradient_var(input_node)) += $(QuoteNode(deriv))($arg_maybe_tracked))) end end end @@ -221,7 +234,7 @@ function back_codegen_random_choice_to_inputs!(stmts, ir, fwd_marked, back_marke # only evaluate the gradient of the logpdf if we need to if any(input_node in fwd_marked for input_node in node.inputs) || node in fwd_marked args = map((input_node) -> input_node.name, node.inputs) - push!(stmts, :($logpdf_grad = logpdf_grad($(node.dist), $(node.name), $(args...)))) + push!(stmts, :($logpdf_grad = $(QuoteNode(Gen.logpdf_grad))($(node.dist), $(node.name), $(args...)))) end # increment gradients of input nodes that are in fwd_marked @@ -231,14 +244,18 @@ function back_codegen_random_choice_to_inputs!(stmts, ir, fwd_marked, back_marke if !has_argument_grads(node.dist)[i] error("Distribution $(node.dist) does not have logpdf gradient for argument $i") end - push!(stmts, :($(gradient_var(input_node)) += $logpdf_grad[$(QuoteNode(i+1))])) + push!(stmts, :($(gradient_var(input_node)) = $(QuoteNode(in_place_add!))( + $(gradient_var(input_node)), $logpdf_grad[$(QuoteNode(i+1))], scale_factor))) + #push!(stmts, :($(gradient_var(input_node)) += $logpdf_grad[$(QuoteNode(i+1))])) end end # handle case when it is the return node if node === ir.return_node && node in fwd_marked @assert node in back_marked - push!(stmts, :($(gradient_var(node)) += retval_grad)) + push!(stmts, :($(gradient_var(node)) = $(QuoteNode(in_place_add!))( + $(gradient_var(node)), retval_grad, scale_factor))) + #push!(stmts, :($(gradient_var(node)) += retval_grad)) end end @@ -254,7 +271,9 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, if !has_output_grad(node.dist) error("Distribution $dist does not logpdf gradient for its output value") end - push!(stmts, :($(gradient_var(node)) += $logpdf_grad[1])) + push!(stmts, :($(gradient_var(node)) = $(QuoteNode(in_place_add!))( + $(gradient_var(node)), retval_grad, scale_factor))) + #push!(stmts, :($(gradient_var(node)) += $logpdf_grad[1])) end end @@ -270,7 +289,9 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, # handle case when it is the return node if node === ir.return_node && node in fwd_marked @assert node in back_marked - push!(stmts, :($(gradient_var(node)) += retval_grad)) + push!(stmts, :($(gradient_var(node)) = $(QuoteNode(in_place_add!))( + $(gradient_var(node)), retval_grad, scale_factor))) + #push!(stmts, :($(gradient_var(node)) += retval_grad)) end if node in fwd_marked @@ -285,7 +306,7 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, push!(stmts, :($call_selection = EmptySelection())) end retval_grad = node in back_marked ? gradient_var(node) : :(nothing) - push!(stmts, :(($input_grads, $value_trie, $gradient_trie) = choice_gradients( + push!(stmts, :(($input_grads, $value_trie, $gradient_trie) = $(QuoteNode(choice_gradients))( trace.$subtrace_fieldname, $call_selection, $retval_grad))) end @@ -293,7 +314,9 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, for (i, input_node) in enumerate(node.inputs) if input_node in fwd_marked @assert input_node in back_marked # this ensured its gradient will have been initialized - push!(stmts, :($(gradient_var(input_node)) += $input_grads[$(QuoteNode(i))])) + push!(stmts, :($(gradient_var(input_node)) = $(QuoteNode(in_place_add!))( + $(gradient_var(input_node)), $input_grads[$(QuoteNode(i))], scale_factor))) + #push!(stmts, :($(gradient_var(input_node)) += $input_grads[$(QuoteNode(i))])) end end @@ -306,21 +329,25 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, # handle case when it is the return node if node === ir.return_node && node in fwd_marked @assert node in back_marked - push!(stmts, :($(gradient_var(node)) += retval_grad)) + push!(stmts, :($(gradient_var(node)) = $(QuoteNode(in_place_add!))( + $(gradient_var(node)), retval_grad, scale_factor))) + #push!(stmts, :($(gradient_var(node)) += retval_grad)) end if node in fwd_marked input_grads = gensym("call_input_grads") subtrace_fieldname = get_subtrace_fieldname(node) retval_grad = node in back_marked ? gradient_var(node) : :(nothing) - push!(stmts, :($input_grads = accumulate_param_gradients!(trace.$subtrace_fieldname, $retval_grad))) + push!(stmts, :($input_grads = $(QuoteNode(accumulate_param_gradients!))(trace.$subtrace_fieldname, $retval_grad, scale_factor))) end # increment gradients of input nodes that are in fwd_marked - for (i, input_node) in enumerate(node.inputs) - if input_node in fwd_marked + for (i, (input_node, has_grad)) in enumerate(zip(node.inputs, has_argument_grads(node.generative_function))) + if input_node in fwd_marked && has_grad @assert input_node in back_marked # this ensured its gradient will have been initialized - push!(stmts, :($(gradient_var(input_node)) += $input_grads[$(QuoteNode(i))])) + push!(stmts, :($(gradient_var(input_node)) = $(QuoteNode(in_place_add!))( + $(gradient_var(input_node)), $input_grads[$(QuoteNode(i))], scale_factor))) + #push!(stmts, :($(gradient_var(input_node)) += $input_grads[$(QuoteNode(i))])) end end end @@ -417,7 +444,7 @@ function codegen_choice_gradients(trace_type::Type{T}, selection_type::Type, # unpack arguments from the trace arg_names = Symbol[arg_node.name for arg_node in ir.arg_nodes] - push!(stmts, :($(Expr(:tuple, arg_names...)) = get_args(trace))) + push!(stmts, :($(Expr(:tuple, arg_names...)) = $(QuoteNode(get_args))(trace))) # forward code-generation pass (initialize gradients to zero, create needed references) for node in ir.nodes @@ -446,7 +473,7 @@ function codegen_choice_gradients(trace_type::Type{T}, selection_type::Type, end function codegen_accumulate_param_gradients!(trace_type::Type{T}, - retval_grad_type::Type) where {T<:StaticIRTrace} + retval_grad_type::Type, scale_factor_type) where {T<:StaticIRTrace} gen_fn_type = get_gen_fn_type(trace_type) ir = get_ir(gen_fn_type) @@ -472,10 +499,11 @@ function codegen_accumulate_param_gradients!(trace_type::Type{T}, end stmts = Expr[] + push!(stmts, :(scale_factor = 1.0f0)) # unpack arguments from the trace arg_names = Symbol[arg_node.name for arg_node in ir.arg_nodes] - push!(stmts, :($(Expr(:tuple, arg_names...)) = get_args(trace))) + push!(stmts, :($(Expr(:tuple, arg_names...)) = $(QuoteNode(get_args))(trace))) # forward code-generation pass (initialize gradients to zero, create needed references) for node in ir.nodes @@ -506,7 +534,7 @@ end end) push!(generated_functions, quote -@generated function $(GlobalRef(Gen, :accumulate_param_gradients!))(trace::T, retval_grad) where {T<:$(QuoteNode(StaticIRTrace))} - $(QuoteNode(codegen_accumulate_param_gradients!))(trace, retval_grad) +@generated function $(GlobalRef(Gen, :accumulate_param_gradients!))(trace::T, retval_grad, scale_factor) where {T<:$(QuoteNode(StaticIRTrace))} + $(QuoteNode(codegen_accumulate_param_gradients!))(trace, retval_grad, scale_factor) end end) diff --git a/src/static_ir/static_ir.jl b/src/static_ir/static_ir.jl index d5d7e237f..72abbca9c 100644 --- a/src/static_ir/static_ir.jl +++ b/src/static_ir/static_ir.jl @@ -50,6 +50,7 @@ function generate_generative_function(ir::StaticIR, name::Symbol, options::Stati struct $gen_fn_type_name <: $(QuoteNode(StaticIRGenerativeFunction)){$return_type,$trace_type} params_grad::Dict{Symbol,Any} params::Dict{Symbol,Any} + params_grad_lock::ReentrantLock end (gen_fn::$gen_fn_type_name)(args...) = $(GlobalRef(Gen, :propose))(gen_fn, args)[3] $(GlobalRef(Gen, :get_ir))(::$gen_fn_type_name) = $(QuoteNode(ir)) @@ -61,7 +62,8 @@ function generate_generative_function(ir::StaticIR, name::Symbol, options::Stati $(GlobalRef(Gen, :get_gen_fn_type))(::Type{$trace_struct_name}) = $gen_fn_type_name $(GlobalRef(Gen, :get_options))(::Type{$gen_fn_type_name}) = $(QuoteNode(options)) end - Expr(:block, trace_defns, gen_fn_defn, Expr(:call, gen_fn_type_name, :(Dict{Symbol,Any}()), :(Dict{Symbol,Any}()))) + Expr(:block, trace_defns, gen_fn_defn, + Expr(:call, gen_fn_type_name, :(Dict{Symbol,Any}()), :(Dict{Symbol,Any}()), :(ReentrantLock()))) end include("print_ir.jl") From e471f468444b64410326ba93d48d6993dacdf824 Mon Sep 17 00:00:00 2001 From: Marco Cusumano-Towner Date: Thu, 13 May 2021 14:46:00 +0000 Subject: [PATCH 02/24] fix newly introduced bug --- src/static_ir/backprop.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/static_ir/backprop.jl b/src/static_ir/backprop.jl index b6b0087d6..417a91d6e 100644 --- a/src/static_ir/backprop.jl +++ b/src/static_ir/backprop.jl @@ -423,6 +423,8 @@ function codegen_choice_gradients(trace_type::Type{T}, selection_type::Type, return quote choice_gradients(trace, StaticSelection(selection), retval_grad) end end + push!(stmts, :(scale_factor = NaN)) + ir = get_ir(gen_fn_type) selected_choices = get_selected_choices(schema, ir) selected_calls = get_selected_calls(schema, ir) @@ -499,7 +501,6 @@ function codegen_accumulate_param_gradients!(trace_type::Type{T}, end stmts = Expr[] - push!(stmts, :(scale_factor = 1.0f0)) # unpack arguments from the trace arg_names = Symbol[arg_node.name for arg_node in ir.arg_nodes] From a5473f4e879b5d14f2700f0af32c47ca1bac6dd6 Mon Sep 17 00:00:00 2001 From: Marco Cusumano-Towner Date: Fri, 14 May 2021 00:02:27 -0400 Subject: [PATCH 03/24] checkpoint on gradients refactor --- src/gen_fn_interface.jl | 68 ++++---- src/optimization.jl | 337 ++++++++++++++++++++++++++++++++-------- 2 files changed, 314 insertions(+), 91 deletions(-) diff --git a/src/gen_fn_interface.jl b/src/gen_fn_interface.jl index f3881f43f..67d997df1 100644 --- a/src/gen_fn_interface.jl +++ b/src/gen_fn_interface.jl @@ -1,3 +1,27 @@ +export GenerativeFunction +export has_argument_grads +export accepts_output_grad + + +export get_parameters + +export Trace +export get_args +export get_retval +export get_choices +export get_score +export get_gen_fn + +export simulate +export generate +export project +export propose +export assess +export update +export regenerate +export accumulate_param_gradients! +export choice_gradients + ########## # Traces # ########## @@ -85,12 +109,6 @@ Synonym for [`get_retval`](@ref). """ Base.getindex(trace::Trace) = get_retval(trace) -export get_args -export get_retval -export get_choices -export get_score -export get_gen_fn - ###################### # GenerativeFunction # ###################### @@ -105,6 +123,13 @@ abstract type GenerativeFunction{T,U <: Trace} end get_return_type(::GenerativeFunction{T,U}) where {T,U} = T get_trace_type(::GenerativeFunction{T,U}) where {T,U} = U +""" + parameters::Dict{ParameterStore,Vector} = get_parameters(gen_fn::GenerativeFunction) + +Returns the parameters used by the generative function (including all of its calls). +""" +function get_parameters end + """ bools::Tuple = has_argument_grads(gen_fn::Union{GenerativeFunction,Distribution}) @@ -135,7 +160,7 @@ Return an iterable over the trainable parameters of the generative function. get_params(::GenerativeFunction) = () """ - trace = simulate(gen_fn, args) + trace = simulate(gen_fn, args; parameter_context=Dict()) Execute the generative function and return the trace. @@ -146,17 +171,19 @@ If `gen_fn` has optional trailing arguments (i.e., default values are provided), the optional arguments can be omitted from the `args` tuple. The generated trace will have default values filled in. """ -function simulate(::GenerativeFunction, ::Tuple) +function simulate(::GenerativeFunction, ::Tuple; parameter_context=Dict()) error("Not implemented") end """ - (trace::U, weight) = generate(gen_fn::GenerativeFunction{T,U}, args::Tuple) + (trace::U, weight) = generate( + gen_fn::GenerativeFunction{T,U}, args::Tuple; parameter_context=Dict()) Return a trace of a generative function. - (trace::U, weight) = generate(gen_fn::GenerativeFunction{T,U}, args::Tuple, - constraints::ChoiceMap) + (trace::U, weight) = generate( + gen_fn::GenerativeFunction{T,U}, args::Tuple, + constraints::ChoiceMap; parameter_context=Dict()) Return a trace of a generative function that is consistent with the given constraints on the random choices. @@ -182,11 +209,11 @@ Example with constraint that address `:z` takes value `true`. (trace, weight) = generate(foo, (2, 4), choicemap((:z, true)) ``` """ -function generate(::GenerativeFunction, ::Tuple, ::ChoiceMap) +function generate(::GenerativeFunction, ::Tuple, ::ChoiceMap; parameter_context=Dict()) error("Not implemented") end -function generate(gen_fn::GenerativeFunction, args::Tuple) +function generate(gen_fn::GenerativeFunction, args::Tuple; parameter_context=Dict()) generate(gen_fn, args, EmptyChoiceMap()) end @@ -408,18 +435,3 @@ end function choice_gradients(trace) choice_gradients(trace, EmptySelection(), nothing) end - -export GenerativeFunction -export Trace -export has_argument_grads -export accepts_output_grad -export get_params -export simulate -export generate -export project -export propose -export assess -export update -export regenerate -export accumulate_param_gradients! -export choice_gradients diff --git a/src/optimization.jl b/src/optimization.jl index 8617f6c27..fa6947e7b 100644 --- a/src/optimization.jl +++ b/src/optimization.jl @@ -1,107 +1,318 @@ -""" - state = init_update_state(conf, gen_fn::GenerativeFunction, param_list::Vector) +import Parameters + +# TODO notes +# +# we should modify the semantics of the log probability contribution to the gradient +# so that everything is gradient descent instead of ascent. this will also fix +# the misnomer names +# +# add a default global JuliaParameterStore + +export in_place_add! + +export FixedStepGradientDescent +export DecayStepGradientDescent +export make_optimizer +export apply_update! + +export ParameterStore +export JuliaParameterStore +export JuliaParameterID + +export initialize_parameter! +export increment_gradient! +export reset_gradient! +export get_parameter_value + +################# +# in_place_add! # +################# + +function in_place_add! end + +function in_place_add!(value::Array, increment) + @simd for i in 1:length(param) + value[i] += increment[i] + end + return value +end + +# this exists so user can use the same function on scalars and arrays +function in_place_add!(param::Real, increment::Real) + return param + increment +end -Get the initial state for a parameter update to the given parameters of the given generative function. +############################ +# optimizer specifications # +############################ -`param_list` is a vector of references to parameters of `gen_fn`. -`conf` configures the update. """ -function init_update_state end + conf = FixedStepGradientDescent(step_size) +Configuration for stochastic gradient descent update with fixed step size. """ - apply_update!(state) +Parameters.@with_kw struct FixedStepGradientDescent + step_size::Float64 +end -Apply one parameter update, mutating the values of the trainable parameters, and possibly also the given state. """ -function apply_update! end + conf = GradientDescent(step_size_init, step_size_beta) +Configuration for stochastic gradient descent update with step size given by `(t::Int) -> step_size_init * (step_size_beta + 1) / (step_size_beta + t)` where `t` is the iteration number. """ - update = ParamUpdate(conf, param_lists...) +Parameters.@with_kw struct DecayStepGradientDescent + step_size_init::Float64 + step_size_beta::Float64 +end -Return an update configured by `conf` that applies to set of parameters defined by `param_lists`. -Each element in `param_lists` value is is pair of a generative function and a vector of its parameter references. +# TODO add ADAM update -**Example**. To construct an update that applies a gradient descent update to the parameters `:a` and `:b` of generative function `foo` and the parameter `:theta` of generative function `:bar`: +########################### +# thread-safe accumulator # +########################### -```julia -update = ParamUpdate(GradientDescent(0.001, 100), foo => [:a, :b], bar => [:theta]) -``` +struct Accumulator{T<:Union{Real,Array}} + value::T + lock::ReentrantLock +end ------------------------------------------------------------------------------------------- -Syntactic sugar for the constructor form above. +Accumulator(value) = Accumulator(value, ReentrantLock()) - update = ParamUpdate(conf, gen_fn::GenerativeFunction) +# NOTE: not thread-safe because it may return a reference to the Array +get_value(accum::Accumulator) = accum.value -Return an update configured by `conf` that applies to all trainable parameters owned by the given generative function. +function fill_with_zeros!(accum::Accumulator{T}) where {T <: Real} + lock(accum.lock) + try + accum.value = zero(T) + finally + unlock(accum.lock) + end + return accum +end -Note that trainable parameters not owned by the given generative function will not be updated, even if they are used during execution of the function. +function fill_with_zeros!(accum::Accumulator{Array{T}}) where {T} + lock(accum.lock) + try + fill!(zero(T), accum.arr) + finally + unlock(accum.lock) + end + return accum +end -**Example**. If generative function `foo` has parameters `:a` and `:b`, to construct an update that applies a gradient descent update to the parameters `:a` and `:b`: +function in_place_add!(accum::ThreadsafeAccumulator{Real}, increment::Real, scale_factor::Real) + lock(accum.lock) + try + accum.value = accum.value + increment * scale_factor + finally + unlock(accum.lock) + end + return accum +end -```julia -update = ParamUpdate(GradientDescent(0.001, 100), foo) -``` -""" -struct ParamUpdate - states::Dict{GenerativeFunction,Any} +function in_place_add!(accum::ThreadsafeAccumulator{Real}, increment::Real) + lock(accum.lock) + try + accum.value = accum.value + increment + finally + unlock(accum.lock) + end + return accum +end + +function in_place_add!(accum::ThreadsafeAccumulator{<:Array}, increment, scale_factor::Real) + lock(accum.lock) + try + @simd for i in 1:length(accum.value) + accum.value[i] += increment[i] * scale_factor + end + finally + unlock(accum.lock) + end + return accum +end + +function in_place_add!(accum::ThreadsafeAccumulator{<:Array}, increment) + lock(accum.lock) + try + @simd for i in 1:length(accum.value) + accum.value[i] += increment[i] + end + finally + unlock(accum.lock) + end + return accum +end + + + + +################################# +# ParameterStore and optimizers # +################################# + +abstract type ParameterStore end + +# TODO docstring, returns an optimizer that has an apply_update! method +function make_optimizer(conf, store::ParameterStore, parameter_ids) end + +# TODO docstring +function apply_update!(optimizer) end + +struct CompositeOptimizer conf::Any - function ParamUpdate(conf, param_lists...) - states = Dict{GenerativeFunction,Any}() - for (gen_fn, param_list) in param_lists - states[gen_fn] = init_update_state(conf, gen_fn, param_list) + optimizers::Dict{ParameterStore,Any} + function CompositeOptimizer(conf, parameters::Dict{ParameterStore,Vector}) + optimizers = Dict{ParameterStore,Any}() + for (store, parameter_ids) in parameters + optimizers[store] = make_optimizer(conf, store, parameter_ids) end new(states, conf) end - function ParamUpdate(conf, gen_fn::GenerativeFunction) - param_lists = Dict(gen_fn => collect(get_params(gen_fn))) - ParamUpdate(conf, param_lists...) - end end +function CompositeOptimizer(conf, gen_fn::GenerativeFunction) + return CompositeOptimizer(conf, get_parameters(gen_fn)) +end """ - apply!(update::ParamUpdate) + apply_update!(update::ParamUpdate) Perform one step of the update. """ -function apply!(update::ParamUpdate) - for (_, state) in update.states - apply_update!(state) +function apply_update!(composite_opt::CompositeOptimizer) + for opt in values(composite_opt.optimizers) + apply_update!(opt) end - nothing + return nothing end -""" - conf = FixedStepGradientDescent(step_size) -Configuration for stochastic gradient descent update with fixed step size. -""" -struct FixedStepGradientDescent - step_size::Float64 +######### +# Julia # +######### + +# TODO document +const JuliaParameterID = Tuple{GenerativeFunction,Symbol} + +# TODO document +struct JuliaParameterStore + values::Dict{JuliaParameterID,Any} + gradient_accumulators::Dict{JuliaParameterID,GradientAccumulator} end -""" - conf = GradientDescent(step_size_init, step_size_beta) +function JuliaParameterStore() + return JuliaParameterStore( + Dict{JuliaParameterID,Any}(), + Dict{JuliaParameterID,GradientAccumulator}()) +end -Configuration for stochastic gradient descent update with step size given by `(t::Int) -> step_size_init * (step_size_beta + 1) / (step_size_beta + t)` where `t` is the iteration number. -""" -struct GradientDescent - step_size_init::Float64 - step_size_beta::Float64 +# TODO document +const default_julia_parameter_store = JuliaParameterStore() + +# for looking up in a parameter context when tracing (simulate, generate) +# TODO make the parametr context another argument to simulate and generate +# once a trace is generated, it is bound to use a particular store +const JULIA_PARAMETER_STORE_KEY = :julia_parameter_store + +function get_julia_store(context::Dict{Symbol,Any}) + if haskey(context, JULIA_PARAMETER_STORE_KEY) + return context[JULIA_PARAMETER_STORE_KEY] + else + return default_julia_parameter_store end """ - conf = ADAM(learning_rate, beta1, beta2, epsilon) + initialize_parameter!(store::JuliaParameterStore, id::JuliaParameterID, value) + +Initialize the the value of a named trainable parameter of a generative function. + +Also generates the gradient accumulator for that parameter to `zero(value)`. -Configuration for ADAM update. +Example: +```julia +initialize_parameter!(foo, :theta, 0.6) +``` + +Not thread-safe. """ -struct ADAM - learning_rate::Float64 - beta1::Float64 - beta2::Float64 - epsilon::Float64 +function initialize_parameter!(store::JuliaParameterStore, id::JuliaParameterID, value) + store.values[id] = value + reset_gradient!(store, id) + return nothing +end + +# TODO docstring (not thread-safe) +function reset_gradient!(store::JuliaParameterStore, id::JuliaParameterID) + !haskey(store.values, id) && error("parameter not initialized: $id") + if haskey(store.gradient_accumulators, id) + fill_with_zeros!(store.gradient_accumulators[id]) + else + store.gradient_accumulators[id] = Accumulator(zero(store.values[id])) + end + return nothing +end + +# TODO docstring (thread-safe) +function increment_gradient!( + store::JuliaParameterStore, id::JuliaParameterID, + increment, scale_factor) + in_place_add!(store.gradient_accumulators[id], increment, scale_factor) + return nothing +end + +# TODO docstring (thread-safe) +function increment_gradient!( + store::JuliaParameterStore, id::JuliaParameterID, + increment) + in_place_add!(store.gradient_accumulators[id], increment) + return nothing +end + +# TODO docstring (not thread-safe) +function get_parameter_value(store::JuliaParameterStore, id::JuliaParameterID) + return store.values[id] +end + +# TODO docstring (not thread-safe) +function set_parameter_value!(store::JuliaParameterStore, id::JuliaParameterID, value) + store.values[id] = value + return nothing +end + +# TODO docstring (not thread-safe) +function get_gradient(store::JuliaParameterStore, id::JuliaParameterID) + return get_value(store.gradient_accumulators[id]) +end + +##################################################### +# Optimizer implementations for JuliaParameterStore # +##################################################### + +mutable struct FixedStepGradientDescentJulia + conf::FixedStepGradientDescent + store::JuliaParameterStore + parameters::Vector{JuliaParameterID} +end + +function make_optimizer( + conf::FixedStepGradientDescent, + store::JuliaParameterStore, + parameters::Vector{JuliaParameterID}) + return FixedStepGradientDescentJulia(conf, store, parameters) +end + +# TODO docstring (not thread-safe) +function apply_update!(opt::FixedStepGradientDescentJulia) + for parameter_id in opt.parameters + value = get_parameter_value(opt.store, parameter_id) + gradient = get_gradient(opt.store, id) + new_value = in_place_add!(value, gradient * opt.conf.step_size) + set_parameter_value!(store, parameter_id, new_value) + reset_gradient!(store, parameter_id) + end end -export ParameterSet, ParamUpdate, apply! -export FixedStepGradientDescent, GradientDescent, ADAM +# TODO implement other optimizers From fcf1eb2b54c4eb3fea53c986e395f4a63e9f606b Mon Sep 17 00:00:00 2001 From: Marco Cusumano-Towner Date: Fri, 14 May 2021 00:14:25 -0400 Subject: [PATCH 04/24] checkpoint --- src/optimization.jl | 8 +++++++- src/static_ir/backprop.jl | 2 ++ src/static_ir/dag.jl | 12 ++++++++++++ src/static_ir/static_ir.jl | 6 +++--- 4 files changed, 24 insertions(+), 4 deletions(-) diff --git a/src/optimization.jl b/src/optimization.jl index fa6947e7b..a3ea4b250 100644 --- a/src/optimization.jl +++ b/src/optimization.jl @@ -6,7 +6,13 @@ import Parameters # so that everything is gradient descent instead of ascent. this will also fix # the misnomer names # -# add a default global JuliaParameterStore +# combinators and call_at! and choice_at! all need to implement get_parameters.. +# +# make changes to src/static_ir/backprop.jl +# +# make changes to src/dynamic/dynamic.jl (use the JuliaParameterStore) +# +# make changes to src/dynamic/backprop.jl export in_place_add! diff --git a/src/static_ir/backprop.jl b/src/static_ir/backprop.jl index 417a91d6e..595287d35 100644 --- a/src/static_ir/backprop.jl +++ b/src/static_ir/backprop.jl @@ -1,3 +1,5 @@ +# TODO this code needs to be simplified + struct BackpropTraceMode end struct BackpropParamsMode end diff --git a/src/static_ir/dag.jl b/src/static_ir/dag.jl index c82658892..190b40ba4 100644 --- a/src/static_ir/dag.jl +++ b/src/static_ir/dag.jl @@ -204,6 +204,18 @@ function set_accepts_output_grad!(builder::StaticIRBuilder, value::Bool) builder.accepts_output_grad = value end +function get_parameters(ir::StaticIR, gen_fn::GenerativeFunction, parameter_context) + parameters = Dict() + for call_node in ir.call_nodes + merge!(parameters, get_parameters(call_node.generative_function, parameter_context)) + end + julia_store = get_julia_store(parameter_context) + for param_node in ir.trainable_param_nodes + parameters[store] = (gen_fn, param_node.name) + end + return parameters +end + export StaticIR, StaticIRBuilder, build_ir export add_trainable_param_node! export add_argument_node! diff --git a/src/static_ir/static_ir.jl b/src/static_ir/static_ir.jl index 72abbca9c..8e996e9c7 100644 --- a/src/static_ir/static_ir.jl +++ b/src/static_ir/static_ir.jl @@ -48,9 +48,6 @@ function generate_generative_function(ir::StaticIR, name::Symbol, options::Stati gen_fn_defn = quote struct $gen_fn_type_name <: $(QuoteNode(StaticIRGenerativeFunction)){$return_type,$trace_type} - params_grad::Dict{Symbol,Any} - params::Dict{Symbol,Any} - params_grad_lock::ReentrantLock end (gen_fn::$gen_fn_type_name)(args...) = $(GlobalRef(Gen, :propose))(gen_fn, args)[3] $(GlobalRef(Gen, :get_ir))(::$gen_fn_type_name) = $(QuoteNode(ir)) @@ -61,6 +58,9 @@ function generate_generative_function(ir::StaticIR, name::Symbol, options::Stati $(GlobalRef(Gen, :get_gen_fn))(trace::$trace_struct_name) = $(Expr(:(.), :trace, QuoteNode(static_ir_gen_fn_ref))) $(GlobalRef(Gen, :get_gen_fn_type))(::Type{$trace_struct_name}) = $gen_fn_type_name $(GlobalRef(Gen, :get_options))(::Type{$gen_fn_type_name}) = $(QuoteNode(options)) + function $(GlobalRef(Gen, :get_parameters))(gen_fn::Type{$gen_fn_type_name}, context) + return $(GlobalRef(Gen, :get_parameters))($(QuoteNode(ir)), gen_fn, context) + end end Expr(:block, trace_defns, gen_fn_defn, Expr(:call, gen_fn_type_name, :(Dict{Symbol,Any}()), :(Dict{Symbol,Any}()), :(ReentrantLock()))) From 6e85f3f00755ae05d696e41ba4a90a7cae5b7f99 Mon Sep 17 00:00:00 2001 From: Marco Cusumano-Towner Date: Sun, 16 May 2021 14:20:32 -0400 Subject: [PATCH 05/24] checkpoint on changes to DML --- src/dynamic/backprop.jl | 112 ++++++++++++++++------------------------ src/dynamic/dynamic.jl | 28 +++++----- src/dynamic/generate.jl | 27 +++++----- src/dynamic/simulate.jl | 28 ++++++---- src/dynamic/trace.jl | 8 ++- src/gen_fn_interface.jl | 2 +- src/optimization.jl | 80 ++++++++++++++++++++++------ 7 files changed, 160 insertions(+), 125 deletions(-) diff --git a/src/dynamic/backprop.jl b/src/dynamic/backprop.jl index 0489ec3a6..c38a4ac27 100644 --- a/src/dynamic/backprop.jl +++ b/src/dynamic/backprop.jl @@ -31,9 +31,9 @@ end end -################### +############################### # accumulate_param_gradients! # -################### +############################### mutable struct GFBackpropParamsState trace::DynamicDSLTrace @@ -41,33 +41,26 @@ mutable struct GFBackpropParamsState tape::InstructionTape visitor::AddressVisitor scale_factor::Float64 - - # only those tracked parameters that are in scope (not including splice calls) - tracked_params::Dict{Symbol,Any} - - # tracked parameters for all (nested) splice calls - splice_tracked_params::Dict{Tuple{GenerativeFunction,Symbol},Any} -end - -function track_params(tape, params) - tracked_params = Dict{Symbol,Any}() - for (name, value) in params - tracked_params[name] = track(value, tape) + active_gen_fn::GenerativeFunction + tracked_params::Dict{ParameterID,Any} + + function GFBackpropParamsState(trace::DynamicDSLTrace, tape, scale_factor) + tracked_params = Dict{ParameterID,Any}() + store = get_parameter_store(trace) + gen_fn = get_gen_fn(trace) + for (name, value) in get_local_parameters(store, gen_fn) + parameter_id = (gen_fn, name) + tracked_params[parameter_id] = track(value, tape) + end + score = track(0., tape) + new(trace, score, tape, AddressVisitor(), scale_factor, + gen_fn, tracked_params) end - tracked_params -end - -function GFBackpropParamsState(trace::DynamicDSLTrace, tape, params, scale_factor) - tracked_params = track_params(tape, params) - splice_tracked_params = Dict{Tuple{GenerativeFunction,Symbol},Any}() - score = track(0., tape) - GFBackpropParamsState(trace, score, tape, AddressVisitor(), scale_factor, - tracked_params, splice_tracked_params) end function read_param(state::GFBackpropParamsState, name::Symbol) - value = state.tracked_params[name] - value + parameter_id = (state.active_gen_fn, name) + return state.tracked_params[parameter_id] end function traceat(state::GFBackpropParamsState, dist::Distribution{T}, @@ -110,30 +103,20 @@ end function splice(state::GFBackpropParamsState, gen_fn::DynamicDSLFunction, args_maybe_tracked::Tuple) - - # save previous tracked parameter scope - prev_tracked_params = state.tracked_params - - # construct new tracked parameter scope - state.tracked_params = Dict{Symbol,Any}() - for name in keys(gen_fn.params) - if haskey(state.splice_tracked_params, (gen_fn, name)) - # parameter was already tracked in another splice call - state.tracked_params[name] = state.splice_tracked_params[(gen_fn, name)] - else + prev_gen_fn = state.active_gen_fn + state.active_gen_fn = gen_fn + store = get_parameter_store(state.trace) + for (name, value) in get_local_parameters(store, gen_fn) + parameter_id = (gen_fn, name) + if !haskey(state.tracked_params, parameter_id) # parameter was not already tracked - tracked = track(get_param(gen_fn, name), state.tape) - state.tracked_params[name] = tracked - state.splice_tracked_params[(gen_fn, name)] = tracked + tracked = track(value, state.tape) + state.tracked_params[parameter_id] = tracked end end - retval_maybe_tracked = exec(gen_fn, state, args_maybe_tracked) - - # restore previous tracked parameter scope - state.tracked_params = prev_tracked_params - - retval_maybe_tracked + state.active_gen_fn = prev_gen_fn + return retval_maybe_tracked end @noinline function ReverseDiff.special_reverse_exec!( @@ -185,7 +168,7 @@ end function accumulate_param_gradients!(trace::DynamicDSLTrace, retval_grad, scale_factor=1.) gen_fn = trace.gen_fn tape = new_tape() - state = GFBackpropParamsState(trace, tape, gen_fn.params, scale_factor) + state = GFBackpropParamsState(trace, tape, scale_factor) args = get_args(trace) args_maybe_tracked = (map(maybe_track, args, gen_fn.has_argument_grads, fill(tape, length(args)))...,) @@ -195,15 +178,9 @@ function accumulate_param_gradients!(trace::DynamicDSLTrace, retval_grad, scale_ reverse_pass!(tape) # increment the gradient accumulators for trainable parameters in scope - for (name, tracked) in state.tracked_params - #increment = deriv(tracked) * state.scale_factor - increment_param_grad!(gen_fn, name, deriv(tracked), state.scale_factor) - end - - # increment the gradient accumulators for trainable parameters in splice calls - for ((spliced_gen_fn, name), tracked) in state.splice_tracked_params - #increment = deriv(tracked) * state.scale_factor - increment_param_grad!(spliced_gen_fn, name, deriv(tracked), state.scale_factor) + for ((active_gen_fn, name), tracked) in state.tracked_params + parameter_id = (active_gen_fn, parameter_id) + increment_gradient!(store, parameter_id, deriv(tracked), state.scale_factor) end # return gradients with respect to arguments with gradients, or nothing @@ -213,30 +190,31 @@ function accumulate_param_gradients!(trace::DynamicDSLTrace, retval_grad, scale_ end -################## +#################### # choice_gradients # -################## +#################### mutable struct GFBackpropTraceState trace::DynamicDSLTrace score::TrackedReal tape::InstructionTape visitor::AddressVisitor - params::Dict{Symbol,Any} selection::Selection tracked_choices::Trie{Any,Union{TrackedReal,TrackedArray}} value_choices::DynamicChoiceMap gradient_choices::DynamicChoiceMap + active_gen_fn::GenerativeFunction end -function GFBackpropTraceState(trace, selection, params, tape) +function GFBackpropTraceState(trace, selection, tape) score = track(0., tape) visitor = AddressVisitor() tracked_choices = Trie{Any,Union{TrackedReal,TrackedArray}}() value_choices = choicemap() gradient_choices = choicemap() - GFBackpropTraceState(trace, score, tape, visitor, params, - selection, tracked_choices, value_choices, gradient_choices) + GFBackpropTraceState(trace, score, tape, visitor, + selection, tracked_choices, value_choices, gradient_choices, + get_gen_fn(trace)) end function fill_submaps!( @@ -328,11 +306,11 @@ end function splice(state::GFBackpropTraceState, gen_fn::DynamicDSLFunction, args_maybe_tracked::Tuple) - prev_params = state.params - state.params = gen_fn.params - retval = exec(gen_fn, state, args_maybe_tracked) - state.params = prev_params - retval + prev_gen_fn = state.active_gen_fn + state.active_gen_fn = gen_fn + retval = exec(gen_fn, state, args) + state.active_gen_fn = prev_gen_fn + return retval end @noinline function ReverseDiff.special_reverse_exec!( @@ -374,7 +352,7 @@ end function choice_gradients(trace::DynamicDSLTrace, selection::Selection, retval_grad) gen_fn = trace.gen_fn tape = new_tape() - state = GFBackpropTraceState(trace, selection, gen_fn.params, tape) + state = GFBackpropTraceState(trace, selection, tape) args = get_args(trace) args_maybe_tracked = (map(maybe_track, args, gen_fn.has_argument_grads, fill(tape, length(args)))...,) diff --git a/src/dynamic/dynamic.jl b/src/dynamic/dynamic.jl index e136ed723..7f9e286de 100644 --- a/src/dynamic/dynamic.jl +++ b/src/dynamic/dynamic.jl @@ -9,15 +9,12 @@ Constructed using the `@gen` keyword. Most methods in the generative function interface involve a end-to-end execution of the function. """ struct DynamicDSLFunction{T} <: GenerativeFunction{T,DynamicDSLTrace} - params_grad::Dict{Symbol,Any} - params::Dict{Symbol,Any} arg_types::Vector{Type} has_defaults::Bool arg_defaults::Vector{Union{Some{Any},Nothing}} julia_function::Function has_argument_grads::Vector{Bool} accepts_output_grad::Bool - params_grad_lock::ReentrantLock end function DynamicDSLFunction(arg_types::Vector{Type}, @@ -25,14 +22,11 @@ function DynamicDSLFunction(arg_types::Vector{Type}, julia_function::Function, has_argument_grads, ::Type{T}, accepts_output_grad::Bool) where {T} - params_grad = Dict{Symbol,Any}() - params = Dict{Symbol,Any}() has_defaults = any(arg -> arg != nothing, arg_defaults) - DynamicDSLFunction{T}(params_grad, params, arg_types, + return DynamicDSLFunction{T}(arg_types, has_defaults, arg_defaults, julia_function, - has_argument_grads, accepts_output_grad, - ReentrantLock()) + has_argument_grads, accepts_output_grad) end function DynamicDSLTrace(gen_fn::T, args) where {T<:DynamicDSLFunction} @@ -42,17 +36,21 @@ function DynamicDSLTrace(gen_fn::T, args) where {T<:DynamicDSLFunction} defaults = map(x -> something(x), defaults) args = Tuple(vcat(collect(args), defaults)) end - DynamicDSLTrace{T}(gen_fn, args) + return DynamicDSLTrace{T}(gen_fn, args) end accepts_output_grad(gen_fn::DynamicDSLFunction) = gen_fn.accepts_output_grad mutable struct GFUntracedState - params::Dict{Symbol,Any} + gen_fn::GenerativeFunction + parameter_store::JuliaParameterStore end +get_parameter_store(state::GFUntracedState) = state.parameter_store +get_parameter_id(state::GFUntracedState, name::Symbol) = (state.gen_fn, name) + function (gen_fn::DynamicDSLFunction)(args...) - state = GFUntracedState(gen_fn.params) + state = GFUntracedState(gen_fn, default_julia_parameter_store) gen_fn.julia_function(state, args...) end @@ -104,11 +102,9 @@ function dynamic_param_impl(expr::Expr) end function read_param(state, name::Symbol) - if haskey(state.params, name) - state.params[name] - else - throw(UndefVarError(name)) - end + parameter_id = get_parameter_id(state, name) + store = get_parameter_store(state) + return get_parameter_value(store, parameter_id) end ################## diff --git a/src/dynamic/generate.jl b/src/dynamic/generate.jl index a89e0c352..462bd385f 100644 --- a/src/dynamic/generate.jl +++ b/src/dynamic/generate.jl @@ -3,12 +3,13 @@ mutable struct GFGenerateState constraints::ChoiceMap weight::Float64 visitor::AddressVisitor - params::Dict{Symbol,Any} + active_gen_fn::DynamicDSLFunction # mutated by splicing + parameter_context::Dict{Symbol,Any} end -function GFGenerateState(gen_fn, args, constraints, params) - trace = DynamicDSLTrace(gen_fn, args) - GFGenerateState(trace, constraints, 0., AddressVisitor(), params) +function GFGenerateState(gen_fn, args, constraints, parameter_context) + trace = DynamicDSLTrace(gen_fn, args, parameter_context) + GFGenerateState(trace, constraints, 0., AddressVisitor(), parameter_context) end function traceat(state::GFGenerateState, dist::Distribution{T}, @@ -55,7 +56,8 @@ function traceat(state::GFGenerateState, gen_fn::GenerativeFunction{T,U}, constraints = get_submap(state.constraints, key) # get subtrace - (subtrace, weight) = generate(gen_fn, args, constraints) + (subtrace, weight) = generate( + gen_fn, args, constraints; parameter_context=parameter_context) # add to the trace add_call!(state.trace, key, subtrace) @@ -71,16 +73,17 @@ end function splice(state::GFGenerateState, gen_fn::DynamicDSLFunction, args::Tuple) - prev_params = state.params - state.params = gen_fn.params + prev_gen_fn = state.active_gen_fn + state.active_gen_fn = gen_fn retval = exec(gen_fn, state, args) - state.params = prev_params - retval + state.active_gen_fn = prev_gen_fn + return retval end -function generate(gen_fn::DynamicDSLFunction, args::Tuple, - constraints::ChoiceMap) - state = GFGenerateState(gen_fn, args, constraints, gen_fn.params) +function generate( + gen_fn::DynamicDSLFunction, args::Tuple, constraints::ChoiceMap; + parameter_context=default_parameter_context) + state = GFGenerateState(gen_fn, args, constraints, parameter_context) retval = exec(gen_fn, state, args) set_retval!(state.trace, retval) (state.trace, state.weight) diff --git a/src/dynamic/simulate.jl b/src/dynamic/simulate.jl index 7db1a213a..75aa07754 100644 --- a/src/dynamic/simulate.jl +++ b/src/dynamic/simulate.jl @@ -1,14 +1,18 @@ mutable struct GFSimulateState trace::DynamicDSLTrace visitor::AddressVisitor - params::Dict{Symbol,Any} + active_gen_fn::DynamicDSLFunction # mutated by splicing + parameter_context::Dict{Symbol,Any} end -function GFSimulateState(gen_fn::GenerativeFunction, args::Tuple, params) - trace = DynamicDSLTrace(gen_fn, args) - GFSimulateState(trace, AddressVisitor(), params) +function GFSimulateState(gen_fn::GenerativeFunction, args::Tuple, parameter_context) + trace = DynamicDSLTrace(gen_fn, args, parameter_context) + GFSimulateState(trace, AddressVisitor(), parameter_context) end +get_parameter_store(state::GFSimulateState) = get_parameter_store(state.trace) +get_parameter_id(state::GFSimulateState, name::Symbol) = (state.active_gen_fn, name) + function traceat(state::GFSimulateState, dist::Distribution{T}, args, key) where {T} local retval::T @@ -36,7 +40,7 @@ function traceat(state::GFSimulateState, gen_fn::GenerativeFunction{T,U}, visit!(state.visitor, key) # get subtrace - subtrace = simulate(gen_fn, args) + subtrace = simulate(gen_fn, args; parameter_context=state.parameter_context) # add to the trace add_call!(state.trace, key, subtrace) @@ -49,15 +53,17 @@ end function splice(state::GFSimulateState, gen_fn::DynamicDSLFunction, args::Tuple) - prev_params = state.params - state.params = gen_fn.params + prev_gen_fn = state.active_gen_fn + state.active_gen_fn = gen_fn retval = exec(gen_fn, state, args) - state.params = prev_params - retval + state.active_gen_fn = prev_gen_fn + return retval end -function simulate(gen_fn::DynamicDSLFunction, args::Tuple) - state = GFSimulateState(gen_fn, args, gen_fn.params) +function simulate( + gen_fn::DynamicDSLFunction, args::Tuple; + parameter_context=default_parameter_context) + state = GFSimulateState(gen_fn, args, parameter_context) retval = exec(gen_fn, state, args) set_retval!(state.trace, retval) state.trace diff --git a/src/dynamic/trace.jl b/src/dynamic/trace.jl index 8c02eceb5..6ef6e35cf 100644 --- a/src/dynamic/trace.jl +++ b/src/dynamic/trace.jl @@ -37,14 +37,18 @@ mutable struct DynamicDSLTrace{T} <: Trace score::Float64 noise::Float64 args::Tuple + parameter_store::JuliaParameterStore retval::Any - function DynamicDSLTrace{T}(gen_fn::T, args) where {T} + function DynamicDSLTrace{T}(gen_fn::T, args, parameter_context) where {T} trie = Trie{Any,ChoiceOrCallRecord}() + parameter_store = get_julia_store(parameter_context) # retval is not known yet - new(gen_fn, trie, true, 0, 0, args) + new(gen_fn, trie, true, 0, 0, args, parameter_store) end end +get_parameter_store(trace::DynamicDSLTrace) = trace.parameter_store + set_retval!(trace::DynamicDSLTrace, retval) = (trace.retval = retval) function has_choice(trace::DynamicDSLTrace, addr) diff --git a/src/gen_fn_interface.jl b/src/gen_fn_interface.jl index 67d997df1..439068634 100644 --- a/src/gen_fn_interface.jl +++ b/src/gen_fn_interface.jl @@ -124,7 +124,7 @@ get_return_type(::GenerativeFunction{T,U}) where {T,U} = T get_trace_type(::GenerativeFunction{T,U}) where {T,U} = U """ - parameters::Dict{ParameterStore,Vector} = get_parameters(gen_fn::GenerativeFunction) + parameters::Dict{ParameterStore,Vector} = get_parameters(gen_fn::GenerativeFunction, parameter_context) Returns the parameters used by the generative function (including all of its calls). """ diff --git a/src/optimization.jl b/src/optimization.jl index a3ea4b250..c005a7801 100644 --- a/src/optimization.jl +++ b/src/optimization.jl @@ -6,12 +6,14 @@ import Parameters # so that everything is gradient descent instead of ascent. this will also fix # the misnomer names # -# combinators and call_at! and choice_at! all need to implement get_parameters.. +# combinators (map etc.) and call_at! and choice_at! all need to implement get_parameters.. # # make changes to src/static_ir/backprop.jl # # make changes to src/dynamic/dynamic.jl (use the JuliaParameterStore) # +# TODO GF untraced needs to reference a parameter store +# # make changes to src/dynamic/backprop.jl export in_place_add! @@ -200,22 +202,24 @@ end # Julia # ######### -# TODO document const JuliaParameterID = Tuple{GenerativeFunction,Symbol} # TODO document struct JuliaParameterStore - values::Dict{JuliaParameterID,Any} - gradient_accumulators::Dict{JuliaParameterID,GradientAccumulator} + values::Dict{GenerativeFunction,Dict{Symbol,Any}} + gradient_accumulators::Dict{GenerativeFunction,Dict{Symbol,GradientAccumulator}} end function JuliaParameterStore() return JuliaParameterStore( - Dict{JuliaParameterID,Any}(), - Dict{JuliaParameterID,GradientAccumulator}()) + Dict{GenerativeFunction,Dict{Symbol,Any}}(), + Dict{GenerativeFunction,Dict{Symbol,GradientAccumulator}}()) end +get_local_parameters(store::JuliaParameterStore, gen_fn) = store.values[gen_fn] + # TODO document +const default_parameter_context = Dict{Symbol,Any}() const default_julia_parameter_store = JuliaParameterStore() # for looking up in a parameter context when tracing (simulate, generate) @@ -228,6 +232,7 @@ function get_julia_store(context::Dict{Symbol,Any}) return context[JULIA_PARAMETER_STORE_KEY] else return default_julia_parameter_store + end end """ @@ -245,18 +250,31 @@ initialize_parameter!(foo, :theta, 0.6) Not thread-safe. """ function initialize_parameter!(store::JuliaParameterStore, id::JuliaParameterID, value) - store.values[id] = value + (gen_fn, name) = id + if !haskey(store.values, gen_fn) + store.values[gen_fn] = Dict{Symbol,Any}() + end + store.values[gen_fn][name] = value reset_gradient!(store, id) return nothing end # TODO docstring (not thread-safe) function reset_gradient!(store::JuliaParameterStore, id::JuliaParameterID) - !haskey(store.values, id) && error("parameter not initialized: $id") - if haskey(store.gradient_accumulators, id) - fill_with_zeros!(store.gradient_accumulators[id]) + (gen_fn, name) = id + try + value = store.values[gen_fn][name] + catch KeyError + @error "parameter not initialized: $id" + rethrow() + end + if !haskey(store.gradient_accumulators, gen_fn) + store.gradient_accumulators[gen_fn] = Dict{Symbol,Any}() + end + if haskey(store.gradient_accumulators[gen_fn], name) + fill_with_zeros!(store.gradient_accumulators[gen_fn][name]) else - store.gradient_accumulators[id] = Accumulator(zero(store.values[id])) + store.gradient_accumulators[gen_fn][name] = Accumulator(zero(value)) end return nothing end @@ -265,7 +283,13 @@ end function increment_gradient!( store::JuliaParameterStore, id::JuliaParameterID, increment, scale_factor) - in_place_add!(store.gradient_accumulators[id], increment, scale_factor) + (gen_fn, name) = id + try + in_place_add!(store.gradient_accumulators[gen_fn][name], increment, scale_factor) + catch KeyError + @error "parameter not initialized: $id" + rethrow() + end return nothing end @@ -273,24 +297,48 @@ end function increment_gradient!( store::JuliaParameterStore, id::JuliaParameterID, increment) - in_place_add!(store.gradient_accumulators[id], increment) + (gen_fn, name) = id + try + in_place_add!(store.gradient_accumulators[gen_fn][name], increment) + catch KeyError + @error "parameter not initialized: $id" + rethrow() + end return nothing end # TODO docstring (not thread-safe) function get_parameter_value(store::JuliaParameterStore, id::JuliaParameterID) - return store.values[id] + (gen_fn, name) = id + try + return state.values[gen_fn][name] + catch KeyError + @error "parameter not initialized: $id" + rethrow() + end end # TODO docstring (not thread-safe) function set_parameter_value!(store::JuliaParameterStore, id::JuliaParameterID, value) - store.values[id] = value + (gen_fn, name) = id + try + store.values[gen_fn][name] = value + catch KeyError + @error "parameter not initialized: $id" + rethrow() + end return nothing end # TODO docstring (not thread-safe) function get_gradient(store::JuliaParameterStore, id::JuliaParameterID) - return get_value(store.gradient_accumulators[id]) + (gen_fn, name) = id + try + return get_value(store.gradient_accumulators[gen_fn][name]) + catch KeyError + @error "parameter not initialized: $id" + rethrow() + end end ##################################################### From 4b5d8048ea90e0f80099f61c002f7b3659399ac7 Mon Sep 17 00:00:00 2001 From: Marco Cusumano-Towner Date: Sun, 16 May 2021 15:37:27 -0400 Subject: [PATCH 06/24] checkpoint on SML changes --- src/optimization.jl | 16 +++-- src/static_ir/backprop.jl | 131 +++++++++++++++++++------------------- src/static_ir/generate.jl | 9 ++- src/static_ir/simulate.jl | 11 ++-- src/static_ir/trace.jl | 60 ++++++++--------- src/static_ir/update.jl | 2 +- 6 files changed, 122 insertions(+), 107 deletions(-) diff --git a/src/optimization.jl b/src/optimization.jl index c005a7801..6a0cd67b2 100644 --- a/src/optimization.jl +++ b/src/optimization.jl @@ -293,21 +293,27 @@ function increment_gradient!( return nothing end -# TODO docstring (thread-safe) -function increment_gradient!( - store::JuliaParameterStore, id::JuliaParameterID, - increment) +function get_gradient_accumulator(store::JuliaParameterStore, id::JuliaParameterID) (gen_fn, name) = id try - in_place_add!(store.gradient_accumulators[gen_fn][name], increment) + return store.gradient_accumulators[gen_fn][name] catch KeyError @error "parameter not initialized: $id" rethrow() end +end + +# TODO docstring (thread-safe) +function increment_gradient!( + store::JuliaParameterStore, id::JuliaParameterID, + increment) + accumulator = get_gradient_accumulator(store, id) + in_place_add!(accumulator, increment) return nothing end # TODO docstring (not thread-safe) + function get_parameter_value(store::JuliaParameterStore, id::JuliaParameterID) (gen_fn, name) = id try diff --git a/src/static_ir/backprop.jl b/src/static_ir/backprop.jl index 595287d35..ad5110fae 100644 --- a/src/static_ir/backprop.jl +++ b/src/static_ir/backprop.jl @@ -21,30 +21,32 @@ maybe_tracked_value_var(node::JuliaNode) = Symbol("$(maybe_tracked_value_prefix) const maybe_tracked_arg_prefix = gensym("maybe_tracked_arg") maybe_tracked_arg_var(node::JuliaNode, i::Int) = Symbol("$(maybe_tracked_arg_prefix)_$(node.name)_$i") -function fwd_pass!(selected_choices, selected_calls, fwd_marked, node::TrainableParameterNode) - # TODO: only need to mark it if we are doing backprop params +function fwd_pass!(selected_choices, selected_calls, fwd_marked, node::TrainableParameterNode, ::BackpropParamsMode) push!(fwd_marked, node) end -function fwd_pass!(selected_choices, selected_calls, fwd_marked, node::ArgumentNode) +function fwd_pass!(selected_choices, selected_calls, fwd_marked, node::TrainableParameterNode, ::BackpropTraceMode) +end + +function fwd_pass!(selected_choices, selected_calls, fwd_marked, node::ArgumentNode, mode) if node.compute_grad push!(fwd_marked, node) end end -function fwd_pass!(selected_choices, selected_calls, fwd_marked, node::JuliaNode) +function fwd_pass!(selected_choices, selected_calls, fwd_marked, node::JuliaNode, mode) if any(input_node in fwd_marked for input_node in node.inputs) push!(fwd_marked, node) end end -function fwd_pass!(selected_choices, selected_calls, fwd_marked, node::RandomChoiceNode) +function fwd_pass!(selected_choices, selected_calls, fwd_marked, node::RandomChoiceNode, mode) if node in selected_choices push!(fwd_marked, node) end end -function fwd_pass!(selected_choices, selected_calls, fwd_marked, node::GenerativeFunctionCallNode) +function fwd_pass!(selected_choices, selected_calls, fwd_marked, node::GenerativeFunctionCallNode, mode) if node in selected_calls || any(input_node in fwd_marked for input_node in node.inputs) push!(fwd_marked, node) end @@ -73,6 +75,9 @@ end function back_pass!(back_marked, node::GenerativeFunctionCallNode) # the logpdf of every generative function call is a SINK + # (we could ask whether the generative function is deterministic or not + # as a perforance optimization, because only stochsatic generative functions + # actually have a non-trivial logpdf) for (input_node, has_grad) in zip(node.inputs, has_argument_grads(node.generative_function)) if has_grad push!(back_marked, input_node) @@ -82,8 +87,7 @@ end function fwd_codegen!(stmts, fwd_marked, back_marked, node::TrainableParameterNode) if node in back_marked - push!(stmts, :($(node.name) = $(QuoteNode(get_param))($(QuoteNode(get_gen_fn))(trace), - $(QuoteNode(node.name))))) + push!(stmts, :($(node.name) = $(QuoteNode(get_parameter_value))(trace, $(QuoteNode(node.name))))) end if node in fwd_marked && node in back_marked @@ -91,8 +95,7 @@ function fwd_codegen!(stmts, fwd_marked, back_marked, node::TrainableParameterNo # initialize gradient to zero # NOTE: we are avoiding allocating a new gradient accumulator for this function # instead, we are using the threadsafe gradient accumulator directly.. - #push!(stmts, :($(gradient_var(node)) = zero($(node.name)))) - push!(stmts, :($(gradient_var(node)) = $(QuoteNode(get_gen_fn))(trace).params_grad[$(QuoteNode(node.name))])) + push!(stmts, :($(gradient_var(node)) = $(QuoteNode(get_gradient_accumulator))(trace, $(QuoteNode(node.name))))) end end @@ -106,7 +109,7 @@ end function fwd_codegen!(stmts, fwd_marked, back_marked, node::JuliaNode) - if node in back_marked && any(input_node in fwd_marked for input_node in node.inputs) + if (node in fwd_marked) && (node in back_marked) # tracked forward execution tape = tape_var(node) @@ -128,27 +131,20 @@ function fwd_codegen!(stmts, fwd_marked, back_marked, node::JuliaNode) # initialize gradient to zero push!(stmts, :($(gradient_var(node)) = zero($(node.name)))) - else - # regular forward execution. + elseif node in back_marked - # we need the value for initializing gradient to zero (to get the type - # and e.g. shape), and for reference by other nodes during - # back_codegen! we could be more selective about which JuliaNodes need - # to be evalutaed, that is a performance optimization for the future + # regular forward execution. args = map((input_node) -> input_node.name, node.inputs) push!(stmts, :($(node.name) = $(QuoteNode(node.fn))($(args...)))) end end function fwd_codegen!(stmts, fwd_marked, back_marked, node::RandomChoiceNode) - # for reference by other nodes during back_codegen! - # could performance optimize this away - push!(stmts, :($(node.name) = trace.$(get_value_fieldname(node)))) - # every random choice is in back_marked, since it affects it logpdf, but # also possibly due to other downstream usage of the value @assert node in back_marked + push!(stmts, :($(node.name) = trace.$(get_value_fieldname(node)))) if node in fwd_marked # the only way we are fwd_marked is if this choice was selected @@ -160,14 +156,16 @@ function fwd_codegen!(stmts, fwd_marked, back_marked, node::RandomChoiceNode) end function fwd_codegen!(stmts, fwd_marked, back_marked, node::GenerativeFunctionCallNode) - # for reference by other nodes during back_codegen! - # could performance optimize this away - subtrace_fieldname = get_subtrace_fieldname(node) - push!(stmts, :($(node.name) = $(QuoteNode(get_retval))(trace.$subtrace_fieldname))) + + if node in back_marked + # for reference by other nodes during back_codegen! + subtrace_fieldname = get_subtrace_fieldname(node) + push!(stmts, :($(node.name) = $(QuoteNode(get_retval))(trace.$subtrace_fieldname))) + end # NOTE: we will still potentially run choice_gradients recursively on the generative function, # we just might not use its return value gradient. - if node in fwd_marked && node in back_marked + if (node in fwd_marked) && (node in back_marked) # we are fwd_marked if an input was fwd_marked, or if we were selected internally push!(stmts, :($(gradient_var(node)) = zero($(node.name)))) end @@ -181,15 +179,6 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node: push!(stmts, :(isnothing(retval_grad) && error("Required return value gradient but got nothing"))) push!(stmts, :($(gradient_var(node)) = $(QuoteNode(in_place_add!))( $(gradient_var(node)), retval_grad, scale_factor))) - #push!(stmts, :($(gradient_var(node)) += retval_grad)) - end - - if node in fwd_marked && node in back_marked - #NOTE: unecessary, because we accumulated in-place already - #push!(stmts, :($(QuoteNode(increment_param_grad!))(trace.$static_ir_gen_fn_ref, - #$(QuoteNode(node.name)), - #$(gradient_var(node)), - #scale_factor))) end end @@ -199,8 +188,7 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node: if node === ir.return_node && node in fwd_marked @assert node in back_marked push!(stmts, :($(gradient_var(node)) = $(QuoteNode(in_place_add!))( - $(gradient_var(node)), retval_grad, scale_factor))) - #push!(stmts, :($(gradient_var(node)) += retval_grad)) + $(gradient_var(node)), retval_grad))) end end @@ -209,8 +197,7 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node: if node === ir.return_node && node in fwd_marked @assert node in back_marked push!(stmts, :($(gradient_var(node)) = $(QuoteNode(in_place_add!))( - $(gradient_var(node)), retval_grad, scale_factor))) - #push!(stmts, :($(gradient_var(node)) += retval_grad)) + $(gradient_var(node)), retval_grad))) end if node in back_marked && any(input_node in fwd_marked for input_node in node.inputs) @@ -222,9 +209,13 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node: for (i, input_node) in enumerate(node.inputs) if input_node in fwd_marked arg_maybe_tracked = maybe_tracked_arg_var(node, i) - push!(stmts, :($(gradient_var(input_node)) = $(QuoteNode(in_place_add!))( - $(gradient_var(input_node)), $(QuoteNode(deriv))($arg_maybe_tracked), scale_factor))) - #push!(stmts, :($(gradient_var(input_node)) += $(QuoteNode(deriv))($arg_maybe_tracked))) + if isa(input_node, TrainableParameterNode) + push!(stmts, :($(gradient_var(input_node)) = $(QuoteNode(in_place_add!))( + $(gradient_var(input_node)), $(QuoteNode(deriv))($arg_maybe_tracked), scale_factor))) + else + push!(stmts, :($(gradient_var(input_node)) = $(QuoteNode(in_place_add!))( + $(gradient_var(input_node)), $(QuoteNode(deriv))($arg_maybe_tracked)))) + end end end end @@ -246,9 +237,15 @@ function back_codegen_random_choice_to_inputs!(stmts, ir, fwd_marked, back_marke if !has_argument_grads(node.dist)[i] error("Distribution $(node.dist) does not have logpdf gradient for argument $i") end - push!(stmts, :($(gradient_var(input_node)) = $(QuoteNode(in_place_add!))( - $(gradient_var(input_node)), $logpdf_grad[$(QuoteNode(i+1))], scale_factor))) - #push!(stmts, :($(gradient_var(input_node)) += $logpdf_grad[$(QuoteNode(i+1))])) + input_node_grad = gradient_var(input_node) + increment = :($logpdf_grad[$(QuoteNode(i+1))]) + if isa(input_node, TrainableParameterNode) + push!(stmts, :($input_node_grad = $(QuoteNode(in_place_add!))( + $input_node_grad, $increment, scale_factor))) + else + push!(stmts, :($input_node_grad = $(QuoteNode(in_place_add!))( + $input_node_grad, $increment))) + end end end @@ -257,7 +254,6 @@ function back_codegen_random_choice_to_inputs!(stmts, ir, fwd_marked, back_marke @assert node in back_marked push!(stmts, :($(gradient_var(node)) = $(QuoteNode(in_place_add!))( $(gradient_var(node)), retval_grad, scale_factor))) - #push!(stmts, :($(gradient_var(node)) += retval_grad)) end end @@ -274,8 +270,7 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, error("Distribution $dist does not logpdf gradient for its output value") end push!(stmts, :($(gradient_var(node)) = $(QuoteNode(in_place_add!))( - $(gradient_var(node)), retval_grad, scale_factor))) - #push!(stmts, :($(gradient_var(node)) += $logpdf_grad[1])) + $(gradient_var(node)), retval_grad))) end end @@ -292,8 +287,7 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, if node === ir.return_node && node in fwd_marked @assert node in back_marked push!(stmts, :($(gradient_var(node)) = $(QuoteNode(in_place_add!))( - $(gradient_var(node)), retval_grad, scale_factor))) - #push!(stmts, :($(gradient_var(node)) += retval_grad)) + $(gradient_var(node)), retval_grad))) end if node in fwd_marked @@ -316,9 +310,10 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, for (i, input_node) in enumerate(node.inputs) if input_node in fwd_marked @assert input_node in back_marked # this ensured its gradient will have been initialized + input_node_grad = gradient_var(input_node) + increment = :($input_grads[$(QuoteNode(i))]) push!(stmts, :($(gradient_var(input_node)) = $(QuoteNode(in_place_add!))( - $(gradient_var(input_node)), $input_grads[$(QuoteNode(i))], scale_factor))) - #push!(stmts, :($(gradient_var(input_node)) += $input_grads[$(QuoteNode(i))])) + $input_node_grad, $increment))) end end @@ -332,8 +327,7 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, if node === ir.return_node && node in fwd_marked @assert node in back_marked push!(stmts, :($(gradient_var(node)) = $(QuoteNode(in_place_add!))( - $(gradient_var(node)), retval_grad, scale_factor))) - #push!(stmts, :($(gradient_var(node)) += retval_grad)) + $(gradient_var(node)), retval_grad))) end if node in fwd_marked @@ -347,9 +341,15 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, for (i, (input_node, has_grad)) in enumerate(zip(node.inputs, has_argument_grads(node.generative_function))) if input_node in fwd_marked && has_grad @assert input_node in back_marked # this ensured its gradient will have been initialized - push!(stmts, :($(gradient_var(input_node)) = $(QuoteNode(in_place_add!))( - $(gradient_var(input_node)), $input_grads[$(QuoteNode(i))], scale_factor))) - #push!(stmts, :($(gradient_var(input_node)) += $input_grads[$(QuoteNode(i))])) + input_node_grad = gradient_var(input_node) + increment = :($input_grads[$(QuoteNode(i))]) + if isa(input_node, TrainableParameterNode) + push!(stmts, :($(gradient_var(input_node)) = $(QuoteNode(in_place_add!))( + $input_node_grad, $increment, scale_factor))) + else + push!(stmts, :($(gradient_var(input_node)) = $(QuoteNode(in_place_add!))( + $input_node_grad, $increment))) + end end end end @@ -425,8 +425,6 @@ function codegen_choice_gradients(trace_type::Type{T}, selection_type::Type, return quote choice_gradients(trace, StaticSelection(selection), retval_grad) end end - push!(stmts, :(scale_factor = NaN)) - ir = get_ir(gen_fn_type) selected_choices = get_selected_choices(schema, ir) selected_calls = get_selected_calls(schema, ir) @@ -434,7 +432,7 @@ function codegen_choice_gradients(trace_type::Type{T}, selection_type::Type, # forward marking pass fwd_marked = Set{StaticIRNode}() for node in ir.nodes - fwd_pass!(selected_choices, selected_calls, fwd_marked, node) + fwd_pass!(selected_choices, selected_calls, fwd_marked, node, BackpropTraceMode()) end # backward marking pass @@ -489,13 +487,13 @@ function codegen_accumulate_param_gradients!(trace_type::Type{T}, selected_calls = Set{GenerativeFunctionCallNode}( node for node in ir.nodes if isa(node, GenerativeFunctionCallNode)) - # forward marking pass + # forward marking pass (propagate forward from 'sources') fwd_marked = Set{StaticIRNode}() for node in ir.nodes - fwd_pass!(selected_choices, selected_calls, fwd_marked, node) + fwd_pass!(selected_choices, selected_calls, fwd_marked, node, BackpropParamsMode()) end - # backward marking pass + # backward marking pass (propagate backwards from 'sinks') back_marked = Set{StaticIRNode}() push!(back_marked, ir.return_node) for node in reverse(ir.nodes) @@ -508,12 +506,15 @@ function codegen_accumulate_param_gradients!(trace_type::Type{T}, arg_names = Symbol[arg_node.name for arg_node in ir.arg_nodes] push!(stmts, :($(Expr(:tuple, arg_names...)) = $(QuoteNode(get_args))(trace))) - # forward code-generation pass (initialize gradients to zero, create needed references) + # forward code-generation pass + # any node that is backward-marked creates a variable for its current value + # any node that is forward-marked and backwards marked initializes a gradient variable for node in ir.nodes fwd_codegen!(stmts, fwd_marked, back_marked, node) end - # backward code-generation pass (increment gradients) + # backward code-generation pass + # any node that is forward-marked and backwards marked increments its gradient variable for node in reverse(ir.nodes) back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node, BackpropParamsMode()) end diff --git a/src/static_ir/generate.jl b/src/static_ir/generate.jl index 3557f19b1..d57f59788 100644 --- a/src/static_ir/generate.jl +++ b/src/static_ir/generate.jl @@ -6,7 +6,7 @@ end function process!(::StaticIRGenerateState, node, options) end function process!(state::StaticIRGenerateState, node::TrainableParameterNode, options) - push!(state.stmts, :($(node.name) = $(QuoteNode(get_param))(gen_fn, $(QuoteNode(node.name))))) + push!(stmts, :($(node.name) = $(QuoteNode(get_parameter_value))(trace, $(QuoteNode(node.name))))) end function process!(state::StaticIRGenerateState, node::ArgumentNode, options) @@ -84,6 +84,7 @@ function codegen_generate(gen_fn_type::Type{T}, args, push!(stmts, :($total_noise_fieldname = 0.)) push!(stmts, :($weight = 0.)) push!(stmts, :($num_nonempty_fieldname = 0)) + push!(stmts, :($parameter_store_fieldname = $(QuoteNode(get_julia_store))(parameter_context))) # unpack arguments arg_names = Symbol[arg_node.name for arg_node in ir.arg_nodes] @@ -109,8 +110,10 @@ function codegen_generate(gen_fn_type::Type{T}, args, end push!(generated_functions, quote -@generated function $(GlobalRef(Gen, :generate))(gen_fn::$(QuoteNode(StaticIRGenerativeFunction)), - args::$(QuoteNode(Tuple)), constraints::$(QuoteNode(ChoiceMap))) +@generated function $(GlobalRef(Gen, :generate))( + gen_fn::$(QuoteNode(StaticIRGenerativeFunction)), + args::$(QuoteNode(Tuple)), constraints::$(QuoteNode(ChoiceMap)); + parameter_context=default_parameter_context) $(QuoteNode(codegen_generate))(gen_fn, args, constraints) end end) diff --git a/src/static_ir/simulate.jl b/src/static_ir/simulate.jl index 48c215a3c..ab2b769dc 100644 --- a/src/static_ir/simulate.jl +++ b/src/static_ir/simulate.jl @@ -5,7 +5,7 @@ end function process!(::StaticIRSimulateState, node, options) end function process!(state::StaticIRSimulateState, node::TrainableParameterNode, options) - push!(state.stmts, :($(node.name) = $(QuoteNode(get_param))(gen_fn, $(QuoteNode(node.name))))) + push!(stmts, :($(node.name) = $(QuoteNode(get_parameter_value))(trace, $(QuoteNode(node.name))))) end function process!(state::StaticIRSimulateState, node::ArgumentNode, options) @@ -47,7 +47,7 @@ function process!(state::StaticIRSimulateState, node::GenerativeFunctionCallNode push!(state.stmts, :($total_noise_fieldname += $(GlobalRef(Gen, :project))($subtrace, $(GlobalRef(Gen, :EmptySelection))()))) end -function codegen_simulate(gen_fn_type::Type{T}, args) where {T <: StaticIRGenerativeFunction} +function codegen_simulate(gen_fn_type::Type{T}, args, parameter_context_type) where {T <: StaticIRGenerativeFunction} ir = get_ir(gen_fn_type) options = get_options(gen_fn_type) @@ -57,6 +57,7 @@ function codegen_simulate(gen_fn_type::Type{T}, args) where {T <: StaticIRGenera push!(stmts, :($total_score_fieldname = 0.)) push!(stmts, :($total_noise_fieldname = 0.)) push!(stmts, :($num_nonempty_fieldname = 0)) + push!(stmts, :($parameter_store_fieldname = $(QuoteNode(get_julia_store))(parameter_context))) # unpack arguments arg_names = Symbol[arg_node.name for arg_node in ir.arg_nodes] @@ -83,7 +84,9 @@ function codegen_simulate(gen_fn_type::Type{T}, args) where {T <: StaticIRGenera end push!(generated_functions, quote -@generated function $(GlobalRef(Gen, :simulate))(gen_fn::$(QuoteNode(StaticIRGenerativeFunction)), args::Tuple) - $(QuoteNode(codegen_simulate))(gen_fn, args) +@generated function $(GlobalRef(Gen, :simulate))( + gen_fn::$(QuoteNode(StaticIRGenerativeFunction)), args::Tuple; + parameter_context=default_parameter_context) + $(QuoteNode(codegen_simulate))(gen_fn, args, parameter_context) end end) diff --git a/src/static_ir/trace.jl b/src/static_ir/trace.jl index de2c84b30..9f520240a 100644 --- a/src/static_ir/trace.jl +++ b/src/static_ir/trace.jl @@ -37,6 +37,34 @@ static_get_submap(::StaticIRTraceAssmt, ::Val) = EmptyChoiceMap() abstract type StaticIRTrace <: Trace end +# fixed fields shared by all StaticIRTraces +const num_nonempty_fieldname = :num_nonempty +const total_score_fieldname = :score +const total_noise_fieldname = :noise +const return_value_fieldname = :retval +const parameter_store_fieldname = :parameter_store + + +# other fields based on user-defined variable names are prefixed to avoid collisions +get_value_fieldname(node::ArgumentNode) = Symbol("#arg#_$(node.name)") +get_value_fieldname(node::RandomChoiceNode) = Symbol("#choice_value#_$(node.addr)") +get_value_fieldname(node::JuliaNode) = Symbol("#julia#_$(node.name)") +get_score_fieldname(node::RandomChoiceNode) = Symbol("#choice_score#_$(node.addr)") +get_subtrace_fieldname(node::GenerativeFunctionCallNode) = Symbol("#subtrace#_$(node.addr)") + + +# getters + +function get_parameter_value(trace::StaticIRTrace, name) + parameter_id = (get_gen_fn(trace), name) + return get_parameter_value(trace.parameter_store, parameter_id) +end + +function get_gradient_accumulator(trace::StaticIRTrace, name) + parameter_id = (get_gen_fn(trace), name) + return get_gradient_accumulator(trace.parameter_store, parameter_id) +end + @inline function static_get_subtrace(trace::StaticIRTrace, addr) error("Not implemented") end @@ -52,36 +80,8 @@ end return Gen.static_get_subtrace(trace, Val(first))[rest] end -const arg_prefix = gensym("arg") -const choice_value_prefix = gensym("choice_value") -const choice_score_prefix = gensym("choice_score") -const subtrace_prefix = gensym("subtrace") -const julia_prefix = gensym("julia_prefix") -function get_value_fieldname(node::ArgumentNode) - Symbol("$(arg_prefix)_$(node.name)") -end - -function get_value_fieldname(node::RandomChoiceNode) - Symbol("$(choice_value_prefix)_$(node.addr)") -end - -function get_value_fieldname(node::JuliaNode) - Symbol("$(julia_prefix)_$(node.name)") -end - -function get_score_fieldname(node::RandomChoiceNode) - Symbol("$(choice_score_prefix)_$(node.addr)") -end - -function get_subtrace_fieldname(node::GenerativeFunctionCallNode) - Symbol("$(subtrace_prefix)_$(node.addr)") -end -const num_nonempty_fieldname = gensym("num_nonempty") -const total_score_fieldname = gensym("score") -const total_noise_fieldname = gensym("noise") -const return_value_fieldname = gensym("retval") struct TraceField fieldname::Symbol @@ -115,6 +115,7 @@ function get_trace_fields(ir::StaticIR, options::StaticIRGenerativeFunctionOptio push!(fields, TraceField(total_noise_fieldname, QuoteNode(Float64))) push!(fields, TraceField(num_nonempty_fieldname, QuoteNode(Int))) push!(fields, TraceField(return_value_fieldname, ir.return_node.typ)) + push!(fields, parameter_store_fieldname, QuoteNode(JuliaParameterStore)()) return fields end @@ -125,7 +126,8 @@ function generate_trace_struct(ir::StaticIR, trace_struct_name::Symbol, options: fields = get_trace_fields(ir, options) field_exprs = map((f) -> Expr(:(::), f.fieldname, f.typ), fields) Expr(:struct, mutable, Expr(:(<:), trace_struct_name, QuoteNode(StaticIRTrace)), - Expr(:block, field_exprs..., Expr(:(::), static_ir_gen_fn_ref, QuoteNode(Any)))) + Expr(:block, field_exprs..., + Expr(:(::), static_ir_gen_fn_ref, QuoteNode(Any)))) end function generate_isempty(trace_struct_name::Symbol) diff --git a/src/static_ir/update.jl b/src/static_ir/update.jl index 3a491c7a8..c0fe486eb 100644 --- a/src/static_ir/update.jl +++ b/src/static_ir/update.jl @@ -144,7 +144,7 @@ end function process_codegen!(stmts, ::ForwardPassState, back::BackwardPassState, node::TrainableParameterNode, ::AbstractUpdateMode, options) if node in back.marked - push!(stmts, :($(node.name) = $(QuoteNode(get_param))($(QuoteNode(get_gen_fn))(trace), $(QuoteNode(node.name))))) + push!(stmts, :($(node.name) = $(QuoteNode(get_parameter_value))(trace, $(QuoteNode(node.name))))) end end From ad05f34b164c93133a50b4fd6bc45b054db1fdbe Mon Sep 17 00:00:00 2001 From: Marco Cusumano-Towner Date: Sun, 16 May 2021 20:00:43 -0400 Subject: [PATCH 07/24] checkpoint on getting tests to pass again --- src/Gen.jl | 3 - src/builtin_optimization.jl | 235 -------------------------- src/dynamic/backprop.jl | 28 ++-- src/dynamic/dynamic.jl | 22 ++- src/dynamic/generate.jl | 32 ++-- src/dynamic/regenerate.jl | 34 ++-- src/dynamic/simulate.jl | 30 ++-- src/dynamic/trace.jl | 3 +- src/dynamic/update.jl | 39 +++-- src/gen_fn_interface.jl | 2 +- src/inference/train.jl | 10 +- src/inference/variational.jl | 48 +++--- src/optimization.jl | 307 +++++++++++++++++++++++++--------- src/static_ir/backprop.jl | 129 +++++++------- src/static_ir/generate.jl | 8 +- src/static_ir/simulate.jl | 8 +- src/static_ir/static_ir.jl | 11 +- src/static_ir/trace.jl | 15 +- src/static_ir/update.jl | 15 +- test/dsl/dynamic_dsl.jl | 44 +++-- test/inference/train.jl | 44 ++--- test/inference/variational.jl | 78 ++++----- test/optional_args.jl | 2 +- test/runtests.jl | 14 +- test/static_ir/static_ir.jl | 17 +- 25 files changed, 569 insertions(+), 609 deletions(-) diff --git a/src/Gen.jl b/src/Gen.jl index a3a40e38b..13a540400 100644 --- a/src/Gen.jl +++ b/src/Gen.jl @@ -64,9 +64,6 @@ include("dynamic/dynamic.jl") # static IR generative function include("static_ir/static_ir.jl") -# optimization for built-in generative functions (dynamic and static IR) -include("builtin_optimization.jl") - # DSLs for defining dynamic embedded and static IR generative functions # 'Dynamic DSL' and 'Static DSL' include("dsl/dsl.jl") diff --git a/src/builtin_optimization.jl b/src/builtin_optimization.jl index d4f5ade80..e69de29bb 100644 --- a/src/builtin_optimization.jl +++ b/src/builtin_optimization.jl @@ -1,235 +0,0 @@ -############################# - -# primitives for in-place gradient accumulation - -function in_place_add!(param::Array, increment::Array, scale_factor::Real) - # NOTE: it ignores the scale_factor, because it is not a parameter... - # scale factors only affect parameters - # TODO this is potentially very confusing! - @simd for i in 1:length(param) - param[i] += increment[i] - end - return param -end - -function in_place_add!(param::Array, increment::Array) - @inbounds @simd for i in 1:length(param) - param[i] += increment[i] - end - return param -end - -function in_place_add!(param::Real, increment::Real, scale_factor::Real) - return param + increment -end - -function in_place_add!(param::Real, increment::Real) - return param + increment -end - -mutable struct ThreadsafeAccumulator{T} - value::T - lock::ReentrantLock -end - -ThreadsafeAccumulator(value) = ThreadsafeAccumulator(value, ReentrantLock()) - -# TODO not threadsafe -function get_current_value(accum::ThreadsafeAccumulator) - return accum.value -end - -function in_place_add!(param::ThreadsafeAccumulator{Real}, increment::Real, scale_factor::Real) - lock(param.lock) - try - param.value = param.value + increment * scale_factor - finally - unlock(param.lock) - end - return param -end - -function in_place_add!(param::ThreadsafeAccumulator{Real}, increment::Real) - lock(param.lock) - try - param.value = param.value + increment - finally - unlock(param.lock) - end - return param -end - -function in_place_add!(param::ThreadsafeAccumulator{<:Array}, increment, scale_factor::Real) - lock(param.lock) - try - @simd for i in 1:length(param.value) - param.value[i] += increment[i] * scale_factor - end - finally - unlock(param.lock) - end - return param -end - -function in_place_add!(param::ThreadsafeAccumulator{<:Array}, increment) - lock(param.lock) - try - @simd for i in 1:length(param.value) - param.value[i] += increment[i] - end - finally - unlock(param.lock) - end - return param -end - -############################# - - - -""" - set_param!(gen_fn::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, name::Symbol, value) - -Set the value of a trainable parameter of the generative function. - -NOTE: Does not update the gradient accumulator value. -""" -function set_param!(gf::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, name::Symbol, value) - return gf.params[name] = value -end - -""" - value = get_param(gen_fn::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, name::Symbol) - -Get the current value of a trainable parameter of the generative function. -""" -function get_param(gf::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, name::Symbol) - return gf.params[name] -end - -""" - value = get_param_grad(gen_fn::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, name::Symbol) - -Get the current value of the gradient accumulator for a trainable parameter of the generative function. - -Not threadsafe. -""" -function get_param_grad(gf::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, name::Symbol) - try - val = gf.params_grad[name] # the accumulator - return get_current_value(val) - catch KeyError - error("parameter $name not found") - end - return val -end - -""" - zero_param_grad!(gen_fn::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, name::Symbol) - -Reset the gradient accumlator for a trainable parameter of the generative function to all zeros. - -Not threadsafe. -""" -function zero_param_grad!(gf::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, name::Symbol) - gf.params_grad[name] = ThreadsafeAccumulator(zero(gf.params[name])) # TODO avoid allocation? - return gf.params_grad[name] -end - -""" - set_param_grad!(gen_fn::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, name::Symbol, grad_value) - -Set the gradient accumlator for a trainable parameter of the generative function. - -Not threadsafe. -""" -function set_param_grad!(gf::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, name::Symbol, grad_value) - gf.params_grad[name] = ThreadsafeAccumulator(grad_value) - return grad_value -end - -# TODO document me; it is threadsafe.. -function increment_param_grad!(gf::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, name::Symbol, increment, scale_factor) - in_place_add!(gf.params_grad[name], increment, scale_factor) -end - -# TODO document me; it is threadsafe.. -function increment_param_grad!(gf::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, name::Symbol, increment) - in_place_add!(gf.params_grad[name], increment) -end - - - -""" - init_param!(gen_fn::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, name::Symbol, value) - -Initialize the the value of a named trainable parameter of a generative function. - -Also generates the gradient accumulator for that parameter to `zero(value)`. - -Example: -```julia -init_param!(foo, :theta, 0.6) -``` -""" -function init_param!(gf::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, name::Symbol, value) - set_param!(gf, name, value) - zero_param_grad!(gf, name) -end - -get_params(gf::Union{DynamicDSLFunction,StaticIRGenerativeFunction}) = keys(gf.params) - -export set_param!, get_param, get_param_grad, zero_param_grad!, set_param_grad!, init_param!, increment_param_grad! - -######################################### -# gradient descent with fixed step size # -######################################### - -mutable struct FixedStepGradientDescentBuiltinDSLState - step_size::Float64 - gen_fn::Union{DynamicDSLFunction,StaticIRGenerativeFunction} - param_list::Vector -end - -function init_update_state(conf::FixedStepGradientDescent, - gen_fn::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, param_list::Vector) - FixedStepGradientDescentBuiltinDSLState(conf.step_size, gen_fn, param_list) -end - -function apply_update!(state::FixedStepGradientDescentBuiltinDSLState) - for param_name in state.param_list - value = get_param(state.gen_fn, param_name) - grad = get_param_grad(state.gen_fn, param_name) - set_param!(state.gen_fn, param_name, value + grad * state.step_size) - zero_param_grad!(state.gen_fn, param_name) - end -end - -#################### -# gradient descent # -#################### - -mutable struct GradientDescentBuiltinDSLState - step_size_init::Float64 - step_size_beta::Float64 - gen_fn::Union{DynamicDSLFunction,StaticIRGenerativeFunction} - param_list::Vector - t::Int -end - -function init_update_state(conf::GradientDescent, - gen_fn::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, param_list::Vector) - GradientDescentBuiltinDSLState(conf.step_size_init, conf.step_size_beta, - gen_fn, param_list, 1) -end - -function apply_update!(state::GradientDescentBuiltinDSLState) - step_size = state.step_size_init * (state.step_size_beta + 1) / (state.step_size_beta + state.t) - for param_name in state.param_list - value = get_param(state.gen_fn, param_name) - grad = get_param_grad(state.gen_fn, param_name) - set_param!(state.gen_fn, param_name, value + grad * step_size) - zero_param_grad!(state.gen_fn, param_name) - end - state.t += 1 -end diff --git a/src/dynamic/backprop.jl b/src/dynamic/backprop.jl index c38a4ac27..361f29c14 100644 --- a/src/dynamic/backprop.jl +++ b/src/dynamic/backprop.jl @@ -42,10 +42,10 @@ mutable struct GFBackpropParamsState visitor::AddressVisitor scale_factor::Float64 active_gen_fn::GenerativeFunction - tracked_params::Dict{ParameterID,Any} + tracked_params::Dict{Tuple{GenerativeFunction,Symbol},Any} function GFBackpropParamsState(trace::DynamicDSLTrace, tape, scale_factor) - tracked_params = Dict{ParameterID,Any}() + tracked_params = Dict{Tuple{GenerativeFunction,Symbol},Any}() store = get_parameter_store(trace) gen_fn = get_gen_fn(trace) for (name, value) in get_local_parameters(store, gen_fn) @@ -178,9 +178,10 @@ function accumulate_param_gradients!(trace::DynamicDSLTrace, retval_grad, scale_ reverse_pass!(tape) # increment the gradient accumulators for trainable parameters in scope + store = get_parameter_store(trace) for ((active_gen_fn, name), tracked) in state.tracked_params - parameter_id = (active_gen_fn, parameter_id) - increment_gradient!(store, parameter_id, deriv(tracked), state.scale_factor) + parameter_id = (active_gen_fn, name) + increment_gradient!(parameter_id, deriv(tracked), state.scale_factor, store) end # return gradients with respect to arguments with gradients, or nothing @@ -217,6 +218,16 @@ function GFBackpropTraceState(trace, selection, tape) get_gen_fn(trace)) end +get_parameter_store(state::GFBackpropTraceState) = get_parameter_store(state.trace) + +get_parameter_id(state::GFBackpropTraceState, name::Symbol) = (state.active_gen_fn, name) + +get_active_gen_fn(state::GFBackpropTraceState) = state.active_gen_fn + +function set_active_gen_fn!(state::GFBackpropTraceState, gen_fn::GenerativeFunction) + state.active_gen_fn = gen_fn +end + function fill_submaps!( map::DynamicChoiceMap, tracked_trie::Trie{Any,Union{TrackedReal,TrackedArray}}, @@ -304,15 +315,6 @@ function traceat(state::GFBackpropTraceState, gen_fn::GenerativeFunction{T,U}, retval_maybe_tracked end -function splice(state::GFBackpropTraceState, gen_fn::DynamicDSLFunction, - args_maybe_tracked::Tuple) - prev_gen_fn = state.active_gen_fn - state.active_gen_fn = gen_fn - retval = exec(gen_fn, state, args) - state.active_gen_fn = prev_gen_fn - return retval -end - @noinline function ReverseDiff.special_reverse_exec!( instruction::ReverseDiff.SpecialInstruction{BackpropTraceRecord}) record = instruction.func diff --git a/src/dynamic/dynamic.jl b/src/dynamic/dynamic.jl index 7f9e286de..dda40f214 100644 --- a/src/dynamic/dynamic.jl +++ b/src/dynamic/dynamic.jl @@ -29,14 +29,22 @@ function DynamicDSLFunction(arg_types::Vector{Type}, has_argument_grads, accepts_output_grad) end -function DynamicDSLTrace(gen_fn::T, args) where {T<:DynamicDSLFunction} +function Base.show(io::IO, gen_fn::DynamicDSLFunction) + return "Gen DML generative function: $(gen_fn.julia_function)" +end + +function Base.show(io::IO, ::MIME"text/plain", gen_fn::DynamicDSLFunction) + return "Gen DML generative function: $(gen_fn.julia_function)" +end + +function DynamicDSLTrace(gen_fn::T, args, parameter_store::JuliaParameterStore) where {T<:DynamicDSLFunction} # pad args with default values, if available if gen_fn.has_defaults && length(args) < length(gen_fn.arg_defaults) defaults = gen_fn.arg_defaults[length(args)+1:end] defaults = map(x -> something(x), defaults) args = Tuple(vcat(collect(args), defaults)) end - return DynamicDSLTrace{T}(gen_fn, args) + return DynamicDSLTrace{T}(gen_fn, args, parameter_store) end accepts_output_grad(gen_fn::DynamicDSLFunction) = gen_fn.accepts_output_grad @@ -58,6 +66,14 @@ function exec(gen_fn::DynamicDSLFunction, state, args::Tuple) gen_fn.julia_function(state, args...) end +function splice(state, gen_fn::DynamicDSLFunction, args::Tuple) + prev_gen_fn = get_active_gen_fn(state) + state.active_gen_fn = gen_fn + retval = exec(gen_fn, state, args) + set_active_gen_fn!(state, prev_gen_fn) + return retval +end + # whether there is a gradient of score with respect to each argument # it returns 'nothing' for those arguemnts that don't have a derivatice has_argument_grads(gen::DynamicDSLFunction) = gen.has_argument_grads @@ -104,7 +120,7 @@ end function read_param(state, name::Symbol) parameter_id = get_parameter_id(state, name) store = get_parameter_store(state) - return get_parameter_value(store, parameter_id) + return get_parameter_value(parameter_id, store) end ################## diff --git a/src/dynamic/generate.jl b/src/dynamic/generate.jl index 462bd385f..beed86c20 100644 --- a/src/dynamic/generate.jl +++ b/src/dynamic/generate.jl @@ -4,14 +4,26 @@ mutable struct GFGenerateState weight::Float64 visitor::AddressVisitor active_gen_fn::DynamicDSLFunction # mutated by splicing - parameter_context::Dict{Symbol,Any} + parameter_context::Dict + + function GFGenerateState(gen_fn, args, constraints, parameter_context) + parameter_store = get_julia_store(parameter_context) + trace = DynamicDSLTrace(gen_fn, args, parameter_store) + return new(trace, constraints, 0., AddressVisitor(), gen_fn, parameter_context) + end end -function GFGenerateState(gen_fn, args, constraints, parameter_context) - trace = DynamicDSLTrace(gen_fn, args, parameter_context) - GFGenerateState(trace, constraints, 0., AddressVisitor(), parameter_context) +get_parameter_store(state::GFGenerateState) = get_parameter_store(state.trace) + +get_parameter_id(state::GFGenerateState, name::Symbol) = (state.active_gen_fn, name) + +get_active_gen_fn(state::GFGenerateState) = state.active_gen_fn + +function set_active_gen_fn!(state::GFGenerateState, gen_fn::GenerativeFunction) + state.active_gen_fn = gen_fn end + function traceat(state::GFGenerateState, dist::Distribution{T}, args, key) where {T} local retval::T @@ -57,7 +69,8 @@ function traceat(state::GFGenerateState, gen_fn::GenerativeFunction{T,U}, # get subtrace (subtrace, weight) = generate( - gen_fn, args, constraints; parameter_context=parameter_context) + gen_fn, args, constraints; + parameter_context=state.parameter_context) # add to the trace add_call!(state.trace, key, subtrace) @@ -71,15 +84,6 @@ function traceat(state::GFGenerateState, gen_fn::GenerativeFunction{T,U}, retval end -function splice(state::GFGenerateState, gen_fn::DynamicDSLFunction, - args::Tuple) - prev_gen_fn = state.active_gen_fn - state.active_gen_fn = gen_fn - retval = exec(gen_fn, state, args) - state.active_gen_fn = prev_gen_fn - return retval -end - function generate( gen_fn::DynamicDSLFunction, args::Tuple, constraints::ChoiceMap; parameter_context=default_parameter_context) diff --git a/src/dynamic/regenerate.jl b/src/dynamic/regenerate.jl index 13d14d86f..371240f44 100644 --- a/src/dynamic/regenerate.jl +++ b/src/dynamic/regenerate.jl @@ -4,16 +4,27 @@ mutable struct GFRegenerateState selection::Selection weight::Float64 visitor::AddressVisitor - params::Dict{Symbol,Any} + active_gen_fn::DynamicDSLFunction # mutated by splicing + + function GFRegenerateState(gen_fn, args, prev_trace, selection) + visitor = AddressVisitor() + trace = DynamicDSLTrace(gen_fn, args, get_parameter_store(prev_trace)) + return new(prev_trace, trace, selection, + 0., visitor, gen_fn) + end end -function GFRegenerateState(gen_fn, args, prev_trace, - selection, params) - visitor = AddressVisitor() - GFRegenerateState(prev_trace, DynamicDSLTrace(gen_fn, args), selection, - 0., visitor, params) +get_parameter_store(state::GFRegenerateState) = get_parameter_store(state.trace) + +get_parameter_id(state::GFRegenerateState, name::Symbol) = (state.active_gen_fn, name) + +get_active_gen_fn(state::GFRegenerateState) = state.active_gen_fn + +function set_active_gen_fn!(state::GFRegenerateState, gen_fn::GenerativeFunction) + state.active_gen_fn = gen_fn end + function traceat(state::GFRegenerateState, dist::Distribution{T}, args, key) where {T} local prev_retval::T @@ -92,15 +103,6 @@ function traceat(state::GFRegenerateState, gen_fn::GenerativeFunction{T,U}, retval end -function splice(state::GFRegenerateState, gen_fn::DynamicDSLFunction, - args::Tuple) - prev_params = state.params - state.params = gen_fn.params - retval = exec(gen_fn, state, args) - state.params = prev_params - retval -end - function regenerate_delete_recurse(prev_trie::Trie{Any,ChoiceOrCallRecord}, visited::EmptySelection) noise = 0. @@ -133,7 +135,7 @@ end function regenerate(trace::DynamicDSLTrace, args::Tuple, argdiffs::Tuple, selection::Selection) gen_fn = trace.gen_fn - state = GFRegenerateState(gen_fn, args, trace, selection, gen_fn.params) + state = GFRegenerateState(gen_fn, args, trace, selection) retval = exec(gen_fn, state, args) set_retval!(state.trace, retval) visited = state.visitor.visited diff --git a/src/dynamic/simulate.jl b/src/dynamic/simulate.jl index 75aa07754..cbee0a165 100644 --- a/src/dynamic/simulate.jl +++ b/src/dynamic/simulate.jl @@ -2,17 +2,26 @@ mutable struct GFSimulateState trace::DynamicDSLTrace visitor::AddressVisitor active_gen_fn::DynamicDSLFunction # mutated by splicing - parameter_context::Dict{Symbol,Any} -end - -function GFSimulateState(gen_fn::GenerativeFunction, args::Tuple, parameter_context) - trace = DynamicDSLTrace(gen_fn, args, parameter_context) - GFSimulateState(trace, AddressVisitor(), parameter_context) + parameter_context::Dict + + function GFSimulateState( + gen_fn::GenerativeFunction, args::Tuple, parameter_context) + parameter_store = get_julia_store(parameter_context) + trace = DynamicDSLTrace(gen_fn, args, parameter_store) + return new(trace, AddressVisitor(), gen_fn, parameter_context) + end end get_parameter_store(state::GFSimulateState) = get_parameter_store(state.trace) + get_parameter_id(state::GFSimulateState, name::Symbol) = (state.active_gen_fn, name) +get_active_gen_fn(state::GFSimulateState) = state.active_gen_fn + +function set_active_gen_fn!(state::GFSimulateState, gen_fn::GenerativeFunction) + state.active_gen_fn = gen_fn +end + function traceat(state::GFSimulateState, dist::Distribution{T}, args, key) where {T} local retval::T @@ -51,15 +60,6 @@ function traceat(state::GFSimulateState, gen_fn::GenerativeFunction{T,U}, retval end -function splice(state::GFSimulateState, gen_fn::DynamicDSLFunction, - args::Tuple) - prev_gen_fn = state.active_gen_fn - state.active_gen_fn = gen_fn - retval = exec(gen_fn, state, args) - state.active_gen_fn = prev_gen_fn - return retval -end - function simulate( gen_fn::DynamicDSLFunction, args::Tuple; parameter_context=default_parameter_context) diff --git a/src/dynamic/trace.jl b/src/dynamic/trace.jl index 6ef6e35cf..4673ae1e8 100644 --- a/src/dynamic/trace.jl +++ b/src/dynamic/trace.jl @@ -39,9 +39,8 @@ mutable struct DynamicDSLTrace{T} <: Trace args::Tuple parameter_store::JuliaParameterStore retval::Any - function DynamicDSLTrace{T}(gen_fn::T, args, parameter_context) where {T} + function DynamicDSLTrace{T}(gen_fn::T, args, parameter_store::JuliaParameterStore) where {T} trie = Trie{Any,ChoiceOrCallRecord}() - parameter_store = get_julia_store(parameter_context) # retval is not known yet new(gen_fn, trie, true, 0, 0, args, parameter_store) end diff --git a/src/dynamic/update.jl b/src/dynamic/update.jl index bf4a42de2..7b49cd1c3 100644 --- a/src/dynamic/update.jl +++ b/src/dynamic/update.jl @@ -4,18 +4,29 @@ mutable struct GFUpdateState constraints::Any weight::Float64 visitor::AddressVisitor - params::Dict{Symbol,Any} discard::DynamicChoiceMap + active_gen_fn::DynamicDSLFunction # mutated by splicing + + function GFUpdateState(gen_fn, args, prev_trace, constraints) + visitor = AddressVisitor() + discard = choicemap() + trace = DynamicDSLTrace(gen_fn, args, get_parameter_store(prev_trace)) + return new(prev_trace, trace, constraints, + 0., visitor, discard, gen_fn) + end end -function GFUpdateState(gen_fn, args, prev_trace, constraints, params) - visitor = AddressVisitor() - discard = choicemap() - trace = DynamicDSLTrace(gen_fn, args) - GFUpdateState(prev_trace, trace, constraints, - 0., visitor, params, discard) +get_parameter_store(state::GFUpdateState) = get_parameter_store(state.trace) + +get_parameter_id(state::GFUpdateState, name::Symbol) = (state.active_gen_fn, name) + +get_active_gen_fn(state::GFUpdateState) = state.active_gen_fn + +function set_active_gen_fn!(state::GFUpdateState, gen_fn::GenerativeFunction) + state.active_gen_fn = gen_fn end + function traceat(state::GFUpdateState, dist::Distribution{T}, args::Tuple, key) where {T} @@ -110,13 +121,13 @@ function traceat(state::GFUpdateState, gen_fn::GenerativeFunction{T,U}, retval end -function splice(state::GFUpdateState, gen_fn::DynamicDSLFunction, - args::Tuple) - prev_params = state.params - state.params = gen_fn.params +function splice( + state::GFUpdateState, gen_fn::DynamicDSLFunction, args::Tuple) + prev_gen_fn = state.active_gen_fn + state.active_gen_fn = gen_fn retval = exec(gen_fn, state, args) - state.params = prev_params - retval + state.active_gen_fn = prev_gen_fn + return retval end function update_delete_recurse(prev_trie::Trie{Any,ChoiceOrCallRecord}, @@ -187,7 +198,7 @@ end function update(trace::DynamicDSLTrace, arg_values::Tuple, arg_diffs::Tuple, constraints::ChoiceMap) gen_fn = trace.gen_fn - state = GFUpdateState(gen_fn, arg_values, trace, constraints, gen_fn.params) + state = GFUpdateState(gen_fn, arg_values, trace, constraints) retval = exec(gen_fn, state, arg_values) set_retval!(state.trace, retval) visited = get_visited(state.visitor) diff --git a/src/gen_fn_interface.jl b/src/gen_fn_interface.jl index 439068634..6037f2521 100644 --- a/src/gen_fn_interface.jl +++ b/src/gen_fn_interface.jl @@ -246,7 +246,7 @@ t)\$, and return \$t\$ \\log \\frac{p(r, t; x)}{q(r; x, t)} ``` """ -function propose(gen_fn::GenerativeFunction, args::Tuple) +function propose(gen_fn::GenerativeFunction, args::Tuple; parameter_context=Dict()) trace = simulate(gen_fn, args) weight = get_score(trace) (get_choices(trace), weight, get_retval(trace)) diff --git a/src/inference/train.jl b/src/inference/train.jl index 9e64869cd..52735bc89 100644 --- a/src/inference/train.jl +++ b/src/inference/train.jl @@ -1,6 +1,6 @@ """ train!(gen_fn::GenerativeFunction, data_generator::Function, - update::ParamUpdate, + optimizer::CompositeOptimizer, num_epoch, epoch_size, num_minibatch, minibatch_size; verbose::Bool=false) Train the given generative function to maximize the expected conditional log @@ -22,7 +22,7 @@ taken under the marginal distribution on `inputs` determined by the data generator. """ function train!(gen_fn::GenerativeFunction, data_generator::Function, - update::ParamUpdate; + optimizer::CompositeOptimizer; num_epoch=1, epoch_size=1, num_minibatch=1, minibatch_size=1, evaluation_size=epoch_size, verbose=false, callback=(epoch, minibatch, minibatch_objective) -> nothing) @@ -55,7 +55,7 @@ function train!(gen_fn::GenerativeFunction, data_generator::Function, minibatch_objective += weight accumulate_param_gradients!(trace) end - apply!(update) + apply_update!(optimizer) minibatch_objective /= minibatch_size callback(epoch, minibatch, minibatch_objective) end @@ -87,7 +87,7 @@ end p::GenerativeFunction, p_args::Tuple, q::GenerativeFunction, get_q_args::Function) -Simulate a trace of p representing a training example, and use to update the gradients of the trainable parameters of q. +Simulate a trace of p representing a training example, and use to optimizer the gradients of the trainable parameters of q. Used for training q via maximum expected conditional likelihood. Random choices will be mapped from p to q based on their address. @@ -109,7 +109,7 @@ end p::GenerativeFunction, p_args::Tuple, q::GenerativeFunction, get_q_args::Function) -Simulate a batch of traces of p representing training samples, and use them to update the gradients of the trainable parameters of q. +Simulate a batch of traces of p representing training samples, and use them to optimizer the gradients of the trainable parameters of q. Like `lecture!` but q is batched, and must make random choices for training sample i under hierarchical address namespace i::Int (e.g. i => :z). get_q_args maps a vector of traces of p to an argument tuple of q. diff --git a/src/inference/variational.jl b/src/inference/variational.jl index 88736bf8b..610140ed8 100644 --- a/src/inference/variational.jl +++ b/src/inference/variational.jl @@ -88,25 +88,25 @@ function multi_sample_gradient_estimate!( (L, traces, weights_normalized) end -function _maybe_accumulate_param_grad!(trace, update::ParamUpdate, scale_factor::Real) +function _maybe_accumulate_param_grad!(trace, optimizer::CompositeOptimizer, scale_factor::Real) return accumulate_param_gradients!(trace, nothing, scale_factor) end -function _maybe_accumulate_param_grad!(trace, update::Nothing, scale_factor::Real) +function _maybe_accumulate_param_grad!(trace, optimizer::Nothing, scale_factor::Real) end """ (elbo_estimate, traces, elbo_history) = black_box_vi!( model::GenerativeFunction, model_args::Tuple, - [model_update::ParamUpdate,] + [model_optimizer::CompositeOptimizer,] observations::ChoiceMap, var_model::GenerativeFunction, var_model_args::Tuple, - var_model_update::ParamUpdate; + var_model_optimizer::CompositeOptimizer; options...) Fit the parameters of a variational model (`var_model`) to the posterior distribution implied by the given `model` and `observations` using stochastic -gradient methods. Users may optionally specify a `model_update` to jointly +gradient methods. Users may optionally specify a `model_optimizer` to jointly update the parameters of `model`. # Additional arguments: @@ -120,10 +120,10 @@ update the parameters of `model`. """ function black_box_vi!( model::GenerativeFunction, model_args::Tuple, - model_update::Union{ParamUpdate,Nothing}, + model_optimizer::Union{CompositeOptimizer,Nothing}, observations::ChoiceMap, var_model::GenerativeFunction, var_model_args::Tuple, - var_model_update::ParamUpdate; + var_model_optimizer::CompositeOptimizer; iters=1000, samples_per_iter=100, verbose=false, callback=(iter, traces, elbo_estimate) -> nothing) @@ -144,7 +144,7 @@ function black_box_vi!( elbo_estimate += (log_weight / samples_per_iter) # accumulate the generative model gradients - _maybe_accumulate_param_grad!(model_trace, model_update, 1.0 / samples_per_iter) + _maybe_accumulate_param_grad!(model_trace, model_optimizer, 1.0 / samples_per_iter) # record the traces var_traces[sample] = var_trace @@ -159,11 +159,11 @@ function black_box_vi!( callback(iter, var_traces, elbo_estimate) # update parameters of variational family - apply!(var_model_update) + apply_update!(var_model_optimizer) # update parameters of generative model - if !isnothing(model_update) - apply!(model_update) + if !isnothing(model_optimizer) + apply_update!(model_optimizer) end end @@ -173,24 +173,24 @@ end black_box_vi!(model::GenerativeFunction, model_args::Tuple, observations::ChoiceMap, var_model::GenerativeFunction, var_model_args::Tuple, - var_model_update::ParamUpdate; options...) = + var_model_optimizer::CompositeOptimizer; options...) = black_box_vi!(model, model_args, nothing, observations, - var_model, var_model_args, var_model_update; options...) + var_model, var_model_args, var_model_optimizer; options...) """ (iwelbo_estimate, traces, iwelbo_history) = black_box_vimco!( model::GenerativeFunction, model_args::Tuple, - [model_update::ParamUpdate,] + [model_optimizer::CompositeOptimizer,] observations::ChoiceMap, var_model::GenerativeFunction, var_model_args::Tuple, - var_model_update::ParamUpdate, + var_model_optimizer::CompositeOptimizer, grad_est_samples::Int; options...) Fit the parameters of a variational model (`var_model`) to the posterior distribution implied by the given `model` and `observations` using stochastic gradient methods applied to the [Variational Inference with Monte Carlo Objectives](https://arxiv.org/abs/1602.06725) (VIMCO) lower bound on the -marginal likelihood. Users may optionally specify a `model_update` to jointly +marginal likelihood. Users may optionally specify a `model_optimizer` to jointly update the parameters of `model`. # Additional arguments: @@ -208,9 +208,9 @@ update the parameters of `model`. """ function black_box_vimco!( model::GenerativeFunction, model_args::Tuple, - model_update::Union{ParamUpdate,Nothing}, observations::ChoiceMap, + model_optimizer::Union{CompositeOptimizer,Nothing}, observations::ChoiceMap, var_model::GenerativeFunction, var_model_args::Tuple, - var_model_update::ParamUpdate, grad_est_samples::Int; + var_model_optimizer::CompositeOptimizer, grad_est_samples::Int; iters=1000, samples_per_iter=100, geometric=true, verbose=false, callback=(iter, traces, elbo_estimate) -> nothing) @@ -238,7 +238,7 @@ function black_box_vimco!( for (var_trace, weight) in zip(original_var_traces, weights) constraints = merge(observations, get_choices(var_trace)) (model_trace, _) = generate(model, model_args, constraints) - _maybe_accumulate_param_grad!(model_trace, model_update, weight / samples_per_iter) + _maybe_accumulate_param_grad!(model_trace, model_optimizer, weight / samples_per_iter) end end iwelbo_history[iter] = iwelbo_estimate @@ -250,11 +250,11 @@ function black_box_vimco!( callback(iter, resampled_var_traces, iwelbo_estimate) # update parameters of variational family - apply!(var_model_update) + apply_update!(var_model_optimizer) # update parameters of generative model - if !isnothing(model_update) - apply!(model_update) + if !isnothing(model_optimizer) + apply_update!(model_optimizer) end end @@ -265,10 +265,10 @@ end black_box_vimco!(model::GenerativeFunction, model_args::Tuple, observations::ChoiceMap, var_model::GenerativeFunction, var_model_args::Tuple, - var_model_update::ParamUpdate, + var_model_optimizer::CompositeOptimizer, grad_est_samples::Int; options...) = black_box_vimco!(model, model_args, nothing, observations, - var_model, var_model_args, var_model_update, + var_model, var_model_args, var_model_optimizer, grad_est_samples; options...) export black_box_vi!, black_box_vimco! diff --git a/src/optimization.jl b/src/optimization.jl index 6a0cd67b2..35a49f82b 100644 --- a/src/optimization.jl +++ b/src/optimization.jl @@ -7,30 +7,26 @@ import Parameters # the misnomer names # # combinators (map etc.) and call_at! and choice_at! all need to implement get_parameters.. -# -# make changes to src/static_ir/backprop.jl -# -# make changes to src/dynamic/dynamic.jl (use the JuliaParameterStore) +# TODO add tests specifically for JuliaParameterStore etc. # # TODO GF untraced needs to reference a parameter store # # make changes to src/dynamic/backprop.jl +# make changes to other dynamic methods export in_place_add! export FixedStepGradientDescent export DecayStepGradientDescent -export make_optimizer +export init_optimizer export apply_update! -export ParameterStore export JuliaParameterStore -export JuliaParameterID - -export initialize_parameter! +export init_parameter! export increment_gradient! export reset_gradient! export get_parameter_value +export get_gradient ################# # in_place_add! # @@ -80,7 +76,7 @@ end # thread-safe accumulator # ########################### -struct Accumulator{T<:Union{Real,Array}} +mutable struct Accumulator{T<:Union{Real,Array}} value::T lock::ReentrantLock end @@ -110,7 +106,7 @@ function fill_with_zeros!(accum::Accumulator{Array{T}}) where {T} return accum end -function in_place_add!(accum::ThreadsafeAccumulator{Real}, increment::Real, scale_factor::Real) +function in_place_add!(accum::Accumulator{<:Real}, increment::Real, scale_factor::Real) lock(accum.lock) try accum.value = accum.value + increment * scale_factor @@ -120,7 +116,7 @@ function in_place_add!(accum::ThreadsafeAccumulator{Real}, increment::Real, scal return accum end -function in_place_add!(accum::ThreadsafeAccumulator{Real}, increment::Real) +function in_place_add!(accum::Accumulator{<:Real}, increment::Real) lock(accum.lock) try accum.value = accum.value + increment @@ -130,7 +126,7 @@ function in_place_add!(accum::ThreadsafeAccumulator{Real}, increment::Real) return accum end -function in_place_add!(accum::ThreadsafeAccumulator{<:Array}, increment, scale_factor::Real) +function in_place_add!(accum::Accumulator{<:Array}, increment, scale_factor::Real) lock(accum.lock) try @simd for i in 1:length(accum.value) @@ -142,7 +138,7 @@ function in_place_add!(accum::ThreadsafeAccumulator{<:Array}, increment, scale_f return accum end -function in_place_add!(accum::ThreadsafeAccumulator{<:Array}, increment) +function in_place_add!(accum::Accumulator{<:Array}, increment) lock(accum.lock) try @simd for i in 1:length(accum.value) @@ -157,38 +153,79 @@ end -################################# -# ParameterStore and optimizers # -################################# +################################### +# parameter stores and optimizers # +################################### + +# TODO create diagram and document the overal framework +# including parameter contexts and parameter stores,and the default beahviors abstract type ParameterStore end -# TODO docstring, returns an optimizer that has an apply_update! method -function make_optimizer(conf, store::ParameterStore, parameter_ids) end +""" + optimizer = init_optimizer( + conf, parameter_ids, + store=default_julia_parameter_store) + +Initialize an iterative gradient-based optimizer. + +The first argument defines the mathematical behavior of the update, the second argument defines the set of parameters to which the update should be applied at each iteration, and the third argument gives the location of the parameter values and their gradient accumulators. + +See [`apply_update!`](@ref). + +Not thread-safe. +""" +function init_optimizer(conf, parameter_ids, store=default_julia_parameter_store) + error("Not implemented") +end + +""" + apply_update!(optimizer) + +Apply one iteration of a gradient-based optimization update. + +See [`init_optimizer!`](@ref). + +Not thread-safe. +""" +function apply_update!(optimizer) + error("Not implemented") +end + +""" + + optimizer = CompositeOptimizer(conf, parameter_stores_to_ids::Dict{Any,Vector}) -# TODO docstring -function apply_update!(optimizer) end +Construct an optimizer that applies the given update to parameters in multiple parameter stores. +The first argument defines the mathematical behavior of the update; +the second argument defines the set of parameters to which the update should be applied at each iteration, +as a map from parameter stores to a vector of IDs of parameters within that parameter store. + + optimizer = CompositeOptimizer(conf, gen_fn::GenerativeFunction; parameter_context=default_parameter_context) + +Constructs a composite optimizer that applies the given update to all parameters used by the given generative function, even when the parameters exist in multiple parameter stores. +""" struct CompositeOptimizer conf::Any - optimizers::Dict{ParameterStore,Any} - function CompositeOptimizer(conf, parameters::Dict{ParameterStore,Vector}) - optimizers = Dict{ParameterStore,Any}() + optimizers::Dict{Any,Any} + function CompositeOptimizer(conf, parameter_stores_to_ids::Dict{Any,Vector}) + optimizers = Dict{Any,Any}() for (store, parameter_ids) in parameters - optimizers[store] = make_optimizer(conf, store, parameter_ids) + optimizers[store] = init_optimizer(conf, parameter_ids, store) end new(states, conf) end end -function CompositeOptimizer(conf, gen_fn::GenerativeFunction) - return CompositeOptimizer(conf, get_parameters(gen_fn)) +function CompositeOptimizer(conf, gen_fn::GenerativeFunction; parameter_context=default_parameter_context) + return CompositeOptimizer(conf, get_parameters(gen_fn, parameter_context)) end """ - apply_update!(update::ParamUpdate) + apply_update!(composite_opt::ComposieOptimizer) -Perform one step of the update. +Perform one step of an update, possibly mutating the values of parameters in multiple parameter stores. """ function apply_update!(composite_opt::CompositeOptimizer) for opt in values(composite_opt.optimizers) @@ -202,28 +239,39 @@ end # Julia # ######### -const JuliaParameterID = Tuple{GenerativeFunction,Symbol} - -# TODO document struct JuliaParameterStore values::Dict{GenerativeFunction,Dict{Symbol,Any}} - gradient_accumulators::Dict{GenerativeFunction,Dict{Symbol,GradientAccumulator}} + gradient_accumulators::Dict{GenerativeFunction,Dict{Symbol,Accumulator}} end +""" + + store = JuliaParameterStore() + +Construct a parameter store stores the state of parameters in the memory of the Julia runtime as Julia values. + +There is a global Julia parameter store automatically created and named `Gen.default_julia_parameter_store`. + +Incrementing the gradients can be safely multi-threaded (see [`increment_gradient!`](@ref)). +""" function JuliaParameterStore() return JuliaParameterStore( Dict{GenerativeFunction,Dict{Symbol,Any}}(), - Dict{GenerativeFunction,Dict{Symbol,GradientAccumulator}}()) + Dict{GenerativeFunction,Dict{Symbol,Accumulator}}()) end -get_local_parameters(store::JuliaParameterStore, gen_fn) = store.values[gen_fn] +function get_local_parameters(store::JuliaParameterStore, gen_fn) + if !haskey(store.values, gen_fn) + return Dict{Symbol,Any}() + else + return store.values[gen_fn] + end +end -# TODO document const default_parameter_context = Dict{Symbol,Any}() const default_julia_parameter_store = JuliaParameterStore() # for looking up in a parameter context when tracing (simulate, generate) -# TODO make the parametr context another argument to simulate and generate # once a trace is generated, it is bound to use a particular store const JULIA_PARAMETER_STORE_KEY = :julia_parameter_store @@ -236,7 +284,9 @@ function get_julia_store(context::Dict{Symbol,Any}) end """ - initialize_parameter!(store::JuliaParameterStore, id::JuliaParameterID, value) + init_parameter!( + id::Tuple{GenerativeFunction,Symbol}, value, + store::JuliaParameterStore=default_julia_parameter_store) Initialize the the value of a named trainable parameter of a generative function. @@ -244,28 +294,41 @@ Also generates the gradient accumulator for that parameter to `zero(value)`. Example: ```julia -initialize_parameter!(foo, :theta, 0.6) +init_parameter!((foo, :theta), 0.6) ``` Not thread-safe. """ -function initialize_parameter!(store::JuliaParameterStore, id::JuliaParameterID, value) +function init_parameter!( + id::Tuple{GenerativeFunction,Symbol}, value, + store::JuliaParameterStore=default_julia_parameter_store) (gen_fn, name) = id if !haskey(store.values, gen_fn) store.values[gen_fn] = Dict{Symbol,Any}() end store.values[gen_fn][name] = value - reset_gradient!(store, id) + reset_gradient!(id, store) return nothing end -# TODO docstring (not thread-safe) -function reset_gradient!(store::JuliaParameterStore, id::JuliaParameterID) +""" + reset_gradient!( + id::Tuple{GenerativeFunction,Symbol}, + store::JuliaParameterStore=default_julia_parameter_store) + +Reset the gradient accumulator for a trainable parameter. + +Not thread-safe. +""" +function reset_gradient!( + id::Tuple{GenerativeFunction,Symbol}, + store::JuliaParameterStore=default_julia_parameter_store) (gen_fn, name) = id + local value::Any try value = store.values[gen_fn][name] catch KeyError - @error "parameter not initialized: $id" + @error "parameter $name of $gen_fn was not initialized" rethrow() end if !haskey(store.gradient_accumulators, gen_fn) @@ -279,70 +342,129 @@ function reset_gradient!(store::JuliaParameterStore, id::JuliaParameterID) return nothing end -# TODO docstring (thread-safe) +""" + increment_gradient!( + id::Tuple{GenerativeFunction,Symbol}, increment, scale_factor::Real, + store::JuliaParameterStore=default_julia_parameter_store) + +Increment the gradient accumulator for a parameter. + +The increment is scaled by the given scale_factor. + +Thread-safe (multiple threads can increment the gradient of the same parameter concurrently). +""" function increment_gradient!( - store::JuliaParameterStore, id::JuliaParameterID, - increment, scale_factor) + id::Tuple{GenerativeFunction,Symbol}, increment, scale_factor, + store::JuliaParameterStore=default_julia_parameter_store) (gen_fn, name) = id try in_place_add!(store.gradient_accumulators[gen_fn][name], increment, scale_factor) catch KeyError - @error "parameter not initialized: $id" + @error "parameter $name of $gen_fn was not initialized" rethrow() end return nothing end -function get_gradient_accumulator(store::JuliaParameterStore, id::JuliaParameterID) +""" + increment_gradient!( + id::Tuple{GenerativeFunction,Symbol}, increment, + store::JuliaParameterStore=default_julia_parameter_store) + +Increment the gradient accumulator for a parameter. + +Thread-safe (multiple threads can increment the gradient of the same parameter concurrently). +""" +function increment_gradient!( + id::Tuple{GenerativeFunction,Symbol}, increment, + store::JuliaParameterStore=default_julia_parameter_store) + accumulator = get_gradient_accumulator(store, id) + in_place_add!(accumulator, increment) + return nothing +end + + +""" + accum::Accumulator = get_gradient_accumulator!( + id::Tuple{GenerativeFunction,Symbol}, + store::JuliaParameterStore=default_julia_parameter_store) + +Return the gradient accumulator for a parameter. + +Not thread-safe. +""" +function get_gradient_accumulator( + id::Tuple{GenerativeFunction,Symbol}, + store::JuliaParameterStore=default_julia_parameter_store) (gen_fn, name) = id try return store.gradient_accumulators[gen_fn][name] catch KeyError - @error "parameter not initialized: $id" + @error "parameter $name of $gen_fn was not initialized" rethrow() end end -# TODO docstring (thread-safe) -function increment_gradient!( - store::JuliaParameterStore, id::JuliaParameterID, - increment) - accumulator = get_gradient_accumulator(store, id) - in_place_add!(accumulator, increment) - return nothing -end +""" + value::Union{Real,Array} = get_parameter_value( + id::Tuple{GenerativeFunction,Symbol}, + store::JuliaParameterStore=default_julia_parameter_store) -# TODO docstring (not thread-safe) +Get the current value of a parameter. -function get_parameter_value(store::JuliaParameterStore, id::JuliaParameterID) +Not thread-safe. +""" +function get_parameter_value( + id::Tuple{GenerativeFunction,Symbol}, + store::JuliaParameterStore=default_julia_parameter_store) (gen_fn, name) = id try - return state.values[gen_fn][name] + return store.values[gen_fn][name] catch KeyError - @error "parameter not initialized: $id" + @error "parameter $name of $gen_fn was not initialized" rethrow() end end -# TODO docstring (not thread-safe) -function set_parameter_value!(store::JuliaParameterStore, id::JuliaParameterID, value) +""" + set_parameter_value!( + id::Tuple{GenerativeFunction,Symbol}, value::Union{Real,Array}, + store::JuliaParameterStore=default_julia_parameter_store) + +Set the value of a parameter. + +Not thread-safe. +""" +function set_parameter_value!( + id::Tuple{GenerativeFunction,Symbol}, value::Union{Real,Array}, + store::JuliaParameterStore=default_julia_parameter_store) (gen_fn, name) = id try store.values[gen_fn][name] = value catch KeyError - @error "parameter not initialized: $id" + @error "parameter $name of $gen_fn was not initialized" rethrow() end return nothing end -# TODO docstring (not thread-safe) -function get_gradient(store::JuliaParameterStore, id::JuliaParameterID) +""" + gradient::Union{Real,Array} = get_gradient( + id::Tuple{GenerativeFunction,Symbol}, + store::JuliaParameterStore=default_julia_parameter_store) + +Get the current value of the gradient accumulator for a parameter. + +Not thread-safe. +""" +function get_gradient( + id::Tuple{GenerativeFunction,Symbol}, + store::JuliaParameterStore=default_julia_parameter_store) (gen_fn, name) = id try return get_value(store.gradient_accumulators[gen_fn][name]) catch KeyError - @error "parameter not initialized: $id" + @error "parameter $name of $gen_fn was not initialized" rethrow() end end @@ -354,25 +476,52 @@ end mutable struct FixedStepGradientDescentJulia conf::FixedStepGradientDescent store::JuliaParameterStore - parameters::Vector{JuliaParameterID} + parameters::Vector end -function make_optimizer( +function init_optimizer( conf::FixedStepGradientDescent, - store::JuliaParameterStore, - parameters::Vector{JuliaParameterID}) + parameters::Vector, + store::JuliaParameterStore=default_julia_parameter_store) return FixedStepGradientDescentJulia(conf, store, parameters) end -# TODO docstring (not thread-safe) function apply_update!(opt::FixedStepGradientDescentJulia) - for parameter_id in opt.parameters - value = get_parameter_value(opt.store, parameter_id) - gradient = get_gradient(opt.store, id) + for parameter_id::Tuple{GenerativeFunction,Symbol} in opt.parameters + value = get_parameter_value(parameter_id, opt.store) + gradient = get_gradient(parameter_id, opt.store) new_value = in_place_add!(value, gradient * opt.conf.step_size) - set_parameter_value!(store, parameter_id, new_value) - reset_gradient!(store, parameter_id) + set_parameter_value!(parameter_id, new_value, opt.store) + reset_gradient!(parameter_id, opt.store) + end +end + +mutable struct DecayStepGradientDescentJulia + conf::DecayStepGradientDescent + store::JuliaParameterStore + parameters::Vector + t::Int +end + +function init_optimizer( + conf::DecayStepGradientDescent, + parameters::Vector, + store::JuliaParameterStore=default_julia_parameter_store) + return DecayStepGradientDescentJulia(conf, store, parameters, 1) +end + +function apply_update!(opt::DecayStepGradientDescentJulia) + step_size_init = opt.conf.step_size_init + step_size_beta = opt.conf.step_size_beta + step_size = step_size_init * (step_size_beta + 1) / (step_size_beta + opt.t) + for parameter_id::Tuple{GenerativeFunction,Symbol} in opt.parameters + value = get_parameter_value(parameter_id, opt.store) + gradient = get_gradient(parameter_id, opt.store) + new_value = in_place_add!(value, gradient * step_size) + set_parameter_value!(parameter_id, new_value, opt.store) + reset_gradient!(parameter_id, opt.store) end + opt.t += 1 end -# TODO implement other optimizers +# TODO implement other optimizers (ADAM, etc.) diff --git a/src/static_ir/backprop.jl b/src/static_ir/backprop.jl index ad5110fae..c255c7d9f 100644 --- a/src/static_ir/backprop.jl +++ b/src/static_ir/backprop.jl @@ -21,42 +21,43 @@ maybe_tracked_value_var(node::JuliaNode) = Symbol("$(maybe_tracked_value_prefix) const maybe_tracked_arg_prefix = gensym("maybe_tracked_arg") maybe_tracked_arg_var(node::JuliaNode, i::Int) = Symbol("$(maybe_tracked_arg_prefix)_$(node.name)_$i") -function fwd_pass!(selected_choices, selected_calls, fwd_marked, node::TrainableParameterNode, ::BackpropParamsMode) - push!(fwd_marked, node) -end - -function fwd_pass!(selected_choices, selected_calls, fwd_marked, node::TrainableParameterNode, ::BackpropTraceMode) +function forward_marking!( + selected_choices, selected_calls, fwd_marked, + node::TrainableParameterNode, mode) + if mode == BackpropParamsMode() + push!(fwd_marked, node) + end end -function fwd_pass!(selected_choices, selected_calls, fwd_marked, node::ArgumentNode, mode) +function forward_marking!(selected_choices, selected_calls, fwd_marked, node::ArgumentNode, mode) if node.compute_grad push!(fwd_marked, node) end end -function fwd_pass!(selected_choices, selected_calls, fwd_marked, node::JuliaNode, mode) +function forward_marking!(selected_choices, selected_calls, fwd_marked, node::JuliaNode, mode) if any(input_node in fwd_marked for input_node in node.inputs) push!(fwd_marked, node) end end -function fwd_pass!(selected_choices, selected_calls, fwd_marked, node::RandomChoiceNode, mode) +function forward_marking!(selected_choices, selected_calls, fwd_marked, node::RandomChoiceNode, mode) if node in selected_choices push!(fwd_marked, node) end end -function fwd_pass!(selected_choices, selected_calls, fwd_marked, node::GenerativeFunctionCallNode, mode) +function forward_marking!(selected_choices, selected_calls, fwd_marked, node::GenerativeFunctionCallNode, mode) if node in selected_calls || any(input_node in fwd_marked for input_node in node.inputs) push!(fwd_marked, node) end end -function back_pass!(back_marked, node::TrainableParameterNode) end +function backward_marking!(back_marked, node::TrainableParameterNode) end -function back_pass!(back_marked, node::ArgumentNode) end +function backward_marking!(back_marked, node::ArgumentNode) end -function back_pass!(back_marked, node::JuliaNode) +function backward_marking!(back_marked, node::JuliaNode) if node in back_marked for input_node in node.inputs push!(back_marked, input_node) @@ -64,7 +65,7 @@ function back_pass!(back_marked, node::JuliaNode) end end -function back_pass!(back_marked, node::RandomChoiceNode) +function backward_marking!(back_marked, node::RandomChoiceNode) # the logpdf of every random choice is a SINK for input_node in node.inputs push!(back_marked, input_node) @@ -73,7 +74,7 @@ function back_pass!(back_marked, node::RandomChoiceNode) push!(back_marked, node) end -function back_pass!(back_marked, node::GenerativeFunctionCallNode) +function backward_marking!(back_marked, node::GenerativeFunctionCallNode) # the logpdf of every generative function call is a SINK # (we could ask whether the generative function is deterministic or not # as a perforance optimization, because only stochsatic generative functions @@ -85,7 +86,7 @@ function back_pass!(back_marked, node::GenerativeFunctionCallNode) end end -function fwd_codegen!(stmts, fwd_marked, back_marked, node::TrainableParameterNode) +function forward_codegen!(stmts, fwd_marked, back_marked, node::TrainableParameterNode) if node in back_marked push!(stmts, :($(node.name) = $(QuoteNode(get_parameter_value))(trace, $(QuoteNode(node.name))))) end @@ -99,7 +100,7 @@ function fwd_codegen!(stmts, fwd_marked, back_marked, node::TrainableParameterNo end end -function fwd_codegen!(stmts, fwd_marked, back_marked, node::ArgumentNode) +function forward_codegen!(stmts, fwd_marked, back_marked, node::ArgumentNode) if node in fwd_marked && node in back_marked # initialize gradient to zero @@ -107,7 +108,7 @@ function fwd_codegen!(stmts, fwd_marked, back_marked, node::ArgumentNode) end end -function fwd_codegen!(stmts, fwd_marked, back_marked, node::JuliaNode) +function forward_codegen!(stmts, fwd_marked, back_marked, node::JuliaNode) if (node in fwd_marked) && (node in back_marked) @@ -140,7 +141,7 @@ function fwd_codegen!(stmts, fwd_marked, back_marked, node::JuliaNode) end end -function fwd_codegen!(stmts, fwd_marked, back_marked, node::RandomChoiceNode) +function forward_codegen!(stmts, fwd_marked, back_marked, node::RandomChoiceNode) # every random choice is in back_marked, since it affects it logpdf, but # also possibly due to other downstream usage of the value @assert node in back_marked @@ -155,10 +156,10 @@ function fwd_codegen!(stmts, fwd_marked, back_marked, node::RandomChoiceNode) end end -function fwd_codegen!(stmts, fwd_marked, back_marked, node::GenerativeFunctionCallNode) +function forward_codegen!(stmts, fwd_marked, back_marked, node::GenerativeFunctionCallNode) if node in back_marked - # for reference by other nodes during back_codegen! + # for reference by other nodes during backward_codegen! subtrace_fieldname = get_subtrace_fieldname(node) push!(stmts, :($(node.name) = $(QuoteNode(get_retval))(trace.$subtrace_fieldname))) end @@ -171,18 +172,20 @@ function fwd_codegen!(stmts, fwd_marked, back_marked, node::GenerativeFunctionCa end end -function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node::TrainableParameterNode, mode) +function backward_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node::TrainableParameterNode, mode) + if mode == BackpropParamsMode() - # handle case when it is the return node - if node === ir.return_node && node in fwd_marked - @assert node in back_marked - push!(stmts, :(isnothing(retval_grad) && error("Required return value gradient but got nothing"))) - push!(stmts, :($(gradient_var(node)) = $(QuoteNode(in_place_add!))( - $(gradient_var(node)), retval_grad, scale_factor))) + # handle case when it is the return node + if node === ir.return_node && node in fwd_marked + @assert node in back_marked + push!(stmts, :(isnothing(retval_grad) && error("Required return value gradient but got nothing"))) + push!(stmts, :($(gradient_var(node)) = $(QuoteNode(in_place_add!))( + $(gradient_var(node)), retval_grad, scale_factor))) + end end end -function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node::ArgumentNode, mode) +function backward_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node::ArgumentNode, mode) # handle case when it is the return node if node === ir.return_node && node in fwd_marked @@ -192,7 +195,7 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node: end end -function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node::JuliaNode, mode) +function backward_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node::JuliaNode, mode) # handle case when it is the return node if node === ir.return_node && node in fwd_marked @assert node in back_marked @@ -210,6 +213,7 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node: if input_node in fwd_marked arg_maybe_tracked = maybe_tracked_arg_var(node, i) if isa(input_node, TrainableParameterNode) + @assert mode == BackpropParamsMode() push!(stmts, :($(gradient_var(input_node)) = $(QuoteNode(in_place_add!))( $(gradient_var(input_node)), $(QuoteNode(deriv))($arg_maybe_tracked), scale_factor))) else @@ -222,8 +226,11 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node: end -function back_codegen_random_choice_to_inputs!(stmts, ir, fwd_marked, back_marked, - node::RandomChoiceNode, logpdf_grad::Symbol) +function backward_codegen_random_choice_to_inputs!( + stmts, ir, fwd_marked, back_marked, + node::RandomChoiceNode, logpdf_grad::Symbol, + mode) + # only evaluate the gradient of the logpdf if we need to if any(input_node in fwd_marked for input_node in node.inputs) || node in fwd_marked args = map((input_node) -> input_node.name, node.inputs) @@ -239,7 +246,7 @@ function back_codegen_random_choice_to_inputs!(stmts, ir, fwd_marked, back_marke end input_node_grad = gradient_var(input_node) increment = :($logpdf_grad[$(QuoteNode(i+1))]) - if isa(input_node, TrainableParameterNode) + if isa(input_node, TrainableParameterNode) && mode == BackpropParamsMode() push!(stmts, :($input_node_grad = $(QuoteNode(in_place_add!))( $input_node_grad, $increment, scale_factor))) else @@ -253,16 +260,16 @@ function back_codegen_random_choice_to_inputs!(stmts, ir, fwd_marked, back_marke if node === ir.return_node && node in fwd_marked @assert node in back_marked push!(stmts, :($(gradient_var(node)) = $(QuoteNode(in_place_add!))( - $(gradient_var(node)), retval_grad, scale_factor))) + $(gradient_var(node)), retval_grad))) end end -function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, - node::RandomChoiceNode, ::BackpropTraceMode) +function backward_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, + node::RandomChoiceNode, mode::BackpropTraceMode) logpdf_grad = gensym("logpdf_grad") # backpropagate to the inputs - back_codegen_random_choice_to_inputs!(stmts, ir, fwd_marked, back_marked, node, logpdf_grad) + backward_codegen_random_choice_to_inputs!(stmts, ir, fwd_marked, back_marked, node, logpdf_grad, mode) # backpropagate to the value (if it was selected) if node in fwd_marked @@ -274,13 +281,13 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, end end -function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, - node::RandomChoiceNode, ::BackpropParamsMode) +function backward_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, + node::RandomChoiceNode, mode::BackpropParamsMode) logpdf_grad = gensym("logpdf_grad") - back_codegen_random_choice_to_inputs!(stmts, ir, fwd_marked, back_marked, node, logpdf_grad) + backward_codegen_random_choice_to_inputs!(stmts, ir, fwd_marked, back_marked, node, logpdf_grad, mode) end -function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, +function backward_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node::GenerativeFunctionCallNode, mode::BackpropTraceMode) # handle case when it is the return node @@ -320,7 +327,7 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, # NOTE: the value_trie and gradient_trie are dealt with later end -function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, +function backward_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node::GenerativeFunctionCallNode, mode::BackpropParamsMode) # handle case when it is the return node @@ -393,15 +400,15 @@ function get_selected_choices(schema::StaticAddressSchema, ir::StaticIR) push!(selected_choices, node) end end - selected_choices + return selected_choices end function get_selected_calls(::EmptyAddressSchema, ::StaticIR) - Set{GenerativeFunctionCallNode}() + return Set{GenerativeFunctionCallNode}() end function get_selected_calls(::AllAddressSchema, ir::StaticIR) - Set{GenerativeFunctionCallNode}(ir.call_nodes) + return Set{GenerativeFunctionCallNode}(ir.call_nodes) end function get_selected_calls(schema::StaticAddressSchema, ir::StaticIR) @@ -412,7 +419,7 @@ function get_selected_calls(schema::StaticAddressSchema, ir::StaticIR) push!(selected_calls, node) end end - selected_calls + return selected_calls end function codegen_choice_gradients(trace_type::Type{T}, selection_type::Type, @@ -432,14 +439,14 @@ function codegen_choice_gradients(trace_type::Type{T}, selection_type::Type, # forward marking pass fwd_marked = Set{StaticIRNode}() for node in ir.nodes - fwd_pass!(selected_choices, selected_calls, fwd_marked, node, BackpropTraceMode()) + forward_marking!(selected_choices, selected_calls, fwd_marked, node, BackpropTraceMode()) end # backward marking pass back_marked = Set{StaticIRNode}() push!(back_marked, ir.return_node) for node in reverse(ir.nodes) - back_pass!(back_marked, node) + backward_marking!(back_marked, node) end stmts = Expr[] @@ -450,12 +457,12 @@ function codegen_choice_gradients(trace_type::Type{T}, selection_type::Type, # forward code-generation pass (initialize gradients to zero, create needed references) for node in ir.nodes - fwd_codegen!(stmts, fwd_marked, back_marked, node) + forward_codegen!(stmts, fwd_marked, back_marked, node) end # backward code-generation pass (increment gradients) for node in reverse(ir.nodes) - back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node, BackpropTraceMode()) + backward_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node, BackpropTraceMode()) end # assemble value_trie and gradient_trie @@ -471,7 +478,7 @@ function codegen_choice_gradients(trace_type::Type{T}, selection_type::Type, # return values push!(stmts, :(return ($input_grads, $value_trie, $gradient_trie))) - Expr(:block, stmts...) + return Expr(:block, stmts...) end function codegen_accumulate_param_gradients!(trace_type::Type{T}, @@ -490,14 +497,14 @@ function codegen_accumulate_param_gradients!(trace_type::Type{T}, # forward marking pass (propagate forward from 'sources') fwd_marked = Set{StaticIRNode}() for node in ir.nodes - fwd_pass!(selected_choices, selected_calls, fwd_marked, node, BackpropParamsMode()) + forward_marking!(selected_choices, selected_calls, fwd_marked, node, BackpropParamsMode()) end # backward marking pass (propagate backwards from 'sinks') back_marked = Set{StaticIRNode}() push!(back_marked, ir.return_node) for node in reverse(ir.nodes) - back_pass!(back_marked, node) + backward_marking!(back_marked, node) end stmts = Expr[] @@ -510,13 +517,13 @@ function codegen_accumulate_param_gradients!(trace_type::Type{T}, # any node that is backward-marked creates a variable for its current value # any node that is forward-marked and backwards marked initializes a gradient variable for node in ir.nodes - fwd_codegen!(stmts, fwd_marked, back_marked, node) + forward_codegen!(stmts, fwd_marked, back_marked, node) end # backward code-generation pass # any node that is forward-marked and backwards marked increments its gradient variable for node in reverse(ir.nodes) - back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node, BackpropParamsMode()) + backward_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node, BackpropParamsMode()) end # gradients with respect to inputs @@ -526,19 +533,21 @@ function codegen_accumulate_param_gradients!(trace_type::Type{T}, # return values push!(stmts, :(return $input_grads)) - Expr(:block, stmts...) + return Expr(:block, stmts...) end push!(generated_functions, quote -@generated function $(GlobalRef(Gen, :choice_gradients))(trace::T, selection::$(QuoteNode(Selection)), - retval_grad) where {T<:$(QuoteNode(StaticIRTrace))} - $(QuoteNode(codegen_choice_gradients))(trace, selection, retval_grad) +@generated function $(GlobalRef(Gen, :choice_gradients))( + trace::T, selection::$(QuoteNode(Selection)), + retval_grad) where {T<:$(QuoteNode(StaticIRTrace))} + return $(QuoteNode(codegen_choice_gradients))(trace, selection, retval_grad) end end) push!(generated_functions, quote -@generated function $(GlobalRef(Gen, :accumulate_param_gradients!))(trace::T, retval_grad, scale_factor) where {T<:$(QuoteNode(StaticIRTrace))} - $(QuoteNode(codegen_accumulate_param_gradients!))(trace, retval_grad, scale_factor) +@generated function $(GlobalRef(Gen, :accumulate_param_gradients!))( + trace::T, retval_grad, scale_factor) where {T<:$(QuoteNode(StaticIRTrace))} + return $(QuoteNode(codegen_accumulate_param_gradients!))(trace, retval_grad, scale_factor) end end) diff --git a/src/static_ir/generate.jl b/src/static_ir/generate.jl index d57f59788..2d4ecbd91 100644 --- a/src/static_ir/generate.jl +++ b/src/static_ir/generate.jl @@ -6,7 +6,9 @@ end function process!(::StaticIRGenerateState, node, options) end function process!(state::StaticIRGenerateState, node::TrainableParameterNode, options) - push!(stmts, :($(node.name) = $(QuoteNode(get_parameter_value))(trace, $(QuoteNode(node.name))))) + push!(state.stmts, :($(node.name) = $(QuoteNode(get_parameter_value))( + (gen_fn, $(QuoteNode(node.name))), + $parameter_store_fieldname))) end function process!(state::StaticIRGenerateState, node::ArgumentNode, options) @@ -100,7 +102,7 @@ function codegen_generate(gen_fn_type::Type{T}, args, push!(stmts, :($return_value_fieldname = $(ir.return_node.name))) # construct trace - push!(stmts, :($static_ir_gen_fn_ref = gen_fn)) + push!(stmts, :($static_ir_gen_fn_fieldname = gen_fn)) push!(stmts, :($trace = $(QuoteNode(trace_type))($(fieldnames(trace_type)...)))) # return trace and weight @@ -113,7 +115,7 @@ push!(generated_functions, quote @generated function $(GlobalRef(Gen, :generate))( gen_fn::$(QuoteNode(StaticIRGenerativeFunction)), args::$(QuoteNode(Tuple)), constraints::$(QuoteNode(ChoiceMap)); - parameter_context=default_parameter_context) + parameter_context=$(QuoteNode(default_parameter_context))) $(QuoteNode(codegen_generate))(gen_fn, args, constraints) end end) diff --git a/src/static_ir/simulate.jl b/src/static_ir/simulate.jl index ab2b769dc..5f6106386 100644 --- a/src/static_ir/simulate.jl +++ b/src/static_ir/simulate.jl @@ -5,7 +5,9 @@ end function process!(::StaticIRSimulateState, node, options) end function process!(state::StaticIRSimulateState, node::TrainableParameterNode, options) - push!(stmts, :($(node.name) = $(QuoteNode(get_parameter_value))(trace, $(QuoteNode(node.name))))) + push!(state.stmts, :($(node.name) = $(QuoteNode(get_parameter_value))( + (gen_fn, $(QuoteNode(node.name))), + $parameter_store_fieldname))) end function process!(state::StaticIRSimulateState, node::ArgumentNode, options) @@ -74,7 +76,7 @@ function codegen_simulate(gen_fn_type::Type{T}, args, parameter_context_type) wh # construct trace trace_type = get_trace_type(gen_fn_type) - push!(stmts, :($static_ir_gen_fn_ref = gen_fn)) + push!(stmts, :($static_ir_gen_fn_fieldname = gen_fn)) push!(stmts, :($trace = $(QuoteNode(trace_type))($(fieldnames(trace_type)...)))) # return trace @@ -86,7 +88,7 @@ end push!(generated_functions, quote @generated function $(GlobalRef(Gen, :simulate))( gen_fn::$(QuoteNode(StaticIRGenerativeFunction)), args::Tuple; - parameter_context=default_parameter_context) + parameter_context=$(QuoteNode(default_parameter_context))) $(QuoteNode(codegen_simulate))(gen_fn, args, parameter_context) end end) diff --git a/src/static_ir/static_ir.jl b/src/static_ir/static_ir.jl index 8e996e9c7..bbb24d7dd 100644 --- a/src/static_ir/static_ir.jl +++ b/src/static_ir/static_ir.jl @@ -55,15 +55,18 @@ function generate_generative_function(ir::StaticIR, name::Symbol, options::Stati $(GlobalRef(Gen, :get_trace_type))(::Type{$gen_fn_type_name}) = $trace_struct_name $(GlobalRef(Gen, :has_argument_grads))(::$gen_fn_type_name) = $(QuoteNode(has_argument_grads)) $(GlobalRef(Gen, :accepts_output_grad))(::$gen_fn_type_name) = $(QuoteNode(accepts_output_grad)) - $(GlobalRef(Gen, :get_gen_fn))(trace::$trace_struct_name) = $(Expr(:(.), :trace, QuoteNode(static_ir_gen_fn_ref))) + $(GlobalRef(Gen, :get_gen_fn))(trace::$trace_struct_name) = $(Expr(:(.), :trace, QuoteNode(static_ir_gen_fn_fieldname))) $(GlobalRef(Gen, :get_gen_fn_type))(::Type{$trace_struct_name}) = $gen_fn_type_name $(GlobalRef(Gen, :get_options))(::Type{$gen_fn_type_name}) = $(QuoteNode(options)) - function $(GlobalRef(Gen, :get_parameters))(gen_fn::Type{$gen_fn_type_name}, context) + function $(GlobalRef(Gen, :get_parameters))(gen_fn::$gen_fn_type_name, context) return $(GlobalRef(Gen, :get_parameters))($(QuoteNode(ir)), gen_fn, context) end + function Base.show(io::IO, ::MIME"text/plain", gen_fn::$gen_fn_type_name) + return "Gen SML generative function: $name)" + end + end - Expr(:block, trace_defns, gen_fn_defn, - Expr(:call, gen_fn_type_name, :(Dict{Symbol,Any}()), :(Dict{Symbol,Any}()), :(ReentrantLock()))) + Expr(:block, trace_defns, gen_fn_defn, Expr(:call, gen_fn_type_name)) end include("print_ir.jl") diff --git a/src/static_ir/trace.jl b/src/static_ir/trace.jl index 9f520240a..4e5fe0539 100644 --- a/src/static_ir/trace.jl +++ b/src/static_ir/trace.jl @@ -43,7 +43,7 @@ const total_score_fieldname = :score const total_noise_fieldname = :noise const return_value_fieldname = :retval const parameter_store_fieldname = :parameter_store - +const static_ir_gen_fn_fieldname = :gen_fn # other fields based on user-defined variable names are prefixed to avoid collisions get_value_fieldname(node::ArgumentNode) = Symbol("#arg#_$(node.name)") @@ -52,17 +52,16 @@ get_value_fieldname(node::JuliaNode) = Symbol("#julia#_$(node.name)") get_score_fieldname(node::RandomChoiceNode) = Symbol("#choice_score#_$(node.addr)") get_subtrace_fieldname(node::GenerativeFunctionCallNode) = Symbol("#subtrace#_$(node.addr)") - # getters function get_parameter_value(trace::StaticIRTrace, name) parameter_id = (get_gen_fn(trace), name) - return get_parameter_value(trace.parameter_store, parameter_id) + return get_parameter_value(parameter_id, trace.parameter_store) end function get_gradient_accumulator(trace::StaticIRTrace, name) parameter_id = (get_gen_fn(trace), name) - return get_gradient_accumulator(trace.parameter_store, parameter_id) + return get_gradient_accumulator(parameter_id, trace.parameter_store) end @inline function static_get_subtrace(trace::StaticIRTrace, addr) @@ -115,19 +114,17 @@ function get_trace_fields(ir::StaticIR, options::StaticIRGenerativeFunctionOptio push!(fields, TraceField(total_noise_fieldname, QuoteNode(Float64))) push!(fields, TraceField(num_nonempty_fieldname, QuoteNode(Int))) push!(fields, TraceField(return_value_fieldname, ir.return_node.typ)) - push!(fields, parameter_store_fieldname, QuoteNode(JuliaParameterStore)()) + push!(fields, TraceField(parameter_store_fieldname, QuoteNode(JuliaParameterStore))) + push!(fields, TraceField(static_ir_gen_fn_fieldname, QuoteNode(Any))) return fields end -const static_ir_gen_fn_ref = gensym("gen_fn") - function generate_trace_struct(ir::StaticIR, trace_struct_name::Symbol, options::StaticIRGenerativeFunctionOptions) mutable = false fields = get_trace_fields(ir, options) field_exprs = map((f) -> Expr(:(::), f.fieldname, f.typ), fields) Expr(:struct, mutable, Expr(:(<:), trace_struct_name, QuoteNode(StaticIRTrace)), - Expr(:block, field_exprs..., - Expr(:(::), static_ir_gen_fn_ref, QuoteNode(Any)))) + Expr(:block, field_exprs...)) end function generate_isempty(trace_struct_name::Symbol) diff --git a/src/static_ir/update.jl b/src/static_ir/update.jl index c0fe486eb..a47ece6a7 100644 --- a/src/static_ir/update.jl +++ b/src/static_ir/update.jl @@ -420,15 +420,16 @@ end function generate_new_trace!(stmts::Vector{Expr}, trace_type::Type, options) if options.track_diffs - # note that the generative function is the last field - constructor_args = map((name) -> Expr(:call, QuoteNode(strip_diff), name), - fieldnames(trace_type)[1:end-1]) - push!(stmts, :($trace = $(QuoteNode(trace_type))($(constructor_args...), - $(Expr(:(.), :trace, QuoteNode(static_ir_gen_fn_ref)))))) + # NOTE: applying 'strip_diff' to all of the fields, including fields that are not diffed, is a hack + constructor_args = map( + (name) -> Expr(:call, QuoteNode(strip_diff), name), + fieldnames(trace_type)) else - push!(stmts, :($static_ir_gen_fn_ref = $(Expr(:(.), :trace, QuoteNode(static_ir_gen_fn_ref))))) - push!(stmts, :($trace = $(QuoteNode(trace_type))($(fieldnames(trace_type)...)))) + constructor_args = fieldnames(trace_type) end + push!(stmts, :($static_ir_gen_fn_fieldname = $(Expr(:(.), :trace, QuoteNode(static_ir_gen_fn_fieldname))))) + push!(stmts, :($parameter_store_fieldname = $(Expr(:(.), :trace, QuoteNode(parameter_store_fieldname))))) + push!(stmts, :($trace = $(QuoteNode(trace_type))($(constructor_args...)))) end function generate_discard!(stmts::Vector{Expr}, diff --git a/test/dsl/dynamic_dsl.jl b/test/dsl/dynamic_dsl.jl index 25e3a50f4..6cea33299 100644 --- a/test/dsl/dynamic_dsl.jl +++ b/test/dsl/dynamic_dsl.jl @@ -295,8 +295,8 @@ end return @trace(normal(c, 1), :out) + (theta2 * 3) end - init_param!(bar, :theta1, 0.) - init_param!(foo, :theta2, 0.) + init_parameter!((bar, :theta1), 0.0) + init_parameter!((foo, :theta2), 0.0) function f(mu_a, a, b, z, out) lpdf = 0. @@ -358,8 +358,8 @@ end @test isapprox(mu_a_grad, finite_diff(f, (mu_a, a, b, z, out), 1, dx)) # check parameter gradient - theta1_grad = get_param_grad(bar, :theta1) - theta2_grad = get_param_grad(foo, :theta2) + theta1_grad = get_gradient((bar, :theta1)) + theta2_grad = get_gradient((foo, :theta2)) @test isapprox(theta1_grad, logpdf_grad(normal, z, a, 1)[2]) @test isapprox(theta2_grad, 3 * 2) @@ -372,7 +372,7 @@ end return theta end - init_param!(baz, :theta, 0.) + init_parameter!((baz, :theta), 0.0) @gen (grad) function foo() return @trace(baz()) @@ -381,7 +381,7 @@ end (trace, _) = generate(foo, ()) retval_grad = 2. accumulate_param_gradients!(trace, retval_grad) - @test isapprox(get_param_grad(baz, :theta), retval_grad) + @test isapprox(get_gradient((baz, :theta)), retval_grad) end @testset "gradient descent with fixed step size" begin @@ -389,14 +389,19 @@ end @param theta::Float64 return theta end - init_param!(foo, :theta, 0.) + + init_parameter!((foo, :theta), 0.0) + (trace, ) = generate(foo, ()) - accumulate_param_gradients!(trace, 1.) + accumulate_param_gradients!(trace, 1.0) + @test isapprox(get_gradient((foo, :theta)), 1.0) conf = FixedStepGradientDescent(0.001) - state = Gen.init_update_state(conf, foo, [:theta]) - Gen.apply_update!(state) - @test isapprox(get_param(foo, :theta), 0.001) - @test isapprox(get_param_grad(foo, :theta), 0.) + optimizer = init_optimizer(conf, [(foo, :theta)]) + apply_update!(optimizer) + println(get_parameter_value((foo, :theta))) + println(get_gradient((foo, :theta))) + @test isapprox(get_parameter_value((foo, :theta)), 0.001) + @test isapprox(get_gradient((foo, :theta)), 0.0) end @testset "gradient descent with shrinking step size" begin @@ -404,14 +409,17 @@ end @param theta::Float64 return theta end - init_param!(foo, :theta, 0.) + + init_parameter!((foo, :theta), 0.0) + (trace, ) = generate(foo, ()) accumulate_param_gradients!(trace, 1.) - conf = GradientDescent(0.001, 1000) - state = Gen.init_update_state(conf, foo, [:theta]) - Gen.apply_update!(state) - @test isapprox(get_param(foo, :theta), 0.001) - @test isapprox(get_param_grad(foo, :theta), 0.) + @test isapprox(get_gradient((foo, :theta)), 1.0) + conf = DecayStepGradientDescent(0.001, 1000) + optimizer = init_optimizer(conf, [(foo, :theta)]) + apply_update!(optimizer) + @test isapprox(get_parameter_value((foo, :theta)), 0.001) + @test isapprox(get_gradient((foo, :theta)), 0.0) end @testset "multi-component addresses" begin diff --git a/test/inference/train.jl b/test/inference/train.jl index 9e3aee7a7..995a1b81a 100644 --- a/test/inference/train.jl +++ b/test/inference/train.jl @@ -64,11 +64,11 @@ # theta2 -> inf (prob_y -> 1) # theta3 -> -inf (prob_y -> 0) - init_param!(student, :theta1, 0.) - init_param!(student, :theta2, 0.) - init_param!(student, :theta3, 0.) - init_param!(student, :theta4, 0.) - init_param!(student, :theta5, 0.) + init_parameter!((student, :theta1), 0.0) + init_parameter!((student, :theta2), 0.0) + init_parameter!((student, :theta3), 0.0) + init_parameter!((student, :theta4), 0.0) + init_parameter!((student, :theta5), 0.0) # check gradients using finite differences on a simulated batch minibatch_size = 100 @@ -80,9 +80,9 @@ accumulate_param_gradients!(student_trace, nothing) end for name in [:theta1, :theta2, :theta3, :theta4, :theta5] - actual = get_param_grad(student, name) + actual = get_gradient((student, name)) dx = 1e-6 - value = get_param(student, name) + value = get_parameter_value((student, name)) # evaluate total log density at value + dx set_param!(student, name, value + dx) @@ -107,19 +107,19 @@ end # use stochastic gradient descent - update = ParamUpdate(GradientDescent(0.01, 1000000), student) - train!(student, data_generator, update, + optimizer = CompositeOptimizer(GradientDescent(0.01, 1000000), student) + train!(student, data_generator, optimizer, num_epoch=2000, epoch_size=50, num_minibatch=1, minibatch_size=50, verbose=false) # p(x | z=0) = p(x | z=1) = 0.5 - @test isapprox(get_param(student, :theta1), 0., atol=0.2) + @test isapprox(get_parameter_value((student, :theta1)), 0.0, atol=0.2) # y | z, x = xor(x, z) - @test get_param(student, :theta2) < -5 - @test get_param(student, :theta3) > 5 - @test get_param(student, :theta4) > 5 - @test get_param(student, :theta5) < -5 + @test get_parameter_value((student, :theta2)) < -5 + @test get_parameter_value((student, :theta3)) > 5 + @test get_parameter_value((student, :theta4)) > 5 + @test get_parameter_value((student, :theta5)) < -5 end @@ -143,13 +143,13 @@ end end # train simple q using lecture! to compute gradients - init_param!(q, :theta, 0.) - init_param!(q, :log_std, 0.) - update = ParamUpdate(FixedStepGradientDescent(1e-4), q) + init_parameter!((q, :theta), 0.0) + init_parameter!((q, :log_std), 0.0) + optimizer = CompositeOptimizer(FixedStepGradientDescent(1e-4), q) score = Inf for iter=1:100 score = sum([lecture!(p, (), q, tr -> (tr[:x],)) for _=1:1000]) / 1000 - apply!(update) + apply_update!(optimizer) end score = sum([lecture!(p, (), q, tr -> (tr[:x],)) for _=1:10000]) / 10000 @test isapprox(score, -0.21, atol=5e-2) @@ -165,13 +165,13 @@ end end # train simple q using lecture_batched! to compute gradients - init_param!(q_batched, :theta, 0.) - init_param!(q_batched, :log_std, 0.) - update = ParamUpdate(FixedStepGradientDescent(0.001), q_batched) + init_parameter!(q_batched(, :theta, 0).0) + init_parameter!((q_batched, :log_std), 0.0) + optimizer = CompositeOptimizer(FixedStepGradientDescent(0.001), q_batched) score = Inf for iter=1:100 score = lecture_batched!(p, (), q_batched, trs -> (map(tr -> tr[:x], trs),), 1000) - apply!(update) + apply_update!(optimizer) end score = sum([lecture!(p, (), q, tr -> (tr[:x],)) for _=1:10000]) / 10000 @test isapprox(score, -0.21, atol=5e-2) diff --git a/test/inference/variational.jl b/test/inference/variational.jl index babcfb481..1084b4331 100644 --- a/test/inference/variational.jl +++ b/test/inference/variational.jl @@ -17,38 +17,38 @@ end # to regular black box variational inference - init_param!(approx, :slope_mu, 0.) - init_param!(approx, :slope_log_std, 0.) - init_param!(approx, :intercept_mu, 0.) - init_param!(approx, :intercept_log_std, 0.) + init_param!((approx, :slope_mu), 0.0) + init_param!((approx, :slope_log_std), 0.0) + init_param!((approx, :intercept_mu), 0.0) + init_param!((approx, :intercept_log_std), 0.0) observations = choicemap() - update = ParamUpdate(GradientDescent(1, 100000), approx) - update = ParamUpdate(GradientDescent(1., 1000), approx) - black_box_vi!(model, (), observations, approx, (), update; + optimizer = CompositeOptimizer(GradientDescent(1, 100000), approx) + optimizer = CompositeOptimizer(GradientDescent(1., 1000), approx) + black_box_vi!(model, (), observations, approx, (), optimizer; iters=2000, samples_per_iter=100, verbose=false) - slope_mu = get_param(approx, :slope_mu) - slope_log_std = get_param(approx, :slope_log_std) - intercept_mu = get_param(approx, :intercept_mu) - intercept_log_std = get_param(approx, :intercept_log_std) + slope_mu = get_parameter_value((approx, :slope_mu)) + slope_log_std = get_parameter_value((approx, :slope_log_std)) + intercept_mu = get_parameter_value((approx, :intercept_mu)) + intercept_log_std = get_parameter_value((approx, :intercept_log_std)) @test isapprox(slope_mu, -1., atol=0.001) @test isapprox(slope_log_std, 0.5, atol=0.001) @test isapprox(intercept_mu, 1., atol=0.001) @test isapprox(intercept_log_std, 2.0, atol=0.001) # smoke test for black box variational inference with Monte Carlo objectives - init_param!(approx, :slope_mu, 0.) - init_param!(approx, :slope_log_std, 0.) - init_param!(approx, :intercept_mu, 0.) - init_param!(approx, :intercept_log_std, 0.) - black_box_vimco!(model, (), observations, approx, (), update, 20; + init_param!((approx, :slope_mu), 0.0) + init_param!((approx, :slope_log_std), 0.0) + init_param!((approx, :intercept_mu), 0.0) + init_param!((approx, :intercept_log_std), 0.0) + black_box_vimco!(model, (), observations, approx, (), optimizer, 20; iters=50, samples_per_iter=100, verbose=false, geometric=false) - init_param!(approx, :slope_mu, 0.) - init_param!(approx, :slope_log_std, 0.) - init_param!(approx, :intercept_mu, 0.) - init_param!(approx, :intercept_log_std, 0.) - black_box_vimco!(model, (), observations, approx, (), update, 20; + init_param!((approx, :slope_mu), 0.0) + init_param!((approx, :slope_log_std), 0.0) + init_param!((approx, :intercept_mu), 0.0) + init_param!((approx, :intercept_log_std), 0.0) + black_box_vimco!(model, (), observations, approx, (), optimizer, 20; iters=50, samples_per_iter=100, verbose=false, geometric=true) end @@ -97,7 +97,7 @@ end {(:z, i)} ~ normal(posterior_means[i], sqrt(1.0 / posterior_precisions)) end end - init_param!(model, :theta, opt_theta) + init_param!((model, :theta), opt_theta) approx_trace = simulate(optimum_approx, ()) (model_trace, _) = generate(model, (), merge(get_choices(approx_trace), observations)) # note that p(z1..zn, x1..xn) / p(z1..zn | x1..xn) = p(x1...xn) - for all z1..zn @@ -105,31 +105,31 @@ end println("true optimum log_marginal_likelihood: $log_marginal_likelihood") # using BBVI with score function estimator - init_param!(model, :theta, 0.0) - init_param!(approx, :mu_coeffs, zeros(2)) - init_param!(approx, :log_std, 0.0) - approx_update = ParamUpdate(FixedStepGradientDescent(0.0001), approx) - model_update = ParamUpdate(FixedStepGradientDescent(0.002), model) + init_param!((model, :theta), 0.0) + init_param!((approx, :mu_coeffs), zeros(2)) + init_param!((approx, :log_std), 0.0) + approx_optimizer = CompositeOptimizer(FixedStepGradientDescent(0.0001), approx) + model_optimizer = CompositeOptimizer(FixedStepGradientDescent(0.002), model) @time (_, _, elbo_history, _) = - black_box_vi!(model, (), model_update, observations, - approx, (xs,), approx_update; + black_box_vi!(model, (), model_optimizer, observations, + approx, (xs,), approx_optimizer; iters=3000, samples_per_iter=20, verbose=false) - @test isapprox(get_param(model, :theta), opt_theta, atol=1e-1) - println("final theta: $(get_param(model, :theta))") + @test isapprox(get_parameter_value((model, :theta)), opt_theta, atol=1e-1) + println("final theta: $(get_parameter_value((model, :theta)))") println("final elbo estimate: $(elbo_history[end])") @test isapprox(elbo_history[end], log_marginal_likelihood, rtol=0.1) # using VIMCO - init_param!(model, :theta, 0.0) - init_param!(approx, :mu_coeffs, zeros(2)) - init_param!(approx, :log_std, 0.0) - approx_update = ParamUpdate(FixedStepGradientDescent(0.001), approx) - model_update = ParamUpdate(FixedStepGradientDescent(0.01), model) + init_param!((model, :theta), 0.0) + init_param!((approx, :mu_coeffs), zeros(2)) + init_param!((approx, :log_std), 0.0) + approx_optimizer = CompositeOptimizer(FixedStepGradientDescent(0.001), approx) + model_optimizer = CompositeOptimizer(FixedStepGradientDescent(0.01), model) @time (_, _, elbo_history, _) = - black_box_vimco!(model, (), model_update, observations, - approx, (xs,), approx_update, 10; + black_box_vimco!(model, (), model_optimizer, observations, + approx, (xs,), approx_optimizer, 10; iters=1000, samples_per_iter=10, verbose=false) - println("final theta: $(get_param(model, :theta))") + println("final theta: $(get_parameter_value((model, :theta)))") println("final elbo estimate: $(elbo_history[end])") @test isapprox(elbo_history[end], log_marginal_likelihood, rtol=0.1) end diff --git a/test/optional_args.jl b/test/optional_args.jl index fd6c4ea71..f4883d14d 100644 --- a/test/optional_args.jl +++ b/test/optional_args.jl @@ -10,7 +10,7 @@ using Gen end # initialize theta to zero for non-gradient tests - init_param!(foo, :theta, 0.) + init_parameter!((foo, :theta), 0.0) # test directly calling with varying args @test foo(1) == (1, 2, 3, 6) diff --git a/test/runtests.jl b/test/runtests.jl index 1317b19ee..6cfd343b5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -76,16 +76,14 @@ end const dx = 1e-6 -include("autodiff.jl") -include("diff.jl") -include("selection.jl") -include("assignment.jl") -include("gen_fn_interface.jl") +#include("autodiff.jl") +#include("diff.jl") +#include("selection.jl") +#include("assignment.jl") +#include("gen_fn_interface.jl") include("dsl/dsl.jl") -include("optional_args.jl") +#include("optional_args.jl") include("static_ir/static_ir.jl") include("tilde_sugar.jl") include("inference/inference.jl") include("modeling_library/modeling_library.jl") - - diff --git a/test/static_ir/static_ir.jl b/test/static_ir/static_ir.jl index f773b6853..63eadcd2b 100644 --- a/test/static_ir/static_ir.jl +++ b/test/static_ir/static_ir.jl @@ -49,7 +49,7 @@ foo = eval(generate_generative_function(ir, :foo, track_diffs=false, cache_julia @test occursin("== Static IR ==", repr("text/plain", ir)) theta_val = rand() -set_param!(foo, :theta, theta_val) +init_parameter!((foo, :theta), theta_val) #@gen (static, nojuliacache) function const_fn() #return 1 @@ -338,10 +338,8 @@ end z = 4. out = 5. - init_param!(foo, :theta, theta) - - @test get_param(foo, :theta) == theta - @test get_param_grad(foo, :theta) == 0. + # initialize the trainable parameter + init_parameter!((foo, :theta), theta) # get the initial trace constraints = choicemap() @@ -376,11 +374,6 @@ end @test isapprox(get_value(gradient_trie, :out), finite_diff(f, (mu_a, theta, a, b, z, out), 6, dx)) @test isapprox(get_value(gradient_trie, :bar => :z), finite_diff(f, (mu_a, theta, a, b, z, out), 5, dx)) - # reset the trainable parameter gradient - zero_param_grad!(foo, :theta) - @test get_param(foo, :theta) == theta - @test get_param_grad(foo, :theta) == 0. - # compute gradients with accumulate_param_gradients! retval_grad = 2. (mu_a_grad,) = accumulate_param_gradients!(trace, retval_grad) @@ -389,7 +382,9 @@ end @test isapprox(mu_a_grad, finite_diff(f, (mu_a, theta, a, b, z, out), 1, dx)) # check trainable parameter gradient - @test isapprox(get_param_grad(foo, :theta), finite_diff(f, (mu_a, theta, a, b, z, out), 2, dx)) + @test isapprox( + get_gradient((foo, :theta)), + finite_diff(f, (mu_a, theta, a, b, z, out), 2, dx)) end From d6f636cef6798a00d69a083185744d6dc4454817 Mon Sep 17 00:00:00 2001 From: Marco Cusumano-Towner Date: Mon, 17 May 2021 12:18:16 -0400 Subject: [PATCH 08/24] fixes --- src/dynamic/assess.jl | 37 ++++++----- src/dynamic/dynamic.jl | 4 ++ src/dynamic/propose.jl | 34 ++++++---- src/dynamic/simulate.jl | 2 +- src/optimization.jl | 1 + src/static_ir/backprop.jl | 2 +- src/static_ir/static_ir.jl | 4 +- test/dsl/dynamic_dsl.jl | 2 - test/inference/train.jl | 9 +-- test/modeling_library/map.jl | 2 +- test/modeling_library/unfold.jl | 2 +- test/runtests.jl | 10 +-- test/static_ir/static_ir.jl | 111 +------------------------------- 13 files changed, 64 insertions(+), 156 deletions(-) diff --git a/src/dynamic/assess.jl b/src/dynamic/assess.jl index c583d5079..81e49cce3 100644 --- a/src/dynamic/assess.jl +++ b/src/dynamic/assess.jl @@ -2,11 +2,22 @@ mutable struct GFAssessState choices::ChoiceMap weight::Float64 visitor::AddressVisitor - params::Dict{Symbol,Any} + active_gen_fn::DynamicDSLFunction # mutated by splicing + parameter_context::Dict + + function GFAssessState(gen_fn, choices, parameter_context) + new(choices, 0.0, AddressVisitor(), gen_fn, parameter_context) + end end -function GFAssessState(choices, params::Dict{Symbol,Any}) - GFAssessState(choices, 0., AddressVisitor(), params) +get_parameter_store(state::GFAssessState) = get_julia_store(state.parameter_context) + +get_parameter_id(state::GFAssessState, name::Symbol) = (state.active_gen_fn, name) + +get_active_gen_fn(state::GFAssessState) = state.active_gen_fn + +function set_active_gen_fn!(state::GFAssessState, gen_fn::GenerativeFunction) + state.active_gen_fn = gen_fn end function traceat(state::GFAssessState, dist::Distribution{T}, @@ -22,7 +33,7 @@ function traceat(state::GFAssessState, dist::Distribution{T}, # update weight state.weight += logpdf(dist, retval, args...) - retval + return retval end function traceat(state::GFAssessState, gen_fn::GenerativeFunction{T,U}, @@ -41,19 +52,13 @@ function traceat(state::GFAssessState, gen_fn::GenerativeFunction{T,U}, # update score state.weight += weight - retval -end - -function splice(state::GFAssessState, gen_fn::DynamicDSLFunction, args::Tuple) - prev_params = state.params - state.params = gen_fn.params - retval = exec(gen_fn, state, args) - state.params = prev_params - retval + return retval end -function assess(gen_fn::DynamicDSLFunction, args::Tuple, choices::ChoiceMap) - state = GFAssessState(choices, gen_fn.params) +function assess( + gen_fn::DynamicDSLFunction, args::Tuple, choices::ChoiceMap; + parameter_context=default_parameter_context) + state = GFAssessState(gen_fn, choices, parameter_context) retval = exec(gen_fn, state, args) unvisited = get_unvisited(get_visited(state.visitor), choices) @@ -61,5 +66,5 @@ function assess(gen_fn::DynamicDSLFunction, args::Tuple, choices::ChoiceMap) error("Assess did not visit the following constraint addresses:\n$unvisited") end - (state.weight, retval) + return (state.weight, retval) end diff --git a/src/dynamic/dynamic.jl b/src/dynamic/dynamic.jl index dda40f214..b379826d0 100644 --- a/src/dynamic/dynamic.jl +++ b/src/dynamic/dynamic.jl @@ -37,6 +37,10 @@ function Base.show(io::IO, ::MIME"text/plain", gen_fn::DynamicDSLFunction) return "Gen DML generative function: $(gen_fn.julia_function)" end +function get_parameters(gen_fn::DynamicDSLFunction, parameter_context) + # TODO for this, we need to walk the code... (and throw errors when the +end + function DynamicDSLTrace(gen_fn::T, args, parameter_store::JuliaParameterStore) where {T<:DynamicDSLFunction} # pad args with default values, if available if gen_fn.has_defaults && length(args) < length(gen_fn.arg_defaults) diff --git a/src/dynamic/propose.jl b/src/dynamic/propose.jl index 7c630cc19..3cece0e3a 100644 --- a/src/dynamic/propose.jl +++ b/src/dynamic/propose.jl @@ -2,11 +2,23 @@ mutable struct GFProposeState choices::DynamicChoiceMap weight::Float64 visitor::AddressVisitor - params::Dict{Symbol,Any} + active_gen_fn::DynamicDSLFunction # mutated by splicing + parameter_context::Dict + + function GFProposeState( + gen_fn::GenerativeFunction, parameter_context) + return new(choicemap(), 0.0, AddressVisitor(), gen_fn, parameter_context) + end end -function GFProposeState(params::Dict{Symbol,Any}) - GFProposeState(choicemap(), 0., AddressVisitor(), params) +get_parameter_store(state::GFProposeState) = get_julia_store(state.parameter_context) + +get_parameter_id(state::GFProposeState, name::Symbol) = (state.active_gen_fn, name) + +get_active_gen_fn(state::GFProposeState) = state.active_gen_fn + +function set_active_gen_fn!(state::GFProposeState, gen_fn::GenerativeFunction) + state.active_gen_fn = gen_fn end function traceat(state::GFProposeState, dist::Distribution{T}, @@ -47,16 +59,10 @@ function traceat(state::GFProposeState, gen_fn::GenerativeFunction{T,U}, retval end -function splice(state::GFProposeState, gen_fn::DynamicDSLFunction, args::Tuple) - prev_params = state.params - state.params = gen_fn.params - retval = exec(gen_fn, state, args) - state.params = prev_params - retval -end - -function propose(gen_fn::DynamicDSLFunction, args::Tuple) - state = GFProposeState(gen_fn.params) +function propose( + gen_fn::DynamicDSLFunction, args::Tuple; + parameter_context=default_parameter_context) + state = GFProposeState(gen_fn, parameter_context) retval = exec(gen_fn, state, args) - (state.choices, state.weight, retval) + return (state.choices, state.weight, retval) end diff --git a/src/dynamic/simulate.jl b/src/dynamic/simulate.jl index cbee0a165..8a4337092 100644 --- a/src/dynamic/simulate.jl +++ b/src/dynamic/simulate.jl @@ -66,5 +66,5 @@ function simulate( state = GFSimulateState(gen_fn, args, parameter_context) retval = exec(gen_fn, state, args) set_retval!(state.trace, retval) - state.trace + return state.trace end diff --git a/src/optimization.jl b/src/optimization.jl index 35a49f82b..b0de2f535 100644 --- a/src/optimization.jl +++ b/src/optimization.jl @@ -20,6 +20,7 @@ export FixedStepGradientDescent export DecayStepGradientDescent export init_optimizer export apply_update! +export CompositeOptimizer export JuliaParameterStore export init_parameter! diff --git a/src/static_ir/backprop.jl b/src/static_ir/backprop.jl index c255c7d9f..b42613a9e 100644 --- a/src/static_ir/backprop.jl +++ b/src/static_ir/backprop.jl @@ -277,7 +277,7 @@ function backward_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, error("Distribution $dist does not logpdf gradient for its output value") end push!(stmts, :($(gradient_var(node)) = $(QuoteNode(in_place_add!))( - $(gradient_var(node)), retval_grad))) + $(gradient_var(node)), $logpdf_grad[1]))) end end diff --git a/src/static_ir/static_ir.jl b/src/static_ir/static_ir.jl index bbb24d7dd..3e72a777b 100644 --- a/src/static_ir/static_ir.jl +++ b/src/static_ir/static_ir.jl @@ -46,6 +46,8 @@ function generate_generative_function(ir::StaticIR, name::Symbol, options::Stati has_argument_grads = tuple(map((node) -> node.compute_grad, ir.arg_nodes)...) accepts_output_grad = ir.accepts_output_grad + show_str = "Gen SML generative function: $name" + gen_fn_defn = quote struct $gen_fn_type_name <: $(QuoteNode(StaticIRGenerativeFunction)){$return_type,$trace_type} end @@ -62,7 +64,7 @@ function generate_generative_function(ir::StaticIR, name::Symbol, options::Stati return $(GlobalRef(Gen, :get_parameters))($(QuoteNode(ir)), gen_fn, context) end function Base.show(io::IO, ::MIME"text/plain", gen_fn::$gen_fn_type_name) - return "Gen SML generative function: $name)" + return $(QuoteNode(show_str)) end end diff --git a/test/dsl/dynamic_dsl.jl b/test/dsl/dynamic_dsl.jl index 6cea33299..0b7be3332 100644 --- a/test/dsl/dynamic_dsl.jl +++ b/test/dsl/dynamic_dsl.jl @@ -398,8 +398,6 @@ end conf = FixedStepGradientDescent(0.001) optimizer = init_optimizer(conf, [(foo, :theta)]) apply_update!(optimizer) - println(get_parameter_value((foo, :theta))) - println(get_gradient((foo, :theta))) @test isapprox(get_parameter_value((foo, :theta)), 0.001) @test isapprox(get_gradient((foo, :theta)), 0.0) end diff --git a/test/inference/train.jl b/test/inference/train.jl index 995a1b81a..f62641afa 100644 --- a/test/inference/train.jl +++ b/test/inference/train.jl @@ -85,7 +85,8 @@ value = get_parameter_value((student, name)) # evaluate total log density at value + dx - set_param!(student, name, value + dx) + init_parameter!((student, name), value + dx) + lpdf_pos = 0. for i=1:minibatch_size (incr, _) = assess(student, inputs[i], constraints[i]) @@ -93,7 +94,7 @@ end # evaluate total log density at value - dx - set_param!(student, name, value - dx) + init_parameter!((student, name), value - dx) lpdf_neg = 0. for i=1:minibatch_size (incr, _) = assess(student, inputs[i], constraints[i]) @@ -103,11 +104,11 @@ expected = (lpdf_pos - lpdf_neg) / (2 * dx) @test isapprox(actual, expected, atol=1e-4) - set_param!(student, name, value) + init_parameter!((student, name), value) end # use stochastic gradient descent - optimizer = CompositeOptimizer(GradientDescent(0.01, 1000000), student) + optimizer = CompositeOptimizer(DecayStepGradientDescent(0.01, 1000000), student) train!(student, data_generator, optimizer, num_epoch=2000, epoch_size=50, num_minibatch=1, minibatch_size=50, verbose=false) diff --git a/test/modeling_library/map.jl b/test/modeling_library/map.jl index 3c1f820fe..3dd197d52 100644 --- a/test/modeling_library/map.jl +++ b/test/modeling_library/map.jl @@ -6,7 +6,7 @@ return z end - set_param!(foo, :std, 1.) + init_parameter!((foo, :std), 1.0) bar = Map(foo) xs = [1.0, 2.0, 3.0, 4.0] diff --git a/test/modeling_library/unfold.jl b/test/modeling_library/unfold.jl index 3d8c1b54a..7e6da6907 100644 --- a/test/modeling_library/unfold.jl +++ b/test/modeling_library/unfold.jl @@ -491,7 +491,7 @@ foo = Unfold(kernel) std = 1. - set_param!(kernel, :std, std) + init_parameter!((kernel, :std), std) x_init = 0.1 alpha = 0.2 diff --git a/test/runtests.jl b/test/runtests.jl index 6cfd343b5..15a62fb6b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -76,11 +76,11 @@ end const dx = 1e-6 -#include("autodiff.jl") -#include("diff.jl") -#include("selection.jl") -#include("assignment.jl") -#include("gen_fn_interface.jl") +include("autodiff.jl") +include("diff.jl") +include("selection.jl") +include("assignment.jl") +include("gen_fn_interface.jl") include("dsl/dsl.jl") #include("optional_args.jl") include("static_ir/static_ir.jl") diff --git a/test/static_ir/static_ir.jl b/test/static_ir/static_ir.jl index 63eadcd2b..61a14a58c 100644 --- a/test/static_ir/static_ir.jl +++ b/test/static_ir/static_ir.jl @@ -276,117 +276,8 @@ end end -@testset "backprop" begin - - #@gen (static) function bar(mu_z::Float64) - #z = @trace(normal(mu_z, 1), :z) - #return z + mu_z - #end - - # bar - builder = StaticIRBuilder() - mu_z = add_argument_node!(builder, name=:mu_z, typ=:Float64, compute_grad=true) - one = add_constant_node!(builder, 1.) - z = add_addr_node!(builder, normal, inputs=[mu_z, one], addr=:z, name=:z) - retval = add_julia_node!(builder, (z, mu_z) -> z + mu_z, inputs=[z, mu_z], name=:retval) - set_return_node!(builder, retval) - ir = build_ir(builder) - bar = eval(generate_generative_function(ir, :bar, track_diffs=false, cache_julia_nodes=false)) - - #@gen (static) function foo(mu_a::Float64) - #param theta::Float64 - #a = @trace(normal(mu_a, 1), :a) - #b = @trace(normal(a, 1), :b) - #bar = @trace(bar(a), :bar) - #c = a * b * bar * theta - #out = @trace(normal(c, 1), :out) - #return out - #end - - # foo - builder = StaticIRBuilder() - mu_a = add_argument_node!(builder, name=:mu_a, typ=:Float64, compute_grad=true) - theta = add_trainable_param_node!(builder, :theta, typ=QuoteNode(Float64)) - one = add_constant_node!(builder, 1.) - a = add_addr_node!(builder, normal, inputs=[mu_a, one], addr=:a, name=:a) - b = add_addr_node!(builder, normal, inputs=[a, one], addr=:b, name=:b) - bar_val = add_addr_node!(builder, bar, inputs=[a], addr=:bar, name=:bar_val) - c = add_julia_node!(builder, (a, b, bar, theta) -> (a * b * bar * theta), - inputs=[a, b, bar_val, theta], name=:c) - retval = add_addr_node!(builder, normal, inputs=[c, one], addr=:out, name=:out) - set_return_node!(builder, retval) - ir = build_ir(builder) - foo = eval(generate_generative_function(ir, :foo, track_diffs=false, cache_julia_nodes=false)) - - Gen.load_generated_functions() - - function f(mu_a, theta, a, b, z, out) - lpdf = 0. - mu_z = a - lpdf += logpdf(normal, z, mu_z, 1) - lpdf += logpdf(normal, a, mu_a, 1) - lpdf += logpdf(normal, b, a, 1) - c = a * b * (z + mu_z) * theta - lpdf += logpdf(normal, out, c, 1) - return lpdf + 2 * out - end - - mu_a = 1. - theta = -0.5 - a = 2. - b = 3. - z = 4. - out = 5. - - # initialize the trainable parameter - init_parameter!((foo, :theta), theta) - - # get the initial trace - constraints = choicemap() - constraints[:a] = a - constraints[:b] = b - constraints[:out] = out - constraints[:bar => :z] = z - (trace, _) = generate(foo, (mu_a,), constraints) - - # compute gradients with choice_gradients - selection = select(:bar => :z, :a, :out) - selection = StaticSelection(selection) - retval_grad = 2. - ((mu_a_grad,), value_trie, gradient_trie) = choice_gradients(trace, selection, retval_grad) - - # check input gradient - @test isapprox(mu_a_grad, finite_diff(f, (mu_a, theta, a, b, z, out), 1, dx)) - - # check value trie - @test get_value(value_trie, :a) == a - @test get_value(value_trie, :out) == out - @test get_value(value_trie, :bar => :z) == z - @test !has_value(value_trie, :b) # was not selected - @test length(get_submaps_shallow(value_trie)) == 1 - @test length(get_values_shallow(value_trie)) == 2 - - # check gradient trie - @test length(get_submaps_shallow(gradient_trie)) == 1 - @test length(get_values_shallow(gradient_trie)) == 2 - @test !has_value(gradient_trie, :b) # was not selected - @test isapprox(get_value(gradient_trie, :a), finite_diff(f, (mu_a, theta, a, b, z, out), 3, dx)) - @test isapprox(get_value(gradient_trie, :out), finite_diff(f, (mu_a, theta, a, b, z, out), 6, dx)) - @test isapprox(get_value(gradient_trie, :bar => :z), finite_diff(f, (mu_a, theta, a, b, z, out), 5, dx)) - - # compute gradients with accumulate_param_gradients! - retval_grad = 2. - (mu_a_grad,) = accumulate_param_gradients!(trace, retval_grad) - - # check input gradient - @test isapprox(mu_a_grad, finite_diff(f, (mu_a, theta, a, b, z, out), 1, dx)) - - # check trainable parameter gradient - @test isapprox( - get_gradient((foo, :theta)), - finite_diff(f, (mu_a, theta, a, b, z, out), 2, dx)) +include("gradients.jl") -end # functions to test tracked diffs From 3fc5a0ec1fceb7e6f9bd9bf779faa69aa4274439 Mon Sep 17 00:00:00 2001 From: Marco Cusumano-Towner Date: Mon, 17 May 2021 19:48:39 -0400 Subject: [PATCH 09/24] tests passing --- src/dynamic/assess.jl | 4 +- src/dynamic/backprop.jl | 5 +- src/dynamic/dynamic.jl | 63 ++++++++++++++++++++----- src/dynamic/generate.jl | 19 ++++---- src/dynamic/propose.jl | 8 ++-- src/dynamic/regenerate.jl | 9 ++-- src/dynamic/simulate.jl | 11 +++-- src/dynamic/trace.jl | 17 ++++++- src/dynamic/update.jl | 10 ++-- src/gen_fn_interface.jl | 62 +++++++++++++++--------- src/modeling_library/call_at/call_at.jl | 11 +++-- src/modeling_library/custom_determ.jl | 15 ++++-- src/modeling_library/map/assess.jl | 16 ++++--- src/modeling_library/map/generate.jl | 18 ++++--- src/modeling_library/map/map.jl | 1 + src/modeling_library/map/propose.jl | 15 +++--- src/modeling_library/map/simulate.jl | 9 ++-- src/modeling_library/recurse/recurse.jl | 29 ++++++++---- src/modeling_library/switch/assess.jl | 33 ++++++++----- src/modeling_library/switch/generate.jl | 35 +++++++++----- src/modeling_library/switch/propose.jl | 29 +++++++----- src/modeling_library/switch/simulate.jl | 25 ++++++---- src/modeling_library/switch/switch.jl | 8 ++++ src/modeling_library/unfold/assess.jl | 16 ++++--- src/modeling_library/unfold/generate.jl | 18 ++++--- src/modeling_library/unfold/propose.jl | 13 +++-- src/modeling_library/unfold/simulate.jl | 12 +++-- src/modeling_library/unfold/unfold.jl | 3 ++ src/optimization.jl | 24 ++++------ src/static_ir/generate.jl | 4 +- src/static_ir/simulate.jl | 8 ++-- test/dsl/dynamic_dsl.jl | 7 ++- test/inference/train.jl | 5 +- test/inference/variational.jl | 45 +++++++++--------- test/modeling_library/map.jl | 5 +- test/modeling_library/switch.jl | 14 +++--- test/modeling_library/unfold.jl | 5 +- test/optional_args.jl | 1 + test/runtests.jl | 2 +- 39 files changed, 404 insertions(+), 230 deletions(-) diff --git a/src/dynamic/assess.jl b/src/dynamic/assess.jl index 81e49cce3..ffb7a9378 100644 --- a/src/dynamic/assess.jl +++ b/src/dynamic/assess.jl @@ -56,8 +56,8 @@ function traceat(state::GFAssessState, gen_fn::GenerativeFunction{T,U}, end function assess( - gen_fn::DynamicDSLFunction, args::Tuple, choices::ChoiceMap; - parameter_context=default_parameter_context) + gen_fn::DynamicDSLFunction, args::Tuple, choices::ChoiceMap, + parameter_context::Dict) state = GFAssessState(gen_fn, choices, parameter_context) retval = exec(gen_fn, state, args) diff --git a/src/dynamic/backprop.jl b/src/dynamic/backprop.jl index 361f29c14..77e22e1ae 100644 --- a/src/dynamic/backprop.jl +++ b/src/dynamic/backprop.jl @@ -58,8 +58,11 @@ mutable struct GFBackpropParamsState end end -function read_param(state::GFBackpropParamsState, name::Symbol) +function read_param!(state::GFBackpropParamsState, name::Symbol) parameter_id = (state.active_gen_fn, name) + if !(parameter_id in state.trace.registered_julia_parameters) + throw(ArgumentError("parameter $parameter_id was not registered using register_parameters!")) + end return state.tracked_params[parameter_id] end diff --git a/src/dynamic/dynamic.jl b/src/dynamic/dynamic.jl index b379826d0..3cf48963f 100644 --- a/src/dynamic/dynamic.jl +++ b/src/dynamic/dynamic.jl @@ -1,3 +1,5 @@ +export register_parameters! + include("trace.jl") """ @@ -8,13 +10,14 @@ A generative function based on a shallowly embedding modeling language based on Constructed using the `@gen` keyword. Most methods in the generative function interface involve a end-to-end execution of the function. """ -struct DynamicDSLFunction{T} <: GenerativeFunction{T,DynamicDSLTrace} +mutable struct DynamicDSLFunction{T} <: GenerativeFunction{T,DynamicDSLTrace} arg_types::Vector{Type} has_defaults::Bool arg_defaults::Vector{Union{Some{Any},Nothing}} julia_function::Function has_argument_grads::Vector{Bool} accepts_output_grad::Bool + parameters::Union{Vector,Function} end function DynamicDSLFunction(arg_types::Vector{Type}, @@ -26,29 +29,65 @@ function DynamicDSLFunction(arg_types::Vector{Type}, return DynamicDSLFunction{T}(arg_types, has_defaults, arg_defaults, julia_function, - has_argument_grads, accepts_output_grad) + has_argument_grads, accepts_output_grad, []) end -function Base.show(io::IO, gen_fn::DynamicDSLFunction) - return "Gen DML generative function: $(gen_fn.julia_function)" +function get_parameters(gen_fn::DynamicDSLFunction, parameter_context) + if isa(gen_fn.parameters, Vector) + julia_store = get_julia_store(parameter_context) + parameter_stores_to_ids = Dict{Any,Vector}() + parameter_ids = Tuple{GenerativeFunction,Symbol}[] + for param in gen_fn.parameters + if isa(param, Tuple{GenerativeFunction,Symbol}) + push!(parameter_ids, param) + elseif isa(param, Symbol) + push!(parameter_ids, (gen_fn, param)) + else + throw(ArgumentError("Invalid parameter declaration for DML generative function $gen_fn: $param")) + end + end + parameter_stores_to_ids[julia_store] = parameter_ids + return parameter_stores_to_ids + elseif isa(gen_fn.parameters, Function) + return gen_fn.parameters(parameter_context) + end end -function Base.show(io::IO, ::MIME"text/plain", gen_fn::DynamicDSLFunction) - return "Gen DML generative function: $(gen_fn.julia_function)" +""" + register_parameters!(gen_fn::DynamicDSLFunction, parameters) + +Register the altrainable parameters that are used by a DML generative function. + +This includes all parameters used within any calls made by the generative function. + +There are two variants: + +# TODO document the variants +""" +function register_parameters!(gen_fn::DynamicDSLFunction, parameters) + gen_fn.parameters = parameters + return nothing end -function get_parameters(gen_fn::DynamicDSLFunction, parameter_context) - # TODO for this, we need to walk the code... (and throw errors when the +function Base.show(io::IO, gen_fn::DynamicDSLFunction) + print(io, "Gen DML generative function: $(gen_fn.julia_function)") +end + +function Base.show(io::IO, ::MIME"text/plain", gen_fn::DynamicDSLFunction) + print(io, "Gen DML generative function: $(gen_fn.julia_function)") end -function DynamicDSLTrace(gen_fn::T, args, parameter_store::JuliaParameterStore) where {T<:DynamicDSLFunction} +function DynamicDSLTrace( + gen_fn::T, args, parameter_store::JuliaParameterStore, + parameter_context, registered_julia_parameters) where {T<:DynamicDSLFunction} # pad args with default values, if available if gen_fn.has_defaults && length(args) < length(gen_fn.arg_defaults) defaults = gen_fn.arg_defaults[length(args)+1:end] defaults = map(x -> something(x), defaults) args = Tuple(vcat(collect(args), defaults)) end - return DynamicDSLTrace{T}(gen_fn, args, parameter_store) + return DynamicDSLTrace{T}( + gen_fn, args, parameter_store, parameter_context, registered_julia_parameters) end accepts_output_grad(gen_fn::DynamicDSLFunction) = gen_fn.accepts_output_grad @@ -118,10 +157,10 @@ end function dynamic_param_impl(expr::Expr) @assert expr.head == :genparam "Not a Gen param expression." name = expr.args[1] - Expr(:(=), name, Expr(:call, GlobalRef(@__MODULE__, :read_param), state, QuoteNode(name))) + Expr(:(=), name, Expr(:call, GlobalRef(@__MODULE__, :read_param!), state, QuoteNode(name))) end -function read_param(state, name::Symbol) +function read_param!(state, name::Symbol) parameter_id = get_parameter_id(state, name) store = get_parameter_store(state) return get_parameter_value(parameter_id, store) diff --git a/src/dynamic/generate.jl b/src/dynamic/generate.jl index beed86c20..6098a9012 100644 --- a/src/dynamic/generate.jl +++ b/src/dynamic/generate.jl @@ -8,7 +8,10 @@ mutable struct GFGenerateState function GFGenerateState(gen_fn, args, constraints, parameter_context) parameter_store = get_julia_store(parameter_context) - trace = DynamicDSLTrace(gen_fn, args, parameter_store) + registered_julia_parameters = Set{Tuple{GenerativeFunction,Symbol}}( + get_parameters(gen_fn, parameter_context)[parameter_store]) + trace = DynamicDSLTrace( + gen_fn, args, parameter_store, parameter_context, registered_julia_parameters) return new(trace, constraints, 0., AddressVisitor(), gen_fn, parameter_context) end end @@ -23,7 +26,6 @@ function set_active_gen_fn!(state::GFGenerateState, gen_fn::GenerativeFunction) state.active_gen_fn = gen_fn end - function traceat(state::GFGenerateState, dist::Distribution{T}, args, key) where {T} local retval::T @@ -53,7 +55,7 @@ function traceat(state::GFGenerateState, dist::Distribution{T}, state.weight += score end - retval + return retval end function traceat(state::GFGenerateState, gen_fn::GenerativeFunction{T,U}, @@ -69,8 +71,7 @@ function traceat(state::GFGenerateState, gen_fn::GenerativeFunction{T,U}, # get subtrace (subtrace, weight) = generate( - gen_fn, args, constraints; - parameter_context=state.parameter_context) + gen_fn, args, constraints, state.parameter_context) # add to the trace add_call!(state.trace, key, subtrace) @@ -81,14 +82,14 @@ function traceat(state::GFGenerateState, gen_fn::GenerativeFunction{T,U}, # get return value retval = get_retval(subtrace) - retval + return retval end function generate( - gen_fn::DynamicDSLFunction, args::Tuple, constraints::ChoiceMap; - parameter_context=default_parameter_context) + gen_fn::DynamicDSLFunction, args::Tuple, constraints::ChoiceMap, + parameter_context::Dict) state = GFGenerateState(gen_fn, args, constraints, parameter_context) retval = exec(gen_fn, state, args) set_retval!(state.trace, retval) - (state.trace, state.weight) + return (state.trace, state.weight) end diff --git a/src/dynamic/propose.jl b/src/dynamic/propose.jl index 3cece0e3a..2642229f7 100644 --- a/src/dynamic/propose.jl +++ b/src/dynamic/propose.jl @@ -37,7 +37,7 @@ function traceat(state::GFProposeState, dist::Distribution{T}, # update weight state.weight += logpdf(dist, retval, args...) - retval + return retval end function traceat(state::GFProposeState, gen_fn::GenerativeFunction{T,U}, @@ -56,12 +56,10 @@ function traceat(state::GFProposeState, gen_fn::GenerativeFunction{T,U}, # update weight state.weight += weight - retval + return retval end -function propose( - gen_fn::DynamicDSLFunction, args::Tuple; - parameter_context=default_parameter_context) +function propose(gen_fn::DynamicDSLFunction, args::Tuple, parameter_context::Dict) state = GFProposeState(gen_fn, parameter_context) retval = exec(gen_fn, state, args) return (state.choices, state.weight, retval) diff --git a/src/dynamic/regenerate.jl b/src/dynamic/regenerate.jl index 371240f44..2c35251e3 100644 --- a/src/dynamic/regenerate.jl +++ b/src/dynamic/regenerate.jl @@ -8,9 +8,8 @@ mutable struct GFRegenerateState function GFRegenerateState(gen_fn, args, prev_trace, selection) visitor = AddressVisitor() - trace = DynamicDSLTrace(gen_fn, args, get_parameter_store(prev_trace)) - return new(prev_trace, trace, selection, - 0., visitor, gen_fn) + trace = initialize_from(prev_trace, args) + return new(prev_trace, trace, selection, 0.0, visitor, gen_fn) end end @@ -24,7 +23,6 @@ function set_active_gen_fn!(state::GFRegenerateState, gen_fn::GenerativeFunction state.active_gen_fn = gen_fn end - function traceat(state::GFRegenerateState, dist::Distribution{T}, args, key) where {T} local prev_retval::T @@ -88,7 +86,8 @@ function traceat(state::GFRegenerateState, gen_fn::GenerativeFunction{T,U}, (subtrace, weight, _) = regenerate( prev_subtrace, args, map((_) -> UnknownChange(), args), subselection) else - (subtrace, weight) = generate(gen_fn, args, EmptyChoiceMap()) + (subtrace, weight) = generate( + gen_fn, args, EmptyChoiceMap(), state.trace.parameter_context) end # update weight diff --git a/src/dynamic/simulate.jl b/src/dynamic/simulate.jl index 8a4337092..c366176da 100644 --- a/src/dynamic/simulate.jl +++ b/src/dynamic/simulate.jl @@ -7,7 +7,10 @@ mutable struct GFSimulateState function GFSimulateState( gen_fn::GenerativeFunction, args::Tuple, parameter_context) parameter_store = get_julia_store(parameter_context) - trace = DynamicDSLTrace(gen_fn, args, parameter_store) + registered_julia_parameters = Set{Tuple{GenerativeFunction,Symbol}}( + get_parameters(gen_fn, parameter_context)[parameter_store]) + trace = DynamicDSLTrace( + gen_fn, args, parameter_store, parameter_context, registered_julia_parameters) return new(trace, AddressVisitor(), gen_fn, parameter_context) end end @@ -49,7 +52,7 @@ function traceat(state::GFSimulateState, gen_fn::GenerativeFunction{T,U}, visit!(state.visitor, key) # get subtrace - subtrace = simulate(gen_fn, args; parameter_context=state.parameter_context) + subtrace = simulate(gen_fn, args, state.parameter_context) # add to the trace add_call!(state.trace, key, subtrace) @@ -60,9 +63,7 @@ function traceat(state::GFSimulateState, gen_fn::GenerativeFunction{T,U}, retval end -function simulate( - gen_fn::DynamicDSLFunction, args::Tuple; - parameter_context=default_parameter_context) +function simulate(gen_fn::DynamicDSLFunction, args::Tuple, parameter_context::Dict) state = GFSimulateState(gen_fn, args, parameter_context) retval = exec(gen_fn, state, args) set_retval!(state.trace, retval) diff --git a/src/dynamic/trace.jl b/src/dynamic/trace.jl index 4673ae1e8..ea13135e6 100644 --- a/src/dynamic/trace.jl +++ b/src/dynamic/trace.jl @@ -38,14 +38,27 @@ mutable struct DynamicDSLTrace{T} <: Trace noise::Float64 args::Tuple parameter_store::JuliaParameterStore + parameter_context::Dict + registered_julia_parameters::Set{Tuple{GenerativeFunction,Symbol}} # for runtime cross-check retval::Any - function DynamicDSLTrace{T}(gen_fn::T, args, parameter_store::JuliaParameterStore) where {T} + function DynamicDSLTrace{T}( + gen_fn::T, args, parameter_store::JuliaParameterStore, parameter_context, + registered_julia_parameters::Set{Tuple{GenerativeFunction,Symbol}}) where {T} trie = Trie{Any,ChoiceOrCallRecord}() # retval is not known yet - new(gen_fn, trie, true, 0, 0, args, parameter_store) + new( + gen_fn, trie, true, 0, 0, args, parameter_store, + parameter_context, registered_julia_parameters) end end +function initialize_from(other::DynamicDSLTrace, args) + gen_fn = get_gen_fn(other) + return DynamicDSLTrace( + gen_fn, args, other.parameter_store, other.parameter_context, + other.registered_julia_parameters) +end + get_parameter_store(trace::DynamicDSLTrace) = trace.parameter_store set_retval!(trace::DynamicDSLTrace, retval) = (trace.retval = retval) diff --git a/src/dynamic/update.jl b/src/dynamic/update.jl index 7b49cd1c3..6085fe712 100644 --- a/src/dynamic/update.jl +++ b/src/dynamic/update.jl @@ -10,9 +10,9 @@ mutable struct GFUpdateState function GFUpdateState(gen_fn, args, prev_trace, constraints) visitor = AddressVisitor() discard = choicemap() - trace = DynamicDSLTrace(gen_fn, args, get_parameter_store(prev_trace)) - return new(prev_trace, trace, constraints, - 0., visitor, discard, gen_fn) + parameter_store = get_parameter_store(prev_trace) + trace = initialize_from(prev_trace, args) + return new(prev_trace, trace, constraints, 0.0, visitor, discard, gen_fn) end end @@ -26,7 +26,6 @@ function set_active_gen_fn!(state::GFUpdateState, gen_fn::GenerativeFunction) state.active_gen_fn = gen_fn end - function traceat(state::GFUpdateState, dist::Distribution{T}, args::Tuple, key) where {T} @@ -101,7 +100,8 @@ function traceat(state::GFUpdateState, gen_fn::GenerativeFunction{T,U}, (subtrace, weight, _, discard) = update(prev_subtrace, args, map((_) -> UnknownChange(), args), constraints) else - (subtrace, weight) = generate(gen_fn, args, constraints) + (subtrace, weight) = generate( + gen_fn, args, constraints, state.trace.parameter_context) end # update the weight diff --git a/src/gen_fn_interface.jl b/src/gen_fn_interface.jl index 6037f2521..0728e4285 100644 --- a/src/gen_fn_interface.jl +++ b/src/gen_fn_interface.jl @@ -160,7 +160,7 @@ Return an iterable over the trainable parameters of the generative function. get_params(::GenerativeFunction) = () """ - trace = simulate(gen_fn, args; parameter_context=Dict()) + trace = simulate(gen_fn, args, parameter_context=Dict()) Execute the generative function and return the trace. @@ -171,22 +171,21 @@ If `gen_fn` has optional trailing arguments (i.e., default values are provided), the optional arguments can be omitted from the `args` tuple. The generated trace will have default values filled in. """ -function simulate(::GenerativeFunction, ::Tuple; parameter_context=Dict()) +function simulate(::GenerativeFunction, ::Tuple, parameter_context::Dict) error("Not implemented") end -""" - (trace::U, weight) = generate( - gen_fn::GenerativeFunction{T,U}, args::Tuple; parameter_context=Dict()) - -Return a trace of a generative function. +function simulate(gen_fn::GenerativeFunction, args::Tuple) + return simulate(gen_fn, args, Dict()) +end +""" (trace::U, weight) = generate( gen_fn::GenerativeFunction{T,U}, args::Tuple, - constraints::ChoiceMap; parameter_context=Dict()) + constraints=EmptyChoiceMap(), parameter_context=Dict()) Return a trace of a generative function that is consistent with the given -constraints on the random choices. +constraints on the random choices, if any. Given arguments \$x\$ (`args`) and assignment \$u\$ (`constraints`) (which is empty for the first form), sample \$t \\sim q(\\cdot; u, x)\$ and \$r \\sim q(\\cdot; x, t)\$, and return the trace \$(x, t, r)\$ (`trace`). @@ -209,14 +208,19 @@ Example with constraint that address `:z` takes value `true`. (trace, weight) = generate(foo, (2, 4), choicemap((:z, true)) ``` """ -function generate(::GenerativeFunction, ::Tuple, ::ChoiceMap; parameter_context=Dict()) +function generate(::GenerativeFunction, ::Tuple, ::ChoiceMap, parameter_context::Dict) error("Not implemented") end -function generate(gen_fn::GenerativeFunction, args::Tuple; parameter_context=Dict()) - generate(gen_fn, args, EmptyChoiceMap()) +function generate(gen_fn::GenerativeFunction, args::Tuple, choices::ChoiceMap) + return generate(gen_fn, args, choices, Dict()) +end + +function generate(gen_fn::GenerativeFunction, args::Tuple) + return generate(gen_fn, args, EmptyChoiceMap(), Dict()) end + """ weight = project(trace::U, selection::Selection) @@ -234,8 +238,10 @@ function project(trace, selection::Selection) error("Not implemented") end + """ - (choices, weight, retval) = propose(gen_fn::GenerativeFunction, args::Tuple) + (choices, weight, retval) = propose( + gen_fn::GenerativeFunction, args::Tuple, parameter_context=Dict()) Sample an assignment and compute the probability of proposing that assignment. @@ -246,14 +252,18 @@ t)\$, and return \$t\$ \\log \\frac{p(r, t; x)}{q(r; x, t)} ``` """ -function propose(gen_fn::GenerativeFunction, args::Tuple; parameter_context=Dict()) - trace = simulate(gen_fn, args) +function propose(gen_fn::GenerativeFunction, args::Tuple, parameter_context::Dict) + trace = simulate(gen_fn, args, parameter_context) weight = get_score(trace) - (get_choices(trace), weight, get_retval(trace)) + return (get_choices(trace), weight, get_retval(trace)) end +propose(gen_fn::GenerativeFunction, args::Tuple) = propose(gen_fn, args, Dict()) + """ - (weight, retval) = assess(gen_fn::GenerativeFunction, args::Tuple, choices::ChoiceMap) + (weight, retval) = assess( + gen_fn::GenerativeFunction, args::Tuple, choices::ChoiceMap, + parameter_context=Dict()) Return the probability of proposing an assignment @@ -265,14 +275,20 @@ return the weight (`weight`): ``` It is an error if \$p(t; x) = 0\$. """ +function assess( + gen_fn::GenerativeFunction, args::Tuple, choices::ChoiceMap, parameter_context::Dict) + (trace, weight) = generate(gen_fn, args, choices, parameter_context) + return (weight, get_retval(trace)) +end + function assess(gen_fn::GenerativeFunction, args::Tuple, choices::ChoiceMap) - (trace, weight) = generate(gen_fn, args, choices) - (weight, get_retval(trace)) + return assess(gen_fn, args, choices, Dict()) end + """ - (new_trace, weight, retdiff, discard) = update(trace, args::Tuple, argdiffs::Tuple, - constraints::ChoiceMap) + (new_trace, weight, retdiff, discard) = update( + trace, args::Tuple, argdiffs::Tuple, constraints::ChoiceMap) Update a trace by changing the arguments and/or providing new values for some existing random choice(s) and values for some newly introduced random choice(s). @@ -317,8 +333,8 @@ function update(trace, constraints::ChoiceMap) end """ - (new_trace, weight, retdiff) = regenerate(trace, args::Tuple, argdiffs::Tuple, - selection::Selection) + (new_trace, weight, retdiff) = regenerate( + trace, args::Tuple, argdiffs::Tuple, selection::Selection) Update a trace by changing the arguments and/or randomly sampling new values for selected random choices using the internal proposal distribution family. diff --git a/src/modeling_library/call_at/call_at.jl b/src/modeling_library/call_at/call_at.jl index 0c5b997bc..ac2393f47 100644 --- a/src/modeling_library/call_at/call_at.jl +++ b/src/modeling_library/call_at/call_at.jl @@ -64,6 +64,8 @@ function accepts_output_grad(gen_fn::CallAtCombinator) accepts_output_grad(gen_fn.kernel) end +get_parameters(gen_fn::CallAtCombinator, context) = get_parameters(gen_fn.kernel, context) + unpack_call_at_args(args) = (args[end], args[1:end-1]) @@ -89,13 +91,14 @@ function simulate(gen_fn::CallAtCombinator, args::Tuple) CallAtTrace(gen_fn, subtrace, key) end -function generate(gen_fn::CallAtCombinator{T,U,K}, args::Tuple, - choices::ChoiceMap) where {T,U,K} +function generate( + gen_fn::CallAtCombinator{T,U,K}, args::Tuple, + choices::ChoiceMap, parameter_context::Dict) where {T,U,K} (key, kernel_args) = unpack_call_at_args(args) submap = get_submap(choices, key) - (subtrace, weight) = generate(gen_fn.kernel, kernel_args, submap) + (subtrace, weight) = generate(gen_fn.kernel, kernel_args, submap, parameter_context) trace = CallAtTrace(gen_fn, subtrace, key) - (trace, weight) + return (trace, weight) end function project(trace::CallAtTrace, selection::Selection) diff --git a/src/modeling_library/custom_determ.jl b/src/modeling_library/custom_determ.jl index 24d6d90f2..f5fadf641 100644 --- a/src/modeling_library/custom_determ.jl +++ b/src/modeling_library/custom_determ.jl @@ -93,12 +93,16 @@ function accumulate_param_gradients_determ!( gradient_with_state(gen_fn, state, args, retgrad) end -function simulate(gen_fn::CustomDetermGF{T,S}, args::Tuple) where {T,S} +function simulate( + gen_fn::CustomDetermGF{T,S}, args::Tuple, + parameter_context::Dict) where {T,S} retval, state = apply_with_state(gen_fn, args) CustomDetermGFTrace{T,S}(retval, state, args, gen_fn) end -function generate(gen_fn::CustomDetermGF{T,S}, args::Tuple, choices::ChoiceMap) where {T,S} +function generate( + gen_fn::CustomDetermGF{T,S}, args::Tuple, + choices::ChoiceMap, parameter_context::Dict) where {T,S} if !isempty(choices) error("Deterministic generative function makes no random choices") end @@ -107,7 +111,9 @@ function generate(gen_fn::CustomDetermGF{T,S}, args::Tuple, choices::ChoiceMap) trace, 0. end -function update(trace::CustomDetermGFTrace{T,S}, args::Tuple, argdiffs::Tuple, choices::ChoiceMap) where {T,S} +function update( + trace::CustomDetermGFTrace{T,S}, args::Tuple, argdiffs::Tuple, + choices::ChoiceMap) where {T,S} if !isempty(choices) error("Deterministic generative function makes no random choices") end @@ -129,7 +135,8 @@ function accumulate_param_gradients!(trace::CustomDetermGFTrace, retgrad, scale_ accumulate_param_gradients_determ!(trace.gen_fn, trace.state, trace.args, retgrad, scale_factor) end -export CustomDetermGF, CustomDetermGFTrace, apply_with_state, update_with_state, gradient_with_state, accumulate_param_gradients_determ! +export CustomDetermGF, CustomDetermGFTrace, apply_with_state, update_with_state +export gradient_with_state, accumulate_param_gradients_determ! #################### # CustomGradientGF # diff --git a/src/modeling_library/map/assess.jl b/src/modeling_library/map/assess.jl index 74538d9d4..1d4e62a16 100644 --- a/src/modeling_library/map/assess.jl +++ b/src/modeling_library/map/assess.jl @@ -3,20 +3,24 @@ mutable struct MapAssessState{T} retvals::Vector{T} end -function process_new!(gen_fn::Map{T,U}, args::Tuple, choices::ChoiceMap, - key::Int, state::MapAssessState{T}) where {T,U} +function process_new!( + gen_fn::Map{T,U}, args::Tuple, choices::ChoiceMap, + key::Int, state::MapAssessState{T}, + parameter_context) where {T,U} kernel_args = get_args_for_key(args, key) submap = get_submap(choices, key) - (weight, retval) = assess(gen_fn.kernel, kernel_args, submap) + (weight, retval) = assess(gen_fn.kernel, kernel_args, submap, parameter_context) state.weight += weight state.retvals[key] = retval end -function assess(gen_fn::Map{T,U}, args::Tuple, choices::ChoiceMap) where {T,U} +function assess( + gen_fn::Map{T,U}, args::Tuple, choices::ChoiceMap, + parameter_context::Dict) where {T,U} len = length(args[1]) state = MapAssessState{T}(0., Vector{T}(undef,len)) for key=1:len - process_new!(gen_fn, args, choices, key, state) + process_new!(gen_fn, args, choices, key, state, parameter_context) end - (state.weight, PersistentVector{T}(state.retvals)) + return (state.weight, PersistentVector{T}(state.retvals)) end diff --git a/src/modeling_library/map/generate.jl b/src/modeling_library/map/generate.jl index 8a5c2e8da..cc886a4a6 100644 --- a/src/modeling_library/map/generate.jl +++ b/src/modeling_library/map/generate.jl @@ -7,31 +7,35 @@ mutable struct MapGenerateState{T,U} num_nonempty::Int end -function process!(gen_fn::Map{T,U}, args::Tuple, choices::ChoiceMap, - key::Int, state::MapGenerateState{T,U}) where {T,U} +function process!( + gen_fn::Map{T,U}, args::Tuple, choices::ChoiceMap, + key::Int, state::MapGenerateState{T,U}, + parameter_context) where {T,U} local subtrace::U local retval::T kernel_args = get_args_for_key(args, key) submap = get_submap(choices, key) - (subtrace, weight) = generate(gen_fn.kernel, kernel_args, submap) + (subtrace, weight) = generate(gen_fn.kernel, kernel_args, submap, parameter_context) state.weight += weight state.noise += project(subtrace, EmptySelection()) state.num_nonempty += (isempty(get_choices(subtrace)) ? 0 : 1) state.score += get_score(subtrace) state.subtraces[key] = subtrace retval = get_retval(subtrace) - state.retval[key] = retval + return state.retval[key] = retval end -function generate(gen_fn::Map{T,U}, args::Tuple, choices::ChoiceMap) where {T,U} +function generate( + gen_fn::Map{T,U}, args::Tuple, choices::ChoiceMap, + parameter_context::Dict) where {T,U} len = length(args[1]) state = MapGenerateState{T,U}(0., 0., 0., Vector{U}(undef,len), Vector{T}(undef,len), 0) # TODO check for keys that aren't valid constraints for key=1:len - process!(gen_fn, args, choices, key, state) + process!(gen_fn, args, choices, key, state, parameter_context) end trace = VectorTrace{MapType,T,U}(gen_fn, PersistentVector{U}(state.subtraces), PersistentVector{T}(state.retval), args, state.score, state.noise, len, state.num_nonempty) - (trace, state.weight) + return (trace, state.weight) end diff --git a/src/modeling_library/map/map.jl b/src/modeling_library/map/map.jl index 1bb695ff7..6335dc7ec 100644 --- a/src/modeling_library/map/map.jl +++ b/src/modeling_library/map/map.jl @@ -26,6 +26,7 @@ export Map has_argument_grads(map_gf::Map) = has_argument_grads(map_gf.kernel) accepts_output_grad(map_gf::Map) = accepts_output_grad(map_gf.kernel) +get_parameters(map_gf::Map, parameter_context) = get_parameters(map_gf.kernel, parameter_context) function (gen_fn::Map)(args...) (_, _, retval) = propose(gen_fn, args) diff --git a/src/modeling_library/map/propose.jl b/src/modeling_library/map/propose.jl index c0ca14330..f6929a13a 100644 --- a/src/modeling_library/map/propose.jl +++ b/src/modeling_library/map/propose.jl @@ -4,22 +4,25 @@ mutable struct MapProposeState{T} retvals::Vector{T} end -function process_new!(gen_fn::Map{T,U}, args::Tuple, key::Int, - state::MapProposeState{T}) where {T,U} +function process_new!( + gen_fn::Map{T,U}, args::Tuple, key::Int, + state::MapProposeState{T}, + parameter_context) where {T,U} local subtrace::U kernel_args = get_args_for_key(args, key) - (submap, weight, retval) = propose(gen_fn.kernel, kernel_args) + (submap, weight, retval) = propose(gen_fn.kernel, kernel_args, parameter_context) set_submap!(state.choices, key, submap) state.weight += weight state.retvals[key] = retval end -function propose(gen_fn::Map{T,U}, args::Tuple) where {T,U} +function propose( + gen_fn::Map{T,U}, args::Tuple, parameter_context::Dict) where {T,U} len = length(args[1]) choices = choicemap() state = MapProposeState{T}(choices, 0., Vector{T}(undef,len)) for key=1:len - process_new!(gen_fn, args, key, state) + process_new!(gen_fn, args, key, state, parameter_context) end - (state.choices, state.weight, PersistentVector{T}(state.retvals)) + return (state.choices, state.weight, PersistentVector{T}(state.retvals)) end diff --git a/src/modeling_library/map/simulate.jl b/src/modeling_library/map/simulate.jl index 216e140c8..610ce0db1 100644 --- a/src/modeling_library/map/simulate.jl +++ b/src/modeling_library/map/simulate.jl @@ -7,11 +7,12 @@ mutable struct MapSimulateState{T,U} end function process!(gen_fn::Map{T,U}, args::Tuple, - key::Int, state::MapSimulateState{T,U}) where {T,U} + key::Int, state::MapSimulateState{T,U}, + parameter_context) where {T,U} local subtrace::U local retval::T kernel_args = get_args_for_key(args, key) - subtrace = simulate(gen_fn.kernel, kernel_args) + subtrace = simulate(gen_fn.kernel, kernel_args, parameter_context) state.noise += project(subtrace, EmptySelection()) state.num_nonempty += (isempty(get_choices(subtrace)) ? 0 : 1) state.score += get_score(subtrace) @@ -20,11 +21,11 @@ function process!(gen_fn::Map{T,U}, args::Tuple, state.retval[key] = retval end -function simulate(gen_fn::Map{T,U}, args::Tuple) where {T,U} +function simulate(gen_fn::Map{T,U}, args::Tuple, parameter_context::Dict) where {T,U} len = length(args[1]) state = MapSimulateState{T,U}(0., 0., Vector{U}(undef,len), Vector{T}(undef,len), 0) for key=1:len - process!(gen_fn, args, key, state) + process!(gen_fn, args, key, state, parameter_context) end VectorTrace{MapType,T,U}(gen_fn, PersistentVector{U}(state.subtraces), PersistentVector{T}(state.retval), diff --git a/src/modeling_library/recurse/recurse.jl b/src/modeling_library/recurse/recurse.jl index 7b4d1d23d..15339fb7b 100644 --- a/src/modeling_library/recurse/recurse.jl +++ b/src/modeling_library/recurse/recurse.jl @@ -131,6 +131,13 @@ end # TODO accepts_output_grad(::Recurse) = false +function get_parameters(gen_fn::Recurse, parameter_context) + parameter_stores_to_ids = Dict{Any,Vector}() + merge!(parameter_stores_to_ids, get_parameters(gen_fn.production_kernel, parameter_context)) + merge!(parameter_stores_to_ids, get_parameters(gen_fn.aggregation_kernel, parameter_context)) + return parameter_stores_to_ids +end + function (gen_fn::Recurse)(args...) (_, _, retval) = propose(gen_fn, args) retval @@ -197,7 +204,9 @@ end # simulate # ############ -function simulate(gen_fn::Recurse{S,T,U,V,W,X,Y}, args::Tuple{U,Int}) where {S,T,U,V,W,X,Y} +function simulate( + gen_fn::Recurse{S,T,U,V,W,X,Y}, args::Tuple{U,Int}, + parameter_context::Dict) where {S,T,U,V,W,X,Y} (root_production_input::U, root_idx::Int) = args production_traces = PersistentHashMap{Int,S}() aggregation_traces = PersistentHashMap{Int,T}() @@ -213,7 +222,7 @@ function simulate(gen_fn::Recurse{S,T,U,V,W,X,Y}, args::Tuple{U,Int}) where {S,T cur = first(prod_to_visit) delete!(prod_to_visit, cur) input = get_production_input(gen_fn, cur, production_traces, root_idx, root_production_input) - subtrace = simulate(gen_fn.production_kern, (input,)) + subtrace = simulate(gen_fn.production_kern, (input,), parameter_context) score += get_score(subtrace) production_traces = assoc(production_traces, cur, subtrace) children_inputs::Vector{U} = get_retval(subtrace).children @@ -232,7 +241,7 @@ function simulate(gen_fn::Recurse{S,T,U,V,W,X,Y}, args::Tuple{U,Int}) where {S,T local subtrace::T local input::Tuple{V,Vector{W}} input = get_aggregation_input(gen_fn, cur, production_traces, aggregation_traces) - subtrace = simulate(gen_fn.aggregation_kern, input) + subtrace = simulate(gen_fn.aggregation_kern, input, parameter_context) score += get_score(subtrace) aggregation_traces = assoc(aggregation_traces, cur, subtrace) if !isempty(get_choices(subtrace)) @@ -249,8 +258,10 @@ end # generate # ############ -function generate(gen_fn::Recurse{S,T,U,V,W,X,Y}, args::Tuple{U,Int}, - constraints::ChoiceMap) where {S,T,U,V,W,X,Y} +function generate( + gen_fn::Recurse{S,T,U,V,W,X,Y}, args::Tuple{U,Int}, + constraints::ChoiceMap, + parameter_context::Dict) where {S,T,U,V,W,X,Y} (root_production_input::U, root_idx::Int) = args production_traces = PersistentHashMap{Int,S}() aggregation_traces = PersistentHashMap{Int,T}() @@ -268,7 +279,7 @@ function generate(gen_fn::Recurse{S,T,U,V,W,X,Y}, args::Tuple{U,Int}, delete!(prod_to_visit, cur) input = get_production_input(gen_fn, cur, production_traces, root_idx, root_production_input) subconstraints = get_production_constraints(constraints, cur) - (subtrace, subweight) = generate(gen_fn.production_kern, (input,), subconstraints) + (subtrace, subweight) = generate(gen_fn.production_kern, (input,), subconstraints, parameter_context) score += get_score(subtrace) production_traces = assoc(production_traces, cur, subtrace) weight += subweight @@ -289,7 +300,7 @@ function generate(gen_fn::Recurse{S,T,U,V,W,X,Y}, args::Tuple{U,Int}, local input::Tuple{V,Vector{W}} input = get_aggregation_input(gen_fn, cur, production_traces, aggregation_traces) subconstraints = get_aggregation_constraints(constraints, cur) - (subtrace, subweight) = generate(gen_fn.aggregation_kern, input, subconstraints) + (subtrace, subweight) = generate(gen_fn.aggregation_kern, input, subconstraints, parameter_context) score += get_score(subtrace) aggregation_traces = assoc(aggregation_traces, cur, subtrace) weight += subweight @@ -567,7 +578,7 @@ function update(trace::RecurseTrace{S,T,U,V,W,X,Y}, else # the node does not exist already (and none of its children exist either) - (subtrace, ) = generate(gen_fn.production_kern, input, subconstraints) + (subtrace, ) = generate(gen_fn.production_kern, input, subconstraints, parameter_context) # update trace, weight, and score production_traces = assoc(production_traces, cur, subtrace) @@ -649,7 +660,7 @@ function update(trace::RecurseTrace{S,T,U,V,W,X,Y}, # if the node does not exist (but its children do, since we created them already) else - (subtrace, _) = generate(gen_fn.aggregation_kern, input, subconstraints) + (subtrace, _) = generate(gen_fn.aggregation_kern, input, subconstraints, parameter_context) # update trace, weight, and score aggregation_traces = assoc(aggregation_traces, cur, subtrace) diff --git a/src/modeling_library/switch/assess.jl b/src/modeling_library/switch/assess.jl index 4371eb8a4..187b48d00 100644 --- a/src/modeling_library/switch/assess.jl +++ b/src/modeling_library/switch/assess.jl @@ -4,23 +4,32 @@ mutable struct SwitchAssessState{T} SwitchAssessState{T}(weight::Float64) where T = new{T}(weight) end -function process!(gen_fn::Switch{C, N, K, T}, - index::Int, - args::Tuple, - choices::ChoiceMap, - state::SwitchAssessState{T}) where {C, N, K, T} - (weight, retval) = assess(getindex(gen_fn.branches, index), args, choices) +function process!( + gen_fn::Switch{C, N, K, T}, + index::Int, args::Tuple, + choices::ChoiceMap, + state::SwitchAssessState{T}, + parameter_context) where {C, N, K, T} + (weight, retval) = assess( + getindex(gen_fn.branches, index), args, choices, parameter_context) state.weight = weight state.retval = retval end -@inline process!(gen_fn::Switch{C, N, K, T}, index::C, args::Tuple, choices::ChoiceMap, state::SwitchAssessState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), args, choices, state) +@inline function process!( + gen_fn::Switch{C, N, K, T}, index::C, args::Tuple, + choices::ChoiceMap, state::SwitchAssessState{T}, + parameter_context) where {C, N, K, T} + return process!( + gen_fn, getindex(gen_fn.cases, index), args, choices, state, parameter_context) +end -function assess(gen_fn::Switch{C, N, K, T}, - args::Tuple, - choices::ChoiceMap) where {C, N, K, T} +function assess( + gen_fn::Switch{C, N, K, T}, args::Tuple, + choices::ChoiceMap, + parameter_context::Dict) where {C, N, K, T} index = args[1] state = SwitchAssessState{T}(0.0) - process!(gen_fn, index, args[2 : end], choices, state) - return state.weight, state.retval + process!(gen_fn, index, args[2 : end], choices, state, parameter_context) + return (state.weight, state.retval) end diff --git a/src/modeling_library/switch/generate.jl b/src/modeling_library/switch/generate.jl index bd03f632e..eeed2c5d4 100644 --- a/src/modeling_library/switch/generate.jl +++ b/src/modeling_library/switch/generate.jl @@ -8,27 +8,38 @@ mutable struct SwitchGenerateState{T} SwitchGenerateState{T}(score::Float64, noise::Float64, weight::Float64) where T = new{T}(score, noise, weight) end -function process!(gen_fn::Switch{C, N, K, T}, - index::Int, - args::Tuple, - choices::ChoiceMap, - state::SwitchGenerateState{T}) where {C, N, K, T} +function process!( + gen_fn::Switch{C, N, K, T}, + index::Int, args::Tuple, + choices::ChoiceMap, + state::SwitchGenerateState{T}, + parameter_context) where {C, N, K, T} - (subtrace, weight) = generate(getindex(gen_fn.branches, index), args, choices) + (subtrace, weight) = generate( + getindex(gen_fn.branches, index), args, choices, parameter_context) state.index = index state.subtrace = subtrace state.weight += weight state.retval = get_retval(subtrace) end -@inline process!(gen_fn::Switch{C, N, K, T}, index::C, args::Tuple, choices::ChoiceMap, state::SwitchGenerateState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), args, choices, state) +@inline function process!( + gen_fn::Switch{C, N, K, T}, index::C, args::Tuple, + choices::ChoiceMap, state::SwitchGenerateState{T}, + parameter_context) where {C, N, K, T} + return process!(gen_fn, getindex(gen_fn.cases, index), args, choices, state, parameter_context) +end -function generate(gen_fn::Switch{C, N, K, T}, - args::Tuple, - choices::ChoiceMap) where {C, N, K, T} +function generate( + gen_fn::Switch{C, N, K, T}, + args::Tuple, choices::ChoiceMap, + parameter_context::Dict) where {C, N, K, T} index = args[1] state = SwitchGenerateState{T}(0.0, 0.0, 0.0) - process!(gen_fn, index, args[2 : end], choices, state) - return SwitchTrace{T}(gen_fn, state.index, state.subtrace, state.retval, args[2 : end], state.score, state.noise), state.weight + process!(gen_fn, index, args[2 : end], choices, state, parameter_context) + trace = SwitchTrace{T}( + gen_fn, state.index, state.subtrace, state.retval, + args[2 : end], state.score, state.noise) + return (trace, state.weight) end diff --git a/src/modeling_library/switch/propose.jl b/src/modeling_library/switch/propose.jl index b4df1d97f..abd38a04e 100644 --- a/src/modeling_library/switch/propose.jl +++ b/src/modeling_library/switch/propose.jl @@ -5,25 +5,32 @@ mutable struct SwitchProposeState{T} SwitchProposeState{T}(choices, weight) where T = new{T}(choices, weight) end -function process!(gen_fn::Switch{C, N, K, T}, - index::Int, - args::Tuple, - state::SwitchProposeState{T}) where {C, N, K, T} - - (submap, weight, retval) = propose(getindex(gen_fn.branches, index), args) +function process!( + gen_fn::Switch{C, N, K, T}, + index::Int, args::Tuple, + state::SwitchProposeState{T}, + parameter_context) where {C, N, K, T} + (submap, weight, retval) = propose( + getindex(gen_fn.branches, index), args, parameter_context) state.choices = submap state.weight += weight state.retval = retval end -@inline process!(gen_fn::Switch{C, N, K, T}, index::C, args::Tuple, state::SwitchProposeState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), args, state) +@inline function process!( + gen_fn::Switch{C, N, K, T}, index::C, args::Tuple, + state::SwitchProposeState{T}, + parameter_context) where {C, N, K, T} + return process!(gen_fn, getindex(gen_fn.cases, index), args, state, parameter_context) +end -function propose(gen_fn::Switch{C, N, K, T}, - args::Tuple) where {C, N, K, T} +function propose( + gen_fn::Switch{C, N, K, T}, args::Tuple, + parameter_context::Dict) where {C, N, K, T} index = args[1] choices = choicemap() state = SwitchProposeState{T}(choices, 0.0) - process!(gen_fn, index, args[2:end], state) - return state.choices, state.weight, state.retval + process!(gen_fn, index, args[2:end], state, parameter_context) + return (state.choices, state.weight, state.retval) end diff --git a/src/modeling_library/switch/simulate.jl b/src/modeling_library/switch/simulate.jl index fc4b3b02a..52683b78f 100644 --- a/src/modeling_library/switch/simulate.jl +++ b/src/modeling_library/switch/simulate.jl @@ -7,12 +7,13 @@ mutable struct SwitchSimulateState{T} SwitchSimulateState{T}(score::Float64, noise::Float64) where T = new{T}(score, noise) end -function process!(gen_fn::Switch{C, N, K, T}, - index::Int, - args::Tuple, - state::SwitchSimulateState{T}) where {C, N, K, T} +function process!( + gen_fn::Switch{C, N, K, T}, + index::Int, args::Tuple, + state::SwitchSimulateState{T}, + parameter_context) where {C, N, K, T} local retval::T - subtrace = simulate(getindex(gen_fn.branches, index), args) + subtrace = simulate(getindex(gen_fn.branches, index), args, parameter_context) state.index = index state.noise += project(subtrace, EmptySelection()) state.subtrace = subtrace @@ -20,13 +21,19 @@ function process!(gen_fn::Switch{C, N, K, T}, state.retval = get_retval(subtrace) end -@inline process!(gen_fn::Switch{C, N, K, T}, index::C, args::Tuple, state::SwitchSimulateState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), args, state) +@inline function process!( + gen_fn::Switch{C, N, K, T}, index::C, + args::Tuple, state::SwitchSimulateState{T}, + parameter_context) where {C, N, K, T} + return process!(gen_fn, getindex(gen_fn.cases, index), args, state, parameter_context) +end -function simulate(gen_fn::Switch{C, N, K, T}, - args::Tuple) where {C, N, K, T} +function simulate( + gen_fn::Switch{C, N, K, T}, + args::Tuple, parameter_context::Dict) where {C, N, K, T} index = args[1] state = SwitchSimulateState{T}(0.0, 0.0) - process!(gen_fn, index, args[2 : end], state) + process!(gen_fn, index, args[2 : end], state, parameter_context) return SwitchTrace{T}(gen_fn, state.index, state.subtrace, state.retval, args[2 : end], state.score, state.noise) end diff --git a/src/modeling_library/switch/switch.jl b/src/modeling_library/switch/switch.jl index 821143448..1dfc71fe2 100644 --- a/src/modeling_library/switch/switch.jl +++ b/src/modeling_library/switch/switch.jl @@ -19,6 +19,14 @@ has_argument_grads(switch_fn::Switch) = map(zip(map(has_argument_grads, switch_f end accepts_output_grad(switch_fn::Switch) = all(accepts_output_grad, switch_fn.branches) +function get_parameters(gen_fn::Switch, parameter_context) + parameter_stores_to_ids = Dict{Any,Vector}() + for branch_gen_fn in gen_fn.branches + merge!(parameter_stores_to_ids, get_parameters(branch_gen_fn.production_kernel, parameter_context)) + end + return parameter_stores_to_ids +end + function (gen_fn::Switch)(index::Int, args...) (_, _, retval) = propose(gen_fn, (index, args...)) retval diff --git a/src/modeling_library/unfold/assess.jl b/src/modeling_library/unfold/assess.jl index 4199f77da..727805af2 100644 --- a/src/modeling_library/unfold/assess.jl +++ b/src/modeling_library/unfold/assess.jl @@ -4,24 +4,28 @@ mutable struct UnfoldAssessState{T} state::T end -function process_new!(gen_fn::Unfold{T,U}, params::Tuple, choices::ChoiceMap, - key::Int, state::UnfoldAssessState{T}) where {T,U} +function process_new!( + gen_fn::Unfold{T,U}, params::Tuple, choices::ChoiceMap, + key::Int, state::UnfoldAssessState{T}, + parameter_context) where {T,U} local new_state::T kernel_args = (key, state.state, params...) submap = get_submap(choices, key) - (weight, new_state) = assess(gen_fn.kernel, kernel_args, submap) + (weight, new_state) = assess(gen_fn.kernel, kernel_args, submap, parameter_context) state.weight += weight state.retvals[key] = new_state state.state = new_state end -function assess(gen_fn::Unfold{T,U}, args::Tuple, choices::ChoiceMap) where {T,U} +function assess( + gen_fn::Unfold{T,U}, args::Tuple, choices::ChoiceMap, + parameter_context::Dict) where {T,U} len = args[1] init_state = args[2] params = args[3:end] state = UnfoldAssessState{T}(0., Vector{T}(undef,len), init_state) for key=1:len - process_new!(gen_fn, params, choices, key, state) + process_new!(gen_fn, params, choices, key, state, parameter_context) end - (state.weight, PersistentVector{T}(state.retvals)) + return (state.weight, PersistentVector{T}(state.retvals)) end diff --git a/src/modeling_library/unfold/generate.jl b/src/modeling_library/unfold/generate.jl index 3ef9a78b9..d20ab1844 100644 --- a/src/modeling_library/unfold/generate.jl +++ b/src/modeling_library/unfold/generate.jl @@ -8,13 +8,15 @@ mutable struct UnfoldGenerateState{T,U} state::T end -function process!(gen_fn::Unfold{T,U}, params::Tuple, choices::ChoiceMap, - key::Int, state::UnfoldGenerateState{T,U}) where {T,U} +function process!( + gen_fn::Unfold{T,U}, params::Tuple, choices::ChoiceMap, + key::Int, state::UnfoldGenerateState{T,U}, + parameter_context) where {T,U} local subtrace::U local new_state::T kernel_args = (key, state.state, params...) submap = get_submap(choices, key) - (subtrace, weight) = generate(gen_fn.kernel, kernel_args, submap) + (subtrace, weight) = generate(gen_fn.kernel, kernel_args, submap, parameter_context) state.weight += weight state.noise += project(subtrace, EmptySelection()) state.num_nonempty += (isempty(get_choices(subtrace)) ? 0 : 1) @@ -22,20 +24,22 @@ function process!(gen_fn::Unfold{T,U}, params::Tuple, choices::ChoiceMap, state.subtraces[key] = subtrace new_state = get_retval(subtrace) state.state = new_state - state.retval[key] = new_state + return state.retval[key] = new_state end -function generate(gen_fn::Unfold{T,U}, args::Tuple, choices::ChoiceMap) where {T,U} +function generate( + gen_fn::Unfold{T,U}, args::Tuple, choices::ChoiceMap, + parameter_context::Dict) where {T,U} len = args[1] init_state = args[2] params = args[3:end] state = UnfoldGenerateState{T,U}(0., 0., 0., Vector{U}(undef,len), Vector{T}(undef,len), 0, init_state) for key=1:len - process!(gen_fn, params, choices, key, state) + process!(gen_fn, params, choices, key, state, parameter_context) end trace = VectorTrace{UnfoldType,T,U}(gen_fn, PersistentVector{U}(state.subtraces), PersistentVector{T}(state.retval), args, state.score, state.noise, len, state.num_nonempty) - (trace, state.weight) + return (trace, state.weight) end diff --git a/src/modeling_library/unfold/propose.jl b/src/modeling_library/unfold/propose.jl index 8863fbd4a..411467afd 100644 --- a/src/modeling_library/unfold/propose.jl +++ b/src/modeling_library/unfold/propose.jl @@ -5,8 +5,10 @@ mutable struct UnfoldProposeState{T} state::T end -function process_new!(gen_fn::Unfold{T,U}, params::Tuple, key::Int, - state::UnfoldProposeState{T}) where {T,U} +function process_new!( + gen_fn::Unfold{T,U}, params::Tuple, key::Int, + state::UnfoldProposeState{T}, + parameter_context) where {T,U} local new_state::T kernel_args = (key, state.state, params...) (submap, weight, new_state) = propose(gen_fn.kernel, kernel_args) @@ -16,14 +18,15 @@ function process_new!(gen_fn::Unfold{T,U}, params::Tuple, key::Int, state.state = new_state end -function propose(gen_fn::Unfold{T,U}, args::Tuple) where {T,U} +function propose( + gen_fn::Unfold{T,U}, args::Tuple, parameter_context::Dict) where {T,U} len = args[1] init_state = args[2] params = args[3:end] choices = choicemap() state = UnfoldProposeState{T}(choices, 0., Vector{T}(undef,len), init_state) for key=1:len - process_new!(gen_fn, params, key, state) + process_new!(gen_fn, params, key, state, parameter_context) end - (state.choices, state.weight, PersistentVector{T}(state.retvals)) + return (state.choices, state.weight, PersistentVector{T}(state.retvals)) end diff --git a/src/modeling_library/unfold/simulate.jl b/src/modeling_library/unfold/simulate.jl index e161e64a2..6eea21da5 100644 --- a/src/modeling_library/unfold/simulate.jl +++ b/src/modeling_library/unfold/simulate.jl @@ -7,12 +7,14 @@ mutable struct UnfoldSimulateState{T,U} state::T end -function process!(gen_fn::Unfold{T,U}, params::Tuple, - key::Int, state::UnfoldSimulateState{T,U}) where {T,U} +function process!( + gen_fn::Unfold{T,U}, params::Tuple, + key::Int, state::UnfoldSimulateState{T,U}, + parameter_context) where {T,U} local subtrace::U local new_state::T kernel_args = (key, state.state, params...) - subtrace = simulate(gen_fn.kernel, kernel_args) + subtrace = simulate(gen_fn.kernel, kernel_args, parameter_context) state.noise += project(subtrace, EmptySelection()) state.num_nonempty += (isempty(get_choices(subtrace)) ? 0 : 1) state.score += get_score(subtrace) @@ -22,14 +24,14 @@ function process!(gen_fn::Unfold{T,U}, params::Tuple, state.retval[key] = new_state end -function simulate(gen_fn::Unfold{T,U}, args::Tuple) where {T,U} +function simulate(gen_fn::Unfold{T,U}, args::Tuple, parameter_context::Dict) where {T,U} len = args[1] init_state = args[2] params = args[3:end] state = UnfoldSimulateState{T,U}(0., 0., Vector{U}(undef,len), Vector{T}(undef,len), 0, init_state) for key=1:len - process!(gen_fn, params, key, state) + process!(gen_fn, params, key, state, parameter_context) end VectorTrace{UnfoldType,T,U}(gen_fn, PersistentVector{U}(state.subtraces), PersistentVector{T}(state.retval), diff --git a/src/modeling_library/unfold/unfold.jl b/src/modeling_library/unfold/unfold.jl index 44238e3b7..45f240cd5 100644 --- a/src/modeling_library/unfold/unfold.jl +++ b/src/modeling_library/unfold/unfold.jl @@ -42,6 +42,9 @@ end # TODO accepts_output_grad(gen_fn::Unfold) = false +get_parameters(gen_fn::Unfold, parameter_context) = get_parameters(gen_fn.kernel, parameter_context) + + function (gen_fn::Unfold)(args...) (_, _, retval) = propose(gen_fn, args) retval diff --git a/src/optimization.jl b/src/optimization.jl index b0de2f535..107c7e201 100644 --- a/src/optimization.jl +++ b/src/optimization.jl @@ -1,18 +1,12 @@ import Parameters -# TODO notes -# # we should modify the semantics of the log probability contribution to the gradient # so that everything is gradient descent instead of ascent. this will also fix # the misnomer names # -# combinators (map etc.) and call_at! and choice_at! all need to implement get_parameters.. # TODO add tests specifically for JuliaParameterStore etc. # -# TODO GF untraced needs to reference a parameter store -# -# make changes to src/dynamic/backprop.jl -# make changes to other dynamic methods +# TODO in all update and regenerate implementations, need to pass in the parameter context to inner calls to generate export in_place_add! @@ -36,15 +30,15 @@ export get_gradient function in_place_add! end function in_place_add!(value::Array, increment) - @simd for i in 1:length(param) + @simd for i in 1:length(value) value[i] += increment[i] end return value end # this exists so user can use the same function on scalars and arrays -function in_place_add!(param::Real, increment::Real) - return param + increment +function in_place_add!(value::Real, increment::Real) + return value + increment end ############################ @@ -97,10 +91,10 @@ function fill_with_zeros!(accum::Accumulator{T}) where {T <: Real} return accum end -function fill_with_zeros!(accum::Accumulator{Array{T}}) where {T} +function fill_with_zeros!(accum::Accumulator{<:Array{T}}) where {T} lock(accum.lock) try - fill!(zero(T), accum.arr) + fill!(accum.value, zero(T)) finally unlock(accum.lock) end @@ -212,10 +206,10 @@ struct CompositeOptimizer optimizers::Dict{Any,Any} function CompositeOptimizer(conf, parameter_stores_to_ids::Dict{Any,Vector}) optimizers = Dict{Any,Any}() - for (store, parameter_ids) in parameters + for (store, parameter_ids) in parameter_stores_to_ids optimizers[store] = init_optimizer(conf, parameter_ids, store) end - new(states, conf) + new(conf, optimizers) end end @@ -276,7 +270,7 @@ const default_julia_parameter_store = JuliaParameterStore() # once a trace is generated, it is bound to use a particular store const JULIA_PARAMETER_STORE_KEY = :julia_parameter_store -function get_julia_store(context::Dict{Symbol,Any}) +function get_julia_store(context::Dict) if haskey(context, JULIA_PARAMETER_STORE_KEY) return context[JULIA_PARAMETER_STORE_KEY] else diff --git a/src/static_ir/generate.jl b/src/static_ir/generate.jl index 2d4ecbd91..808a7e722 100644 --- a/src/static_ir/generate.jl +++ b/src/static_ir/generate.jl @@ -114,8 +114,8 @@ end push!(generated_functions, quote @generated function $(GlobalRef(Gen, :generate))( gen_fn::$(QuoteNode(StaticIRGenerativeFunction)), - args::$(QuoteNode(Tuple)), constraints::$(QuoteNode(ChoiceMap)); - parameter_context=$(QuoteNode(default_parameter_context))) + args::$(QuoteNode(Tuple)), constraints::$(QuoteNode(ChoiceMap)), + parameter_context::Dict) $(QuoteNode(codegen_generate))(gen_fn, args, constraints) end end) diff --git a/src/static_ir/simulate.jl b/src/static_ir/simulate.jl index 5f6106386..4b356653b 100644 --- a/src/static_ir/simulate.jl +++ b/src/static_ir/simulate.jl @@ -49,7 +49,9 @@ function process!(state::StaticIRSimulateState, node::GenerativeFunctionCallNode push!(state.stmts, :($total_noise_fieldname += $(GlobalRef(Gen, :project))($subtrace, $(GlobalRef(Gen, :EmptySelection))()))) end -function codegen_simulate(gen_fn_type::Type{T}, args, parameter_context_type) where {T <: StaticIRGenerativeFunction} +function codegen_simulate( + gen_fn_type::Type{T}, args, + parameter_context_type) where {T <: StaticIRGenerativeFunction} ir = get_ir(gen_fn_type) options = get_options(gen_fn_type) @@ -87,8 +89,8 @@ end push!(generated_functions, quote @generated function $(GlobalRef(Gen, :simulate))( - gen_fn::$(QuoteNode(StaticIRGenerativeFunction)), args::Tuple; - parameter_context=$(QuoteNode(default_parameter_context))) + gen_fn::$(QuoteNode(StaticIRGenerativeFunction)), + args::Tuple, parameter_context::Dict) $(QuoteNode(codegen_simulate))(gen_fn, args, parameter_context) end end) diff --git a/test/dsl/dynamic_dsl.jl b/test/dsl/dynamic_dsl.jl index 0b7be3332..0e26f8148 100644 --- a/test/dsl/dynamic_dsl.jl +++ b/test/dsl/dynamic_dsl.jl @@ -285,6 +285,7 @@ end z = @trace(normal(mu_z + theta1, 1), :z) return z + mu_z end + register_parameters!(bar, [:theta1]) @gen (grad) function foo((grad)(mu_a::Float64)) @param theta2::Float64 @@ -294,6 +295,7 @@ end c = a * b * @trace(bar(a), :bar) return @trace(normal(c, 1), :out) + (theta2 * 3) end + register_parameters!(foo, [(bar, :theta1), :theta2]) init_parameter!((bar, :theta1), 0.0) init_parameter!((foo, :theta2), 0.0) @@ -371,12 +373,13 @@ end @param theta::Float64 return theta end - + register_parameters!(baz, [:theta]) init_parameter!((baz, :theta), 0.0) @gen (grad) function foo() return @trace(baz()) end + register_parameters!(foo, [(baz, :theta)]) (trace, _) = generate(foo, ()) retval_grad = 2. @@ -389,6 +392,7 @@ end @param theta::Float64 return theta end + register_parameters!(foo, [:theta]) init_parameter!((foo, :theta), 0.0) @@ -407,6 +411,7 @@ end @param theta::Float64 return theta end + register_parameters!(foo, [:theta]) init_parameter!((foo, :theta), 0.0) diff --git a/test/inference/train.jl b/test/inference/train.jl index f62641afa..277cb2817 100644 --- a/test/inference/train.jl +++ b/test/inference/train.jl @@ -48,6 +48,7 @@ end @trace(bernoulli(prob_y), :y) end + register_parameters!(student, [:theta1, :theta2, :theta3, :theta4, :theta5]) function data_generator() (choices, _, retval) = propose(teacher, ()) @@ -142,6 +143,7 @@ end z = @trace(normal(x + theta, exp(log_std)), :z) return z end + register_parameters!(q, [:theta, :log_std]) # train simple q using lecture! to compute gradients init_parameter!((q, :theta), 0.0) @@ -164,9 +166,10 @@ end @trace(normal(means[i], exp(log_std)), i => :z) end end + register_parameters!(q_batched, [:theta, :log_std]) # train simple q using lecture_batched! to compute gradients - init_parameter!(q_batched(, :theta, 0).0) + init_parameter!((q_batched, :theta), 0.0) init_parameter!((q_batched, :log_std), 0.0) optimizer = CompositeOptimizer(FixedStepGradientDescent(0.001), q_batched) score = Inf diff --git a/test/inference/variational.jl b/test/inference/variational.jl index 1084b4331..f778ac6f2 100644 --- a/test/inference/variational.jl +++ b/test/inference/variational.jl @@ -15,16 +15,17 @@ @trace(normal(slope_mu, exp(slope_log_std)), :slope) @trace(normal(intercept_mu, exp(intercept_log_std)), :intercept) end + register_parameters!(approx, [:slope_mu, :slope_log_std, :intercept_mu, :intercept_log_std]) # to regular black box variational inference - init_param!((approx, :slope_mu), 0.0) - init_param!((approx, :slope_log_std), 0.0) - init_param!((approx, :intercept_mu), 0.0) - init_param!((approx, :intercept_log_std), 0.0) + init_parameter!((approx, :slope_mu), 0.0) + init_parameter!((approx, :slope_log_std), 0.0) + init_parameter!((approx, :intercept_mu), 0.0) + init_parameter!((approx, :intercept_log_std), 0.0) observations = choicemap() - optimizer = CompositeOptimizer(GradientDescent(1, 100000), approx) - optimizer = CompositeOptimizer(GradientDescent(1., 1000), approx) + optimizer = CompositeOptimizer(DecayStepGradientDescent(1, 100000), approx) + optimizer = CompositeOptimizer(DecayStepGradientDescent(1., 1000), approx) black_box_vi!(model, (), observations, approx, (), optimizer; iters=2000, samples_per_iter=100, verbose=false) slope_mu = get_parameter_value((approx, :slope_mu)) @@ -37,17 +38,17 @@ @test isapprox(intercept_log_std, 2.0, atol=0.001) # smoke test for black box variational inference with Monte Carlo objectives - init_param!((approx, :slope_mu), 0.0) - init_param!((approx, :slope_log_std), 0.0) - init_param!((approx, :intercept_mu), 0.0) - init_param!((approx, :intercept_log_std), 0.0) + init_parameter!((approx, :slope_mu), 0.0) + init_parameter!((approx, :slope_log_std), 0.0) + init_parameter!((approx, :intercept_mu), 0.0) + init_parameter!((approx, :intercept_log_std), 0.0) black_box_vimco!(model, (), observations, approx, (), optimizer, 20; iters=50, samples_per_iter=100, verbose=false, geometric=false) - init_param!((approx, :slope_mu), 0.0) - init_param!((approx, :slope_log_std), 0.0) - init_param!((approx, :intercept_mu), 0.0) - init_param!((approx, :intercept_log_std), 0.0) + init_parameter!((approx, :slope_mu), 0.0) + init_parameter!((approx, :slope_log_std), 0.0) + init_parameter!((approx, :intercept_mu), 0.0) + init_parameter!((approx, :intercept_log_std), 0.0) black_box_vimco!(model, (), observations, approx, (), optimizer, 20; iters=50, samples_per_iter=100, verbose=false, geometric=true) @@ -66,6 +67,7 @@ end {(:x, i)} ~ normal(z, 1) end end + register_parameters!(model, [:theta]) @gen function approx(xs) @param mu_coeffs::Vector{Float64} # 2 x 1; should be [opt_theta / 2, 0.5] @@ -75,6 +77,7 @@ end {(:z, i)} ~ normal(mu, exp(log_std)) end end + register_parameters!(approx, [:mu_coeffs, :log_std]) observations = choicemap() xs = Float64[] @@ -97,7 +100,7 @@ end {(:z, i)} ~ normal(posterior_means[i], sqrt(1.0 / posterior_precisions)) end end - init_param!((model, :theta), opt_theta) + init_parameter!((model, :theta), opt_theta) approx_trace = simulate(optimum_approx, ()) (model_trace, _) = generate(model, (), merge(get_choices(approx_trace), observations)) # note that p(z1..zn, x1..xn) / p(z1..zn | x1..xn) = p(x1...xn) - for all z1..zn @@ -105,9 +108,9 @@ end println("true optimum log_marginal_likelihood: $log_marginal_likelihood") # using BBVI with score function estimator - init_param!((model, :theta), 0.0) - init_param!((approx, :mu_coeffs), zeros(2)) - init_param!((approx, :log_std), 0.0) + init_parameter!((model, :theta), 0.0) + init_parameter!((approx, :mu_coeffs), zeros(2)) + init_parameter!((approx, :log_std), 0.0) approx_optimizer = CompositeOptimizer(FixedStepGradientDescent(0.0001), approx) model_optimizer = CompositeOptimizer(FixedStepGradientDescent(0.002), model) @time (_, _, elbo_history, _) = @@ -120,9 +123,9 @@ end @test isapprox(elbo_history[end], log_marginal_likelihood, rtol=0.1) # using VIMCO - init_param!((model, :theta), 0.0) - init_param!((approx, :mu_coeffs), zeros(2)) - init_param!((approx, :log_std), 0.0) + init_parameter!((model, :theta), 0.0) + init_parameter!((approx, :mu_coeffs), zeros(2)) + init_parameter!((approx, :log_std), 0.0) approx_optimizer = CompositeOptimizer(FixedStepGradientDescent(0.001), approx) model_optimizer = CompositeOptimizer(FixedStepGradientDescent(0.01), model) @time (_, _, elbo_history, _) = diff --git a/test/modeling_library/map.jl b/test/modeling_library/map.jl index 3dd197d52..6f4476b31 100644 --- a/test/modeling_library/map.jl +++ b/test/modeling_library/map.jl @@ -5,6 +5,7 @@ z = @trace(normal(x + y, std), :z) return z end + register_parameters!(foo, [:std]) init_parameter!((foo, :std), 1.0) @@ -393,13 +394,13 @@ # get gradients wrt xs and ys trace = get_initial_trace() - zero_param_grad!(foo, :std) + reset_gradient!((foo, :std)) input_grads = accumulate_param_gradients!(trace, retval_grad) @test isapprox(input_grads[1], expected_xs_grad) @test isapprox(input_grads[2], expected_ys_grad) expected_std_grad = (logpdf_grad(normal, z1, 4., 1.)[3] + logpdf_grad(normal, z2, 6., 1.)[3]) - @test isapprox(get_param_grad(foo, :std), expected_std_grad) + @test isapprox(get_gradient((foo, :std)), expected_std_grad) end end diff --git a/test/modeling_library/switch.jl b/test/modeling_library/switch.jl index 8c183aa16..520c9879a 100644 --- a/test/modeling_library/switch.jl +++ b/test/modeling_library/switch.jl @@ -112,13 +112,15 @@ z = @trace(normal(x + y, std), :z) return z end - init_param!(bang1, :std, 3.0) + register_parameters!(bang1, [:std]) + init_parameter!((bang1, :std), 3.0) @gen (grad) function fuzz1((grad)(x::Float64), (grad)(y::Float64)) @param(std::Float64) z = @trace(normal(x + 2 * y, std), :z) return z end - init_param!(fuzz1, :std, 3.0) + register_parameters!(fuzz1, [:std]) + init_parameter!((fuzz1, :std), 3.0) sc = Switch(bang1, fuzz1) @gen (grad) function bam(s::Int) x ~ sc(s, 5.0, 3.0) @@ -195,15 +197,15 @@ for z in [1.0, 3.0, 5.0, 10.0] chm = choicemap((:z, z)) tr, _ = generate(bam, (1, ), chm) - zero_param_grad!(bang1, :std) + reset_gradient!((bang1, :std)) input_grads = accumulate_param_gradients!(tr, 1.0) expected_std_grad = logpdf_grad(normal, tr[:x => :z], 5.0 + 3.0, 3.0)[3] - @test isapprox(get_param_grad(bang1, :std), expected_std_grad) + @test isapprox(get_gradient((bang1, :std)), expected_std_grad) tr, _ = generate(bam, (2, ), chm) - zero_param_grad!(fuzz1, :std) + reset_gradient!((fuzz1, :std)) input_grads = accumulate_param_gradients!(tr, 1.0) expected_std_grad = logpdf_grad(normal, tr[:x => :z], 5.0 + 2 * 3.0, 3.0)[3] - @test isapprox(get_param_grad(fuzz1, :std), expected_std_grad) + @test isapprox(get_gradient((fuzz1, :std)), expected_std_grad) end end diff --git a/test/modeling_library/unfold.jl b/test/modeling_library/unfold.jl index 7e6da6907..2b4aaec7a 100644 --- a/test/modeling_library/unfold.jl +++ b/test/modeling_library/unfold.jl @@ -487,6 +487,7 @@ x = @trace(normal(x_prev * alpha + beta, std), :x) return x end + register_parameters!(kernel, [:std]) foo = Unfold(kernel) @@ -504,7 +505,7 @@ constraints[2 => :x] = x2 (trace, _) = generate(foo, (2, x_init, alpha, beta), constraints) - zero_param_grad!(kernel, :std) + reset_gradient!((kernel, :std)) input_grads = accumulate_param_gradients!(trace, nothing) @test input_grads[1] == nothing # length @test input_grads[2] == nothing # inital state @@ -512,7 +513,7 @@ #@test isapprox(input_grads[4], expected_ys_grad) # beta expected_std_grad = (logpdf_grad(normal, x1, x_init * alpha + beta, std)[3] + logpdf_grad(normal, x2, x1 * alpha + beta, std)[3]) - @test isapprox(get_param_grad(kernel, :std), expected_std_grad) + @test isapprox(get_gradient((kernel, :std)), expected_std_grad) end @gen (grad) function ker(t, (grad)(x::Float64)) diff --git a/test/optional_args.jl b/test/optional_args.jl index f4883d14d..9595c5706 100644 --- a/test/optional_args.jl +++ b/test/optional_args.jl @@ -8,6 +8,7 @@ using Gen b = @trace(normal(a+theta, 1), :b) return (x, y, z, x+y+z) end + register_parameters!(foo, [:theta]) # initialize theta to zero for non-gradient tests init_parameter!((foo, :theta), 0.0) diff --git a/test/runtests.jl b/test/runtests.jl index 15a62fb6b..6fddf20b4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -82,7 +82,7 @@ include("selection.jl") include("assignment.jl") include("gen_fn_interface.jl") include("dsl/dsl.jl") -#include("optional_args.jl") +include("optional_args.jl") include("static_ir/static_ir.jl") include("tilde_sugar.jl") include("inference/inference.jl") From c045614427125cab794328b63267384457ebf5d8 Mon Sep 17 00:00:00 2001 From: Marco Cusumano-Towner Date: Tue, 18 May 2021 12:09:58 -0400 Subject: [PATCH 10/24] finished pass through documentation --- docs/src/ref/learning.md | 26 ++-- docs/src/ref/modeling.md | 24 ++-- docs/src/ref/parameter_optimization.md | 85 ++++++++++--- src/builtin_optimization.jl | 0 src/dynamic/dynamic.jl | 11 +- src/gen_fn_interface.jl | 25 +--- src/inference/train.jl | 11 +- src/inference/variational.jl | 14 +- src/optimization.jl | 169 +++++++++++++++---------- test/inference/train.jl | 6 +- test/inference/variational.jl | 12 +- test/runtests.jl | 2 +- 12 files changed, 228 insertions(+), 157 deletions(-) delete mode 100644 src/builtin_optimization.jl diff --git a/docs/src/ref/learning.md b/docs/src/ref/learning.md index 464d6d284..efdc09c8a 100644 --- a/docs/src/ref/learning.md +++ b/docs/src/ref/learning.md @@ -51,10 +51,10 @@ end ``` Let's suppose we are training the generative model. -The first step is to initialize the values of the trainable parameters, which for generative functions constructed using the built-in modeling languages, we do with [`init_param!`](@ref): +The first step is to initialize the values of the trainable parameters, which for generative functions constructed using the built-in modeling languages, we do with [`init_parameter!`](@ref): ```julia -init_param!(model, :a, 0.) -init_param!(model, :b, 0.) +init_parameter!((model, :a), 0.0) +init_parameter!((model, :b), 0.0) ``` Each trace in the collection contains the observed data from an independent draw from our model. We can populate each trace with its observed data using [`generate`](@ref): @@ -76,24 +76,24 @@ for trace in traces accumulate_param_gradients!(trace) end ``` -Finally, we can construct and gradient-based update with [`ParamUpdate`](@ref) and apply it with [`apply!`](@ref). +Finally, we can construct and gradient-based update with [`init_optimizer`](@ref) and apply it with [`apply_update!`](@ref). We can put this all together into a function: ```julia function train_model(data::Vector{ChoiceMap}) - init_param!(model, :theta, 0.1) + init_parameter!((model, :theta), 0.1) traces = [] for observations in data trace, = generate(model, model_args, observations) push!(traces, trace) end - update = ParamUpdate(FixedStepSizeGradientDescent(0.001), model) + optimizer = init_optimizer(FixedStepGradientDescent(0.001), model) for iter=1:max_iter objective = sum([get_score(trace) for trace in traces]) println("objective: $objective") for trace in traces accumulate_param_gradients!(trace) end - apply!(update) + apply_update!(optimizer) end end ``` @@ -139,14 +139,14 @@ There are many variants possible, based on which Monte Carlo inference algorithm For example: ```julia function train_model(data::Vector{ChoiceMap}) - init_param!(model, :theta, 0.1) - update = ParamUpdate(FixedStepSizeGradientDescent(0.001), model) + init_parameter!((model, :theta), 0.1) + optimizer = init_optimizer(FixedStepGradientDescent(0.001), model) for iter=1:max_iter traces = do_monte_carlo_inference(data) for trace in traces accumulate_param_gradients!(trace) end - apply!(update) + apply_update!(optimizer) end end @@ -160,14 +160,14 @@ end Note that it is also possible to use a weighted collection of traces directly without resampling: ```julia function train_model(data::Vector{ChoiceMap}) - init_param!(model, :theta, 0.1) - update = ParamUpdate(FixedStepSizeGradientDescent(0.001), model) + init_parameter!((model, :theta), 0.1) + optimizer = init_optimizer(FixedStepGradientDescent(0.001), model) for iter=1:max_iter traces, weights = do_monte_carlo_inference_with_weights(data) for (trace, weight) in zip(traces, weights) accumulate_param_gradients!(trace, nothing, weight) end - apply!(update) + apply_update!(optimizer) end end ``` diff --git a/docs/src/ref/modeling.md b/docs/src/ref/modeling.md index 15be1990e..d55015284 100644 --- a/docs/src/ref/modeling.md +++ b/docs/src/ref/modeling.md @@ -254,6 +254,7 @@ See [Generative Function Interface](@ref) for more information about traces. A `@gen` function may begin with an optional block of *trainable parameter declarations*. The block consists of a sequence of statements, beginning with `@param`, that declare the name and Julia type for each trainable parameter. +The Julia type must be either a subtype of `Real` or subtype of `Array{<:Real}`. The function below has a single trainable parameter `theta` with type `Float64`: ```julia @gen function foo(prob::Float64) @@ -264,23 +265,22 @@ The function below has a single trainable parameter `theta` with type `Float64`: end ``` Trainable parameters obey the same scoping rules as Julia local variables defined at the beginning of the function body. -The value of a trainable parameter is undefined until it is initialized using [`init_param!`](@ref). +After the definition of the generative function, you must register all of the parameters used by the generative function using [`register_parameters!`](@ref) (this is not required if you instead use the [Static Modeling Language](@ref)): +```julia +register_parameters!(foo, [:theta]) +``` +The value of a trainable parameter is undefined until it is initialized using [`init_parameter!`](@ref): +```julia +init_parameter!((foo, :theta), 0.0) +``` In addition to the current value, each trainable parameter has a current **gradient accumulator** value. The gradient accumulator value has the same shape (e.g. array dimension) as the parameter value. -It is initialized to all zeros, and is incremented by [`accumulate_param_gradients!`](@ref). - -The following methods are exported for the trainable parameters of `@gen` functions: +It is initialized to all zeros, and is incremented by calling [`accumulate_param_gradients!`](@ref) on a trace. +Additional functions for retrieving and manipulating the values of trainable parameters and their gradient accumulators are described in [Optimizing Trainable Parameters](@ref). ```@docs -init_param! -get_param -get_param_grad -set_param! -zero_param_grad! +register_parameters! ``` -Trainable parameters are designed to be trained using gradient-based methods. -This is discussed in the next section. - ## Differentiable programming Given a trace of a `@gen` function, Gen supports automatic differentiation of the log probability (density) of all of the random choices made in the trace with respect to the following types of inputs: diff --git a/docs/src/ref/parameter_optimization.md b/docs/src/ref/parameter_optimization.md index 60e05f41d..943fbbaa8 100644 --- a/docs/src/ref/parameter_optimization.md +++ b/docs/src/ref/parameter_optimization.md @@ -1,33 +1,82 @@ # Optimizing Trainable Parameters -Trainable parameters of generative functions are initialized differently depending on the type of generative function. -Trainable parameters of the built-in modeling language are initialized with [`init_param!`](@ref). +## Parameter stores -Gradient-based optimization of the trainable parameters of generative functions is based on interleaving two steps: +Multiple traces of a generative function typically reference the same trainable parameters of the generative function, which are stored outside of the trace in a **parameter store**. +Different types of generative functions may use different types of parameter stores. +For example, the [`JuliaParameterStore`](@ref) (discussed below) stores parameters as Julia values in the memory of the Julia runtime process. +Other types of parameter stores may store parameters in GPU memory, in a filesystem, or even remotely. -- Incrementing gradient accumulators for trainable parameters by calling [`accumulate_param_gradients!`](@ref) on one or more traces. +When generating a trace of a generative function with [`simulate`](@ref) or [`generate`](@ref), we may pass in an optional **parameter context**, which is a `Dict` that provides information about which parameter store(s) in which to look up the value of parameters. +A generative function obtains a reference to a specific type of parameter store by looking up its key in the parameter context. -- Updating the value of trainable parameters and resetting the gradient accumulators to zero, by calling [`apply!`](@ref) on a *parameter update*, as described below. +If you are just learning Gen, and are only using the built-in modeling language to write generative functions, you can ignore this complexity, because there is a [`default_julia_parameter_store`](@ref) and a default parameter context [`default_parameter_context`](@ref) that points to this default Julia parameter store that will be used if a parameter context is not provided in the call to `simulate` and `generate`. +```@docs +default_parameter_context +default_julia_parameter_store +``` + +## Julia parameter store + +Parameters declared using the `@param` keyword in the built-in modeling language are stored in a type of parameter store called a [`JuliaParameterStore`](@ref). +A generative function can obtain a reference to a `JuliaParameterStore` by looking up the key [`JULIA_PARAMETER_STORE_KEY`](@ref) in a parameter context. +This is how the built-in modeling language implementation finds the parameter stores to use for `@param`-declared parameters. +Note that if you are defining your own [custom generative functions](@ref #Custom-generative-functions), you can also use a [`JuliaParameterStore`](@ref) (including the same parameter store used to store parameters of built-in modeling language generative functions) to store and optimize your trainable parameters. -## Parameter update +Different types of parameter stores provide different APIs for reading, writing, and updating the values of parameters and gradient accumulators for parameters. +The `JuliaParameterStore` API is given below. +The API uses tuples of the form `(gen_fn::GenerativeFunction, name::Symbol)` to identify parameters. +(Note that most user learning code only needs to use [`init_parameter!`](@ref), as the other API functions are called by [Optimizers](@ref) which are discussed below.) -A *parameter update* reads from the gradient accumulators for certain trainable parameters, updates the values of those parameters, and resets the gradient accumulators to zero. -A paramter update is constructed by combining an *update configuration* with the set of trainable parameters to which the update should be applied: ```@docs -ParamUpdate +JuliaParameterStore +init_parameter! +increment_gradient! +reset_gradient! +get_parameter_value +get_gradient +JULIA_PARAMETER_STORE_KEY ``` -The set of possible update configurations is described in [Update configurations](@ref). -An update is applied with: + +### Multi-threaded gradient accumulation + +Note that the [`increment_gradient!`](@ref) call is thread-safe, so that multiple threads can concurrently increment the gradient for the same parameters. This is helpful for parallelizing gradient computation for a batch of traces within stochastic gradient descent learning algorithms. + +## Optimizers + +Gradient-based optimization typically involves iterating between two steps: +(i) computing gradients or estimates of gradients with respect to parameters, and +(ii) updating the value of the parameters based on the gradient estimates according to some mathematical rule. +Sometimes the optimization algorithm also has its own state that is separate from the value of the parameters and the gradient estimates. +Gradient-based optimization algorithms in Gen are implemented by **optimizers**. +Each type of parameter store provides implementations of optimizers for standard mathematical update rules. + +The mathematical rules are defined in **optimizer configuration** objects. +The currently supported optimizer configurations are: ```@docs -apply! +FixedStepGradientDescent +DecayStepGradientDescent +``` + +The most common way to construct an optimizer is via: +```julia +optimizer = init_optimizer(conf, gen_fn) ``` +which returns an optimizer that applies the mathematical rule defined by `conf` to all parameters used by `gen_fn` (even when the generative function uses parameters that are housed in multiple parameter stores). +You can also pass a parameter context keyword argument to customize the parameter store(s) that the optimizer should use. +Then, after accumulating gradients with [`accumulate_param_gradients!`](@ref), you can apply the update with: +```julia +apply_update!(optimizer) +``` + +The `init_optimizer` method described above constructs an optimizer that actually invokes multiple optimizers, one for each parameter store. +To add support to a parameter store type for a new optimizer configuration type, you must implement the per-parameter-store optimizer methods: -## Update configurations +- `init_optimizer(conf, parameter_ids, store)`, which takes in an optimizer configuration object, and list of parameter IDs, and the parameter store in which to apply the updates, and returns an optimizer thata mutates the given parameter store. + +- `apply_update!(optimizer)`, which takes in an a single argument (the optimizer) and applies its update rule, which mutates the value of the parameters in its parameter store (and typically also resets the values of the gradient accumulators to zero). -Gen has built-in support for the following types of update configurations. ```@docs -FixedStepGradientDescent -GradientDescent -ADAM +init_optimizer +apply_update! ``` -For adding new types of update configurations, see [Optimizing Trainable Parameters (Internal)](@ref optimizing-internal). diff --git a/src/builtin_optimization.jl b/src/builtin_optimization.jl deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/dynamic/dynamic.jl b/src/dynamic/dynamic.jl index 3cf48963f..2db4a12cc 100644 --- a/src/dynamic/dynamic.jl +++ b/src/dynamic/dynamic.jl @@ -56,13 +56,14 @@ end """ register_parameters!(gen_fn::DynamicDSLFunction, parameters) -Register the altrainable parameters that are used by a DML generative function. +Register the trainable parameters that used by a DML generative function. -This includes all parameters used within any calls made by the generative function. +This includes all parameters used within any calls made by the generative function, and includes any parameters that may be used by any possible trace (stochastic control flow may cause a parameter to be used by one trace but not another). -There are two variants: - -# TODO document the variants +The second argument is either a `Vector` or a `Function` that takes a parameter context and returns a `Dict` that maps parameter stores to `Vector`s of parameter IDs. +When the second argument is a `Vector`, each element is either a `Symbol` that is the name of a parameter declared in the body of `gen_fn` using `@param`, or is a tuple `(other_gen_fn::GenerativeFunction, name::Symbol)` where `@param ` was declared in the body of `other_gen_fn`. +The `Function` input is used when `gen_fn` uses parameters that come from more than one parameter store, including parameters that are housed in parameter stores that are not `JuliaParameterStore`s (e.g. if `gen_fn` invokes a generative function that executes in another non-Julia runtime). +See [Optimizing Trainable Parameters](@ref) for details on parameter contexts, and parameter stores. """ function register_parameters!(gen_fn::DynamicDSLFunction, parameters) gen_fn.parameters = parameters diff --git a/src/gen_fn_interface.jl b/src/gen_fn_interface.jl index 0728e4285..f85b5b559 100644 --- a/src/gen_fn_interface.jl +++ b/src/gen_fn_interface.jl @@ -160,7 +160,7 @@ Return an iterable over the trainable parameters of the generative function. get_params(::GenerativeFunction) = () """ - trace = simulate(gen_fn, args, parameter_context=Dict()) + trace = simulate(gen_fn, args, parameter_context=default_parameter_context) Execute the generative function and return the trace. @@ -175,14 +175,10 @@ function simulate(::GenerativeFunction, ::Tuple, parameter_context::Dict) error("Not implemented") end -function simulate(gen_fn::GenerativeFunction, args::Tuple) - return simulate(gen_fn, args, Dict()) -end - """ (trace::U, weight) = generate( gen_fn::GenerativeFunction{T,U}, args::Tuple, - constraints=EmptyChoiceMap(), parameter_context=Dict()) + constraints=EmptyChoiceMap(), parameter_context=default_parameter_context) Return a trace of a generative function that is consistent with the given constraints on the random choices, if any. @@ -212,14 +208,6 @@ function generate(::GenerativeFunction, ::Tuple, ::ChoiceMap, parameter_context: error("Not implemented") end -function generate(gen_fn::GenerativeFunction, args::Tuple, choices::ChoiceMap) - return generate(gen_fn, args, choices, Dict()) -end - -function generate(gen_fn::GenerativeFunction, args::Tuple) - return generate(gen_fn, args, EmptyChoiceMap(), Dict()) -end - """ weight = project(trace::U, selection::Selection) @@ -241,7 +229,7 @@ end """ (choices, weight, retval) = propose( - gen_fn::GenerativeFunction, args::Tuple, parameter_context=Dict()) + gen_fn::GenerativeFunction, args::Tuple, parameter_context=default_parameter_context) Sample an assignment and compute the probability of proposing that assignment. @@ -258,12 +246,11 @@ function propose(gen_fn::GenerativeFunction, args::Tuple, parameter_context::Dic return (get_choices(trace), weight, get_retval(trace)) end -propose(gen_fn::GenerativeFunction, args::Tuple) = propose(gen_fn, args, Dict()) """ (weight, retval) = assess( gen_fn::GenerativeFunction, args::Tuple, choices::ChoiceMap, - parameter_context=Dict()) + parameter_context=default_parameter_context) Return the probability of proposing an assignment @@ -281,10 +268,6 @@ function assess( return (weight, get_retval(trace)) end -function assess(gen_fn::GenerativeFunction, args::Tuple, choices::ChoiceMap) - return assess(gen_fn, args, choices, Dict()) -end - """ (new_trace, weight, retdiff, discard) = update( diff --git a/src/inference/train.jl b/src/inference/train.jl index 52735bc89..8c2bf8828 100644 --- a/src/inference/train.jl +++ b/src/inference/train.jl @@ -1,7 +1,8 @@ """ train!(gen_fn::GenerativeFunction, data_generator::Function, - optimizer::CompositeOptimizer, - num_epoch, epoch_size, num_minibatch, minibatch_size; verbose::Bool=false) + optimizer; + num_epoch=1, epoch_size=1, num_minibatch=1, minibatch_size=1; + verbose::Bool=false) Train the given generative function to maximize the expected conditional log probability (density) that `gen_fn` generates the assignment `constraints` @@ -22,7 +23,7 @@ taken under the marginal distribution on `inputs` determined by the data generator. """ function train!(gen_fn::GenerativeFunction, data_generator::Function, - optimizer::CompositeOptimizer; + optimizer; num_epoch=1, epoch_size=1, num_minibatch=1, minibatch_size=1, evaluation_size=epoch_size, verbose=false, callback=(epoch, minibatch, minibatch_objective) -> nothing) @@ -101,7 +102,7 @@ function lecture!( q_args = get_q_args(p_trace) q_trace, score = generate(q, q_args, get_choices(p_trace)) # NOTE: q won't make all the random choices that p does accumulate_param_gradients!(q_trace) - score + return score end """ @@ -128,7 +129,7 @@ function lecture_batched!( q_args = get_q_args(p_traces) q_trace, score = generate(q_batched, q_args, constraints) # NOTE: q won't make all the random choices that p does accumulate_param_gradients!(q_trace) - score / batch_size + return score / batch_size end export train! diff --git a/src/inference/variational.jl b/src/inference/variational.jl index 610140ed8..0be8f7692 100644 --- a/src/inference/variational.jl +++ b/src/inference/variational.jl @@ -88,7 +88,7 @@ function multi_sample_gradient_estimate!( (L, traces, weights_normalized) end -function _maybe_accumulate_param_grad!(trace, optimizer::CompositeOptimizer, scale_factor::Real) +function _maybe_accumulate_param_grad!(trace, optimizer, scale_factor::Real) return accumulate_param_gradients!(trace, nothing, scale_factor) end @@ -98,10 +98,10 @@ end """ (elbo_estimate, traces, elbo_history) = black_box_vi!( model::GenerativeFunction, model_args::Tuple, - [model_optimizer::CompositeOptimizer,] + [model_optimizer,] observations::ChoiceMap, var_model::GenerativeFunction, var_model_args::Tuple, - var_model_optimizer::CompositeOptimizer; + var_model_optimizer; options...) Fit the parameters of a variational model (`var_model`) to the posterior @@ -120,10 +120,10 @@ update the parameters of `model`. """ function black_box_vi!( model::GenerativeFunction, model_args::Tuple, - model_optimizer::Union{CompositeOptimizer,Nothing}, + model_optimizer, observations::ChoiceMap, var_model::GenerativeFunction, var_model_args::Tuple, - var_model_optimizer::CompositeOptimizer; + var_model_optimizer; iters=1000, samples_per_iter=100, verbose=false, callback=(iter, traces, elbo_estimate) -> nothing) @@ -173,14 +173,14 @@ end black_box_vi!(model::GenerativeFunction, model_args::Tuple, observations::ChoiceMap, var_model::GenerativeFunction, var_model_args::Tuple, - var_model_optimizer::CompositeOptimizer; options...) = + var_model_optimizer; options...) = black_box_vi!(model, model_args, nothing, observations, var_model, var_model_args, var_model_optimizer; options...) """ (iwelbo_estimate, traces, iwelbo_history) = black_box_vimco!( model::GenerativeFunction, model_args::Tuple, - [model_optimizer::CompositeOptimizer,] + [model_optimizer,] observations::ChoiceMap, var_model::GenerativeFunction, var_model_args::Tuple, var_model_optimizer::CompositeOptimizer, diff --git a/src/optimization.jl b/src/optimization.jl index 107c7e201..2f45d239a 100644 --- a/src/optimization.jl +++ b/src/optimization.jl @@ -1,12 +1,10 @@ import Parameters -# we should modify the semantics of the log probability contribution to the gradient -# so that everything is gradient descent instead of ascent. this will also fix -# the misnomer names +# TODO we should modify the semantics of the log probability contribution to +# the gradient so that everything is gradient descent instead of ascent. this +# will also fix the misnomer names # # TODO add tests specifically for JuliaParameterStore etc. -# -# TODO in all update and regenerate implementations, need to pass in the parameter context to inner calls to generate export in_place_add! @@ -14,7 +12,6 @@ export FixedStepGradientDescent export DecayStepGradientDescent export init_optimizer export apply_update! -export CompositeOptimizer export JuliaParameterStore export init_parameter! @@ -22,24 +19,10 @@ export increment_gradient! export reset_gradient! export get_parameter_value export get_gradient +export JULIA_PARAMETER_STORE_KEY -################# -# in_place_add! # -################# - -function in_place_add! end - -function in_place_add!(value::Array, increment) - @simd for i in 1:length(value) - value[i] += increment[i] - end - return value -end - -# this exists so user can use the same function on scalars and arrays -function in_place_add!(value::Real, increment::Real) - return value + increment -end +export default_julia_parameter_store +export default_parameter_context ############################ # optimizer specifications # @@ -64,8 +47,27 @@ Parameters.@with_kw struct DecayStepGradientDescent step_size_beta::Float64 end +# TODO add gradient descent with momentum +# TODO add update + +################# +# in_place_add! # +################# + +function in_place_add! end + +function in_place_add!(value::Array, increment) + @simd for i in 1:length(value) + value[i] += increment[i] + end + return value +end + +# this exists so user can use the same function on scalars and arrays +function in_place_add!(value::Real, increment::Real) + return value + increment +end -# TODO add ADAM update ########################### # thread-safe accumulator # @@ -152,23 +154,16 @@ end # parameter stores and optimizers # ################################### -# TODO create diagram and document the overal framework -# including parameter contexts and parameter stores,and the default beahviors - -abstract type ParameterStore end - """ optimizer = init_optimizer( - conf, parameter_ids, + conf, parameter_ids::Vector, store=default_julia_parameter_store) -Initialize an iterative gradient-based optimizer. +Initialize an iterative gradient-based optimizer that mutates a single parameter store. The first argument defines the mathematical behavior of the update, the second argument defines the set of parameters to which the update should be applied at each iteration, and the third argument gives the location of the parameter values and their gradient accumulators. -See [`apply_update!`](@ref). - -Not thread-safe. +Add support for new parameter store types or new optimization configurations by (i) defining a new Julia type for an optimizer that applies the configuration type to the parameter store type, and (ii) implementing this method so that it returns an instance of your new type, and (iii) implement `apply_update!` for your new type. """ function init_optimizer(conf, parameter_ids, store=default_julia_parameter_store) error("Not implemented") @@ -179,32 +174,16 @@ end Apply one iteration of a gradient-based optimization update. -See [`init_optimizer!`](@ref). - -Not thread-safe. +Extend this method to add support for new parameter store types or new optimization configurations. """ function apply_update!(optimizer) error("Not implemented") end -""" - - optimizer = CompositeOptimizer(conf, parameter_stores_to_ids::Dict{Any,Vector}) - -Construct an optimizer that applies the given update to parameters in multiple parameter stores. - -The first argument defines the mathematical behavior of the update; -the second argument defines the set of parameters to which the update should be applied at each iteration, -as a map from parameter stores to a vector of IDs of parameters within that parameter store. - - optimizer = CompositeOptimizer(conf, gen_fn::GenerativeFunction; parameter_context=default_parameter_context) - -Constructs a composite optimizer that applies the given update to all parameters used by the given generative function, even when the parameters exist in multiple parameter stores. -""" struct CompositeOptimizer conf::Any optimizers::Dict{Any,Any} - function CompositeOptimizer(conf, parameter_stores_to_ids::Dict{Any,Vector}) + function CompositeOptimizer(conf, parameter_stores_to_ids) optimizers = Dict{Any,Any}() for (store, parameter_ids) in parameter_stores_to_ids optimizers[store] = init_optimizer(conf, parameter_ids, store) @@ -213,15 +192,32 @@ struct CompositeOptimizer end end -function CompositeOptimizer(conf, gen_fn::GenerativeFunction; parameter_context=default_parameter_context) - return CompositeOptimizer(conf, get_parameters(gen_fn, parameter_context)) +""" + + optimizer = init_optimizer(conf, parameter_stores_to_ids::Dict{Any,Vector}) + +Construct a special type of optimizer that updates parameters in multiple parameter stores. + +The second argument defines the set of parameters to which the update should be applied at each iteration, +The parameters are given in a map from parameter store to a vector of IDs of parameters within that parameter store. + +NOTE: You do _not_ need to extend this method to extend support for new parameter store types or new optimization configurations. +""" +function init_optimizer(conf, parameter_stores_to_ids::Dict) + return CompositeOptimizer(conf, parameter_stores_to_ids) end """ - apply_update!(composite_opt::ComposieOptimizer) + optimizer = init_optimizer( + conf, gen_fn::GenerativeFunction; parameter_context=default_parameter_context) -Perform one step of an update, possibly mutating the values of parameters in multiple parameter stores. +Convenience method that constructs an optimizer that updates all parameters used by the given generative function, even when the parameters exist in multiple parameter stores. """ +function init_optimizer(conf, gen_fn::GenerativeFunction; parameter_context=default_parameter_context) + return CompositeOptimizer(conf, get_parameters(gen_fn, parameter_context)) +end + + function apply_update!(composite_opt::CompositeOptimizer) for opt in values(composite_opt.optimizers) apply_update!(opt) @@ -247,7 +243,7 @@ Construct a parameter store stores the state of parameters in the memory of the There is a global Julia parameter store automatically created and named `Gen.default_julia_parameter_store`. -Incrementing the gradients can be safely multi-threaded (see [`increment_gradient!`](@ref)). +Gradient accumulation is thread-safe (see [`increment_gradient!`](@ref)). """ function JuliaParameterStore() return JuliaParameterStore( @@ -263,21 +259,27 @@ function get_local_parameters(store::JuliaParameterStore, gen_fn) end end -const default_parameter_context = Dict{Symbol,Any}() -const default_julia_parameter_store = JuliaParameterStore() - # for looking up in a parameter context when tracing (simulate, generate) # once a trace is generated, it is bound to use a particular store +""" + JULIA_PARAMETER_STORE_KEY + +If a parameter context contains a value for this key, then the value is a `JuliaParameterStore`. +""" const JULIA_PARAMETER_STORE_KEY = :julia_parameter_store function get_julia_store(context::Dict) - if haskey(context, JULIA_PARAMETER_STORE_KEY) - return context[JULIA_PARAMETER_STORE_KEY] - else - return default_julia_parameter_store - end + return context[JULIA_PARAMETER_STORE_KEY]::JuliaParameterStore end +""" + default_julia_parameter_store::JuliaParameterStore + +The default global Julia parameter store. +""" +const default_julia_parameter_store = JuliaParameterStore() + + """ init_parameter!( id::Tuple{GenerativeFunction,Symbol}, value, @@ -285,7 +287,7 @@ end Initialize the the value of a named trainable parameter of a generative function. -Also generates the gradient accumulator for that parameter to `zero(value)`. +Also initializes the gradient accumulator for that parameter to `zero(value)`. Example: ```julia @@ -520,3 +522,38 @@ function apply_update!(opt::DecayStepGradientDescentJulia) end # TODO implement other optimizers (ADAM, etc.) + +############################# +# default parameter context # +############################# + + +""" + default_parameter_context::Dict + +The default global parameter context, which is initialized to contain the mapping: + + JULIA_PARAMETER_STORE_KEY => Gen.default_julia_parameter_store +""" +const default_parameter_context = Dict{Symbol,Any}( + JULIA_PARAMETER_STORE_KEY => default_julia_parameter_store) + +function simulate(gen_fn::GenerativeFunction, args::Tuple) + return simulate(gen_fn, args, default_parameter_context) +end + +function generate(gen_fn::GenerativeFunction, args::Tuple, choices::ChoiceMap) + return generate(gen_fn, args, choices, default_parameter_context) +end + +function generate(gen_fn::GenerativeFunction, args::Tuple) + return generate(gen_fn, args, EmptyChoiceMap(), default_parameter_context) +end + +function assess(gen_fn::GenerativeFunction, args::Tuple, choices::ChoiceMap) + return assess(gen_fn, args, choices, default_parameter_context) +end + +propose(gen_fn::GenerativeFunction, args::Tuple) = propose(gen_fn, args, default_parameter_context) + + diff --git a/test/inference/train.jl b/test/inference/train.jl index 277cb2817..4a09c217e 100644 --- a/test/inference/train.jl +++ b/test/inference/train.jl @@ -109,7 +109,7 @@ end # use stochastic gradient descent - optimizer = CompositeOptimizer(DecayStepGradientDescent(0.01, 1000000), student) + optimizer = init_optimizer(DecayStepGradientDescent(0.01, 1000000), student) train!(student, data_generator, optimizer, num_epoch=2000, epoch_size=50, num_minibatch=1, minibatch_size=50, verbose=false) @@ -148,7 +148,7 @@ end # train simple q using lecture! to compute gradients init_parameter!((q, :theta), 0.0) init_parameter!((q, :log_std), 0.0) - optimizer = CompositeOptimizer(FixedStepGradientDescent(1e-4), q) + optimizer = init_optimizer(FixedStepGradientDescent(1e-4), q) score = Inf for iter=1:100 score = sum([lecture!(p, (), q, tr -> (tr[:x],)) for _=1:1000]) / 1000 @@ -171,7 +171,7 @@ end # train simple q using lecture_batched! to compute gradients init_parameter!((q_batched, :theta), 0.0) init_parameter!((q_batched, :log_std), 0.0) - optimizer = CompositeOptimizer(FixedStepGradientDescent(0.001), q_batched) + optimizer = init_optimizer(FixedStepGradientDescent(0.001), q_batched) score = Inf for iter=1:100 score = lecture_batched!(p, (), q_batched, trs -> (map(tr -> tr[:x], trs),), 1000) diff --git a/test/inference/variational.jl b/test/inference/variational.jl index f778ac6f2..f3e4037fe 100644 --- a/test/inference/variational.jl +++ b/test/inference/variational.jl @@ -24,8 +24,8 @@ init_parameter!((approx, :intercept_log_std), 0.0) observations = choicemap() - optimizer = CompositeOptimizer(DecayStepGradientDescent(1, 100000), approx) - optimizer = CompositeOptimizer(DecayStepGradientDescent(1., 1000), approx) + optimizer = init_optimizer(DecayStepGradientDescent(1, 100000), approx) + optimizer = init_optimizer(DecayStepGradientDescent(1., 1000), approx) black_box_vi!(model, (), observations, approx, (), optimizer; iters=2000, samples_per_iter=100, verbose=false) slope_mu = get_parameter_value((approx, :slope_mu)) @@ -111,8 +111,8 @@ end init_parameter!((model, :theta), 0.0) init_parameter!((approx, :mu_coeffs), zeros(2)) init_parameter!((approx, :log_std), 0.0) - approx_optimizer = CompositeOptimizer(FixedStepGradientDescent(0.0001), approx) - model_optimizer = CompositeOptimizer(FixedStepGradientDescent(0.002), model) + approx_optimizer = init_optimizer(FixedStepGradientDescent(0.0001), approx) + model_optimizer = init_optimizer(FixedStepGradientDescent(0.002), model) @time (_, _, elbo_history, _) = black_box_vi!(model, (), model_optimizer, observations, approx, (xs,), approx_optimizer; @@ -126,8 +126,8 @@ end init_parameter!((model, :theta), 0.0) init_parameter!((approx, :mu_coeffs), zeros(2)) init_parameter!((approx, :log_std), 0.0) - approx_optimizer = CompositeOptimizer(FixedStepGradientDescent(0.001), approx) - model_optimizer = CompositeOptimizer(FixedStepGradientDescent(0.01), model) + approx_optimizer = init_optimizer(FixedStepGradientDescent(0.001), approx) + model_optimizer = init_optimizer(FixedStepGradientDescent(0.01), model) @time (_, _, elbo_history, _) = black_box_vimco!(model, (), model_optimizer, observations, approx, (xs,), approx_optimizer, 10; diff --git a/test/runtests.jl b/test/runtests.jl index 6fddf20b4..15a62fb6b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -82,7 +82,7 @@ include("selection.jl") include("assignment.jl") include("gen_fn_interface.jl") include("dsl/dsl.jl") -include("optional_args.jl") +#include("optional_args.jl") include("static_ir/static_ir.jl") include("tilde_sugar.jl") include("inference/inference.jl") From a4952dee1cd1bb766e40d1b7445ebeaf8b546a3a Mon Sep 17 00:00:00 2001 From: Marco Cusumano-Towner Date: Tue, 18 May 2021 13:49:13 -0400 Subject: [PATCH 11/24] re-enable optional args test --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 15a62fb6b..6fddf20b4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -82,7 +82,7 @@ include("selection.jl") include("assignment.jl") include("gen_fn_interface.jl") include("dsl/dsl.jl") -#include("optional_args.jl") +include("optional_args.jl") include("static_ir/static_ir.jl") include("tilde_sugar.jl") include("inference/inference.jl") From 65a41a94948a161cb24e74d13b8eb358a5e69735 Mon Sep 17 00:00:00 2001 From: Marco Cusumano-Towner Date: Tue, 18 May 2021 16:53:07 -0400 Subject: [PATCH 12/24] add multithreading to importance sampling --- src/inference/importance.jl | 67 ++++++++++++++++++++++++++++--------- 1 file changed, 51 insertions(+), 16 deletions(-) diff --git a/src/inference/importance.jl b/src/inference/importance.jl index fac3c0e4d..c0103c5b4 100644 --- a/src/inference/importance.jl +++ b/src/inference/importance.jl @@ -17,14 +17,20 @@ The second variant uses a custom proposal distribution defined by the given gene All addresses of random choices sampled by the proposal should also be sampled by the model function. Setting `verbose=true` prints a progress message every sample. """ -function importance_sampling(model::GenerativeFunction{T,U}, model_args::Tuple, - observations::ChoiceMap, - num_samples::Int, verbose=false) where {T,U} +function importance_sampling( + model::GenerativeFunction{T,U}, model_args::Tuple, + observations::ChoiceMap, num_samples::Int; + verbose=false, multithreaded=false) where {T,U} traces = Vector{U}(undef, num_samples) log_weights = Vector{Float64}(undef, num_samples) - for i=1:num_samples - verbose && println("sample: $i of $num_samples") - (traces[i], log_weights[i]) = generate(model, model_args, observations) + if multithreaded + Threads.@threads for i in 1:num_samples + importance_sampling_iter!(traces, log_weights, model, model_args, observations, i, verbose) + end + else + for i=1:num_samples + importance_sampling_iter!(traces, log_weights, model, model_args, observations, i, verbose) + end end log_total_weight = logsumexp(log_weights) log_ml_estimate = log_total_weight - log(num_samples) @@ -32,18 +38,33 @@ function importance_sampling(model::GenerativeFunction{T,U}, model_args::Tuple, return (traces, log_normalized_weights, log_ml_estimate) end -function importance_sampling(model::GenerativeFunction{T,U}, model_args::Tuple, - observations::ChoiceMap, - proposal::GenerativeFunction, proposal_args::Tuple, - num_samples::Int, verbose=false) where {T,U} +function importance_sampling_iter!( + traces::Vector, log_weights::Vector{Float64}, + model::GenerativeFunction, model_args::Tuple, + observations::ChoiceMap, i::Int, verbose::Bool) + (traces[i], log_weights[i]) = generate(model, model_args, observations) + verbose && Core.println("sample: $i of $num_samples completed in thread $(Threads.threadid())") + return nothing +end + +function importance_sampling( + model::GenerativeFunction{T,U}, model_args::Tuple, + observations::ChoiceMap, proposal::GenerativeFunction, proposal_args::Tuple, + num_samples::Int; verbose=false, multithreaded=false) where {T,U} traces = Vector{U}(undef, num_samples) log_weights = Vector{Float64}(undef, num_samples) - for i=1:num_samples - verbose && println("sample: $i of $num_samples") - (proposed_choices, proposal_weight, _) = propose(proposal, proposal_args) - constraints = merge(observations, proposed_choices) - (traces[i], model_weight) = generate(model, model_args, constraints) - log_weights[i] = model_weight - proposal_weight + if multithreaded + Threads.@threads for i=1:num_samples + importance_sampling_iter!( + traces, log_weights, model, model_args, + observations, proposal, proposal_args, i, verbose) + end + else + for i=1:num_samples + importance_sampling_iter!( + traces, log_weights, model, model_args, + observations, proposal, proposal_args, i, verbose) + end end log_total_weight = logsumexp(log_weights) log_ml_estimate = log_total_weight - log(num_samples) @@ -51,6 +72,20 @@ function importance_sampling(model::GenerativeFunction{T,U}, model_args::Tuple, return (traces, log_normalized_weights, log_ml_estimate) end +function importance_sampling_iter!( + traces::Vector, log_weights::Vector{Float64}, + model::GenerativeFunction, model_args::Tuple, + observations::ChoiceMap, + proposal::GenerativeFunction, proposal_args::Tuple, + i::Int, verbose::Bool) + (proposed_choices, proposal_weight, _) = propose(proposal, proposal_args) + constraints = merge(observations, proposed_choices) + (traces[i], model_weight) = generate(model, model_args, constraints) + log_weights[i] = model_weight - proposal_weight + verbose && Core.println("sample $i of $num_samples completed in thread $(Threads.threadid())") + return nothing +end + """ (trace, lml_est) = importance_resampling(model::GenerativeFunction, model_args::Tuple, observations::ChoiceMap, num_samples::Int, From 3cc2d27e4d2135db953282f12d4e35421bfd9433 Mon Sep 17 00:00:00 2001 From: Marco Cusumano-Towner Date: Tue, 18 May 2021 16:54:29 -0400 Subject: [PATCH 13/24] add unit tests for parameter optimization primitives --- test/optimization.jl | 86 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 86 insertions(+) create mode 100644 test/optimization.jl diff --git a/test/optimization.jl b/test/optimization.jl new file mode 100644 index 000000000..ff6a7bb80 --- /dev/null +++ b/test/optimization.jl @@ -0,0 +1,86 @@ +@testset "optimization" begin + +@testset "in_place_add!" begin + # TODO +end + +@testset "Accumulator" begin + # TODO +end + +@testset "Julia parameter store" begin + + store = JuliaParameterStore() + + @gen function foo() + @param theta::Float64 + @param phi::Vector{Float64} + end + register_parameters!(foo, [:theta, :phi]) + + # before the parameters are initialized in the store + + @test Gen.get_local_parameters(store, foo) == Dict{Symbol,Any}() + + @test_throws KeyError get_gradient((foo, :theta), store) + @test_throws KeyError get_parameter_value((foo, :theta), store) + @test_throws KeyError increment_gradient!((foo, :theta), 1.0, store) + @test_throws KeyError reset_gradient!((foo, :theta), store) + @test_throws KeyError Gen.set_parameter_value!((foo, :theta), 1.0, store) + @test_throws KeyError Gen.get_gradient_accumulator((foo, :theta), store) + + @test_throws KeyError get_gradient((foo, :phi), store) + @test_throws KeyError get_parameter_value((foo, :phi), store) + @test_throws KeyError increment_gradient!((foo, :phi), [1.0, 1.0], store) + @test_throws KeyError reset_gradient!((foo, :phi), store) + @test_throws KeyError Gen.set_parameter_value!((foo, :phi), [1.0, 1.0], store) + @test_throws KeyError Gen.get_gradient_accumulator((foo, :phi), store) + + # after the parameters are initialized in the store + + init_parameter!((foo, :theta), 1.0, store) + init_parameter!((foo, :phi), [1.0, 2.0], store) + + dict = Gen.get_local_parameters(store, foo) + @test length(dict) == 2 + @test dict[:theta] == 1.0 + @test dict[:phi] == [1.0, 2.0] + + @test get_gradient((foo, :theta), store) == 0.0 + @test get_parameter_value((foo, :theta), store) == 1.0 + increment_gradient!((foo, :theta), 1.1, store) + @test get_gradient((foo, :theta), store) == 1.1 + increment_gradient!((foo, :theta), 1.1, 2.0, store) + @test get_gradient((foo, :theta), store) == (1.1 + 2.2) + reset_gradient!((foo, :theta), store) + @test get_gradient((foo, :theta), store) == 0.0 + Gen.set_parameter_value!((foo, :theta), 2.0, store) + @test get_parameter_value((foo, :theta), store) == 2.0 + @test get_value(Gen.get_gradient_accumulator((foo, :theta), store)) == 0.0 + + @test get_gradient((foo, :phi), store) == [0.0, 0.0] + @test get_parameter_value((foo, :phi), store) == [1.0, 2.0] + increment_gradient!((foo, :phi), [1.1, 1.2], store) + @test get_gradient((foo, :phi), store) == [1.1, 1.2] + increment_gradient!((foo, :phi), [1.1, 1.2], 2.0, store) + @test get_gradient((foo, :phi), store) == ([1.1, 1.2] .+ (2.0 * [1.1, 1.2])) + reset_gradient!((foo, :phi), store) + @test get_gradient((foo, :phi), store) == [0.0, 0.0] + Gen.set_parameter_value!((foo, :phi), [2.0, 3.0], store) + @test get_parameter_value((foo, :phi), store) == [2.0, 3.0] + @test Gen.get_value(Gen.get_gradient_accumulator((foo, :phi), store)) == [0.0, 0.0] + + # FixedStepGradientDescent + + # DecayStepGradientDescent + + # init_optimizer and apply_update! for FixedStepGradientDescent and DecayStepGradientDescent + # default_parameter_context and default_julia_parameter_store +end + +@testset "composite optimizer" begin + +end + + +end From 22b9c98ea30e43485c2f0e81ee12ba9a53c88a0f Mon Sep 17 00:00:00 2001 From: Marco Cusumano-Towner Date: Tue, 18 May 2021 17:38:34 -0400 Subject: [PATCH 14/24] add multithreading to travis build --- .travis.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.travis.yml b/.travis.yml index 36f0f9500..6270d6015 100644 --- a/.travis.yml +++ b/.travis.yml @@ -6,6 +6,9 @@ julia: - 1.4 - 1.5 - 1.6 +env: + - JULIA_NUM_THREADS=1 + - JULIA_NUM_THREADS=2 services: - docker From 64ee5e0db78d401e2677f2d7ebeedaa00da917c1 Mon Sep 17 00:00:00 2001 From: Marco Cusumano-Towner Date: Tue, 18 May 2021 17:58:13 -0400 Subject: [PATCH 15/24] add multi-threading to VI and importance sampling with tests --- src/inference/variational.jl | 142 ++++++++++++++++++-------- src/optimization.jl | 4 +- test/inference/importance_sampling.jl | 39 +++---- test/inference/variational.jl | 66 ++++++------ test/optimization.jl | 29 ++++-- test/runtests.jl | 1 + 6 files changed, 179 insertions(+), 102 deletions(-) diff --git a/src/inference/variational.jl b/src/inference/variational.jl index 0be8f7692..09713cc7c 100644 --- a/src/inference/variational.jl +++ b/src/inference/variational.jl @@ -16,7 +16,7 @@ function single_sample_gradient_estimate!( accumulate_param_gradients!(var_trace, nothing, log_weight * scale_factor) # unbiased estimate of objective function, and trace - (log_weight, var_trace, model_trace) + return (log_weight, var_trace, model_trace) end function vimco_geometric_baselines(log_weights) @@ -29,12 +29,12 @@ function vimco_geometric_baselines(log_weights) baselines[i] = logsumexp(log_weights) - log(num_samples) log_weights[i] = temp end - baselines + return baselines end function logdiffexp(x, y) m = max(x, y) - m + log(exp(x - m) - exp(y - m)) + return m + log(exp(x - m) - exp(y - m)) end function vimco_arithmetic_baselines(log_weights) @@ -46,7 +46,7 @@ function vimco_arithmetic_baselines(log_weights) log_f_hat = log_sum_f_without_i - log(num_samples - 1) baselines[i] = logsumexp(log_sum_f_without_i, log_f_hat) - log(num_samples) end - baselines + return baselines end # black box, VIMCO gradient estimator @@ -85,7 +85,7 @@ function multi_sample_gradient_estimate!( # collection of traces and normalized importance weights, and estimate of # objective function - (L, traces, weights_normalized) + return (L, traces, weights_normalized) end function _maybe_accumulate_param_grad!(trace, optimizer, scale_factor::Real) @@ -117,6 +117,7 @@ update the parameters of `model`. - `callback`: Callback function that takes `(iter, traces, elbo_estimate)` as input, where `iter` is the iteration number and `traces` are samples from `var_model` for that iteration. +- `multithreaded`: if `true`, gradient estimation may use multiple threads. """ function black_box_vi!( model::GenerativeFunction, model_args::Tuple, @@ -125,31 +126,32 @@ function black_box_vi!( var_model::GenerativeFunction, var_model_args::Tuple, var_model_optimizer; iters=1000, samples_per_iter=100, verbose=false, - callback=(iter, traces, elbo_estimate) -> nothing) + callback=(iter, traces, elbo_estimate) -> nothing, + multithreaded=false) var_traces = Vector{Any}(undef, samples_per_iter) model_traces = Vector{Any}(undef, samples_per_iter) + log_weights = Vector{Float64}(undef, samples_per_iter) elbo_history = Vector{Float64}(undef, iters) for iter=1:iters # compute gradient estimate and objective function estimate - elbo_estimate = 0.0 - # TODO multithread (note that this would require accumulate_param_gradients! to be threadsafe) - for sample=1:samples_per_iter - - # accumulate the variational family gradients - (log_weight, var_trace, model_trace) = single_sample_gradient_estimate!( - var_model, var_model_args, - model, model_args, observations, 1/samples_per_iter) - elbo_estimate += (log_weight / samples_per_iter) - - # accumulate the generative model gradients - _maybe_accumulate_param_grad!(model_trace, model_optimizer, 1.0 / samples_per_iter) - - # record the traces - var_traces[sample] = var_trace - model_traces[sample] = model_trace + if multithreaded + Threads.@threads for i in 1:samples_per_iter + black_box_vi_iter!( + var_traces, model_traces, log_weights, i, samples_per_iter, + var_model, var_model_args, + model, model_args, observations, model_optimizer) + end + else + for i in 1:samples_per_iter + black_box_vi_iter!( + var_traces, model_traces, log_weights, i, samples_per_iter, + var_model, var_model_args, + model, model_args, observations, model_optimizer) + end end + elbo_estimate = sum(log_weights) elbo_history[iter] = elbo_estimate # print it @@ -167,9 +169,34 @@ function black_box_vi!( end end - (elbo_history[end], var_traces, elbo_history, model_traces) + return (elbo_history[end], var_traces, elbo_history, model_traces) +end + +function black_box_vi_iter!( + var_traces::Vector, model_traces::Vector, log_weights::Vector{Float64}, + i::Int, n::Int, + var_model::GenerativeFunction, var_model_args::Tuple, + model::GenerativeFunction, model_args::Tuple, + observations::ChoiceMap, + model_optimizer) + + # accumulate the variational family gradients + (log_weight, var_trace, model_trace) = single_sample_gradient_estimate!( + var_model, var_model_args, + model, model_args, observations, 1.0 / n) + log_weights[i] = log_weight / n + + # accumulate the generative model gradients + _maybe_accumulate_param_grad!(model_trace, model_optimizer, 1.0 / n) + + # record the traces + var_traces[i] = var_trace + model_traces[i] = model_trace + + return nothing end + black_box_vi!(model::GenerativeFunction, model_args::Tuple, observations::ChoiceMap, var_model::GenerativeFunction, var_model_args::Tuple, @@ -205,6 +232,7 @@ update the parameters of `model`. - `callback`: Callback function that takes `(iter, traces, elbo_estimate)` as input, where `iter` is the iteration number and `traces` are samples from `var_model` for that iteration. +- `multithreaded`: if `true`, gradient estimation may use multiple threads. """ function black_box_vimco!( model::GenerativeFunction, model_args::Tuple, @@ -212,35 +240,37 @@ function black_box_vimco!( var_model::GenerativeFunction, var_model_args::Tuple, var_model_optimizer::CompositeOptimizer, grad_est_samples::Int; iters=1000, samples_per_iter=100, geometric=true, verbose=false, - callback=(iter, traces, elbo_estimate) -> nothing) + callback=(iter, traces, elbo_estimate) -> nothing, + multithreaded=false) resampled_var_traces = Vector{Any}(undef, samples_per_iter) model_traces = Vector{Any}(undef, samples_per_iter) + log_weights = Vector{Float64}(undef, samples_per_iter) iwelbo_history = Vector{Float64}(undef, iters) for iter=1:iters # compute gradient estimate and objective function estimate - iwelbo_estimate = 0. - for sample=1:samples_per_iter - - # accumulate the variational family gradients - (est, original_var_traces, weights) = multi_sample_gradient_estimate!( - var_model, var_model_args, - model, model_args, observations, grad_est_samples, - 1/samples_per_iter, geometric) - iwelbo_estimate += (est / samples_per_iter) - - # record a variational trace obtained by resampling from the weighted collection - resampled_var_traces[sample] = original_var_traces[categorical(weights)] - - # accumulate the generative model gradient estimator - for (var_trace, weight) in zip(original_var_traces, weights) - constraints = merge(observations, get_choices(var_trace)) - (model_trace, _) = generate(model, model_args, constraints) - _maybe_accumulate_param_grad!(model_trace, model_optimizer, weight / samples_per_iter) + if multithreaded + Threads.@threads for i in 1:samples_per_iter + black_box_vimco_iter!( + resampled_var_traces, log_weights, + i, samples_per_iter, + var_model, var_model_args, model, model_args, + observations, geometric, grad_est_samples, + model_optimizer) + end + else + for i in 1:samples_per_iter + black_box_vimco_iter!( + resampled_var_traces, log_weights, + i, samples_per_iter, + var_model, var_model_args, model, model_args, + observations, geometric, grad_est_samples, + model_optimizer) end end + iwelbo_estimate = sum(log_weights) iwelbo_history[iter] = iwelbo_estimate # print it @@ -262,6 +292,34 @@ function black_box_vimco!( (iwelbo_history[end], resampled_var_traces, iwelbo_history, model_traces) end +function black_box_vimco_iter!( + resampled_var_traces::Vector, log_weights::Vector{Float64}, + i::Int, samples_per_iter::Int, + var_model::GenerativeFunction, var_model_args::Tuple, + model::GenerativeFunction, model_args::Tuple, + observations::ChoiceMap, geometric::Bool, grad_est_samples::Int, + model_optimizer) + + # accumulate the variational family gradients + (est, original_var_traces, weights) = multi_sample_gradient_estimate!( + var_model, var_model_args, + model, model_args, observations, grad_est_samples, + 1/samples_per_iter, geometric) + log_weights[i] = est / samples_per_iter + + # record a variational trace obtained by resampling from the weighted collection + resampled_var_traces[i] = original_var_traces[categorical(weights)] + + # accumulate the generative model gradient estimator + for (var_trace, weight) in zip(original_var_traces, weights) + constraints = merge(observations, get_choices(var_trace)) + (model_trace, _) = generate(model, model_args, constraints) + _maybe_accumulate_param_grad!(model_trace, model_optimizer, weight / samples_per_iter) + end + + return nothing +end + black_box_vimco!(model::GenerativeFunction, model_args::Tuple, observations::ChoiceMap, var_model::GenerativeFunction, var_model_args::Tuple, diff --git a/src/optimization.jl b/src/optimization.jl index 2f45d239a..1721058a1 100644 --- a/src/optimization.jl +++ b/src/optimization.jl @@ -375,7 +375,7 @@ Thread-safe (multiple threads can increment the gradient of the same parameter c function increment_gradient!( id::Tuple{GenerativeFunction,Symbol}, increment, store::JuliaParameterStore=default_julia_parameter_store) - accumulator = get_gradient_accumulator(store, id) + accumulator = get_gradient_accumulator(id, store) in_place_add!(accumulator, increment) return nothing end @@ -555,5 +555,3 @@ function assess(gen_fn::GenerativeFunction, args::Tuple, choices::ChoiceMap) end propose(gen_fn::GenerativeFunction, args::Tuple) = propose(gen_fn, args, default_parameter_context) - - diff --git a/test/inference/importance_sampling.jl b/test/inference/importance_sampling.jl index 4a4842b21..c2e9ac413 100644 --- a/test/inference/importance_sampling.jl +++ b/test/inference/importance_sampling.jl @@ -15,33 +15,36 @@ n = 4 - (traces, log_weights, lml_est) = importance_sampling(model, (), observations, n) - @test length(traces) == n - @test length(log_weights) == n - @test isapprox(logsumexp(log_weights), 0., atol=1e-14) - @test !isnan(lml_est) - for trace in traces - @test get_choices(trace)[:y] == y + for multithreaded in [false, true] + (traces, log_weights, lml_est) = importance_sampling( + model, (), observations, n; multithreaded=multithreaded) + @test length(traces) == n + @test length(log_weights) == n + @test isapprox(logsumexp(log_weights), 0., atol=1e-14) + @test !isnan(lml_est) + for trace in traces + @test get_choices(trace)[:y] == y + end end - (traces, log_weights, lml_est) = importance_sampling(model, (), observations, proposal, (), n) - @test length(traces) == n - @test length(log_weights) == n - @test isapprox(logsumexp(log_weights), 0., atol=1e-14) - @test !isnan(lml_est) - for trace in traces - @test get_choices(trace)[:y] == y + for multithreaded in [false, true] + (traces, log_weights, lml_est) = importance_sampling( + model, (), observations, proposal, (), n; + multithreaded=multithreaded) + @test length(traces) == n + @test length(log_weights) == n + @test isapprox(logsumexp(log_weights), 0., atol=1e-14) + @test !isnan(lml_est) + for trace in traces + @test get_choices(trace)[:y] == y + end end (trace, lml_est) = importance_resampling(model, (), observations, n) - @test isapprox(logsumexp(log_weights), 0., atol=1e-14) @test !isnan(lml_est) @test get_choices(trace)[:y] == y (trace, lml_est) = importance_resampling(model, (), observations, proposal, (), n) - @test isapprox(logsumexp(log_weights), 0., atol=1e-14) @test !isnan(lml_est) @test get_choices(trace)[:y] == y end - - diff --git a/test/inference/variational.jl b/test/inference/variational.jl index f3e4037fe..f44070873 100644 --- a/test/inference/variational.jl +++ b/test/inference/variational.jl @@ -17,40 +17,46 @@ end register_parameters!(approx, [:slope_mu, :slope_log_std, :intercept_mu, :intercept_log_std]) - # to regular black box variational inference - init_parameter!((approx, :slope_mu), 0.0) - init_parameter!((approx, :slope_log_std), 0.0) - init_parameter!((approx, :intercept_mu), 0.0) - init_parameter!((approx, :intercept_log_std), 0.0) - observations = choicemap() - optimizer = init_optimizer(DecayStepGradientDescent(1, 100000), approx) optimizer = init_optimizer(DecayStepGradientDescent(1., 1000), approx) - black_box_vi!(model, (), observations, approx, (), optimizer; - iters=2000, samples_per_iter=100, verbose=false) - slope_mu = get_parameter_value((approx, :slope_mu)) - slope_log_std = get_parameter_value((approx, :slope_log_std)) - intercept_mu = get_parameter_value((approx, :intercept_mu)) - intercept_log_std = get_parameter_value((approx, :intercept_log_std)) - @test isapprox(slope_mu, -1., atol=0.001) - @test isapprox(slope_log_std, 0.5, atol=0.001) - @test isapprox(intercept_mu, 1., atol=0.001) - @test isapprox(intercept_log_std, 2.0, atol=0.001) + + # test regular black box variational inference + for multithreaded in [false, true] + init_parameter!((approx, :slope_mu), 0.0) + init_parameter!((approx, :slope_log_std), 0.0) + init_parameter!((approx, :intercept_mu), 0.0) + init_parameter!((approx, :intercept_log_std), 0.0) + black_box_vi!(model, (), observations, approx, (), optimizer; + iters=2000, samples_per_iter=100, verbose=false, multithreaded=multithreaded) + + slope_mu = get_parameter_value((approx, :slope_mu)) + slope_log_std = get_parameter_value((approx, :slope_log_std)) + intercept_mu = get_parameter_value((approx, :intercept_mu)) + intercept_log_std = get_parameter_value((approx, :intercept_log_std)) + @test isapprox(slope_mu, -1., atol=0.001) + @test isapprox(slope_log_std, 0.5, atol=0.001) + @test isapprox(intercept_mu, 1., atol=0.001) + @test isapprox(intercept_log_std, 2.0, atol=0.001) + end # smoke test for black box variational inference with Monte Carlo objectives - init_parameter!((approx, :slope_mu), 0.0) - init_parameter!((approx, :slope_log_std), 0.0) - init_parameter!((approx, :intercept_mu), 0.0) - init_parameter!((approx, :intercept_log_std), 0.0) - black_box_vimco!(model, (), observations, approx, (), optimizer, 20; - iters=50, samples_per_iter=100, verbose=false, geometric=false) - - init_parameter!((approx, :slope_mu), 0.0) - init_parameter!((approx, :slope_log_std), 0.0) - init_parameter!((approx, :intercept_mu), 0.0) - init_parameter!((approx, :intercept_log_std), 0.0) - black_box_vimco!(model, (), observations, approx, (), optimizer, 20; - iters=50, samples_per_iter=100, verbose=false, geometric=true) + for multithreaded in [false, true] + init_parameter!((approx, :slope_mu), 0.0) + init_parameter!((approx, :slope_log_std), 0.0) + init_parameter!((approx, :intercept_mu), 0.0) + init_parameter!((approx, :intercept_log_std), 0.0) + black_box_vimco!(model, (), observations, approx, (), optimizer, 20; + iters=50, samples_per_iter=100, verbose=false, geometric=false, + multithreaded=multithreaded) + + init_parameter!((approx, :slope_mu), 0.0) + init_parameter!((approx, :slope_log_std), 0.0) + init_parameter!((approx, :intercept_mu), 0.0) + init_parameter!((approx, :intercept_log_std), 0.0) + black_box_vimco!(model, (), observations, approx, (), optimizer, 20; + iters=50, samples_per_iter=100, verbose=false, geometric=true, + multithreaded=multithreaded) + end end diff --git a/test/optimization.jl b/test/optimization.jl index ff6a7bb80..6e531e71b 100644 --- a/test/optimization.jl +++ b/test/optimization.jl @@ -1,13 +1,5 @@ @testset "optimization" begin -@testset "in_place_add!" begin - # TODO -end - -@testset "Accumulator" begin - # TODO -end - @testset "Julia parameter store" begin store = JuliaParameterStore() @@ -70,11 +62,30 @@ end @test get_parameter_value((foo, :phi), store) == [2.0, 3.0] @test Gen.get_value(Gen.get_gradient_accumulator((foo, :phi), store)) == [0.0, 0.0] + # check that the default global Julia store was unaffected + @test_throws KeyError get_parameter_value((foo, :theta)) + @test_throws KeyError get_gradient((foo, :theta)) + @test_throws KeyError increment_gradient!((foo, :theta), 1.0) + # FixedStepGradientDescent + init_parameter!((foo, :theta), 1.0, store) + init_parameter!((foo, :phi), [1.0, 2.0], store) + increment_gradient!((foo, :theta), 2.0, store) + increment_gradient!((foo, :phi), [1.0, 3.0], store) + optimizer = init_optimizer(FixedStepGradientDescent(1e-2), [(foo, :theta)], store) + apply_update!(optimizer) # update just theta + @test get_gradient((foo, :theta), store) == 0.0 + @test get_parameter_value((foo, :theta), store) == 1.0 + (2.0 * 1e-2) + @test get_gradient((foo, :phi), store) == [1.0, 3.0] # unchanged + @test get_parameter_value((foo, :phi), store) == [1.0, 2.0] # unchanged + optimizer = init_optimizer(FixedStepGradientDescent(1e-2), [(foo, :phi)], store) + apply_update!(optimizer) # update just phi + @test get_gradient((foo, :phi), store) == [0.0, 0.0] + @test get_parameter_value((foo, :phi), store) == ([1.0, 2.0] .+ 1e-2 * [1.0, 3.0]) # DecayStepGradientDescent + # TODO - # init_optimizer and apply_update! for FixedStepGradientDescent and DecayStepGradientDescent # default_parameter_context and default_julia_parameter_store end diff --git a/test/runtests.jl b/test/runtests.jl index 6fddf20b4..fd8f9ebd4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -76,6 +76,7 @@ end const dx = 1e-6 +include("optimization.jl") include("autodiff.jl") include("diff.jl") include("selection.jl") From 6ef392d3a65fff27beab08997c58e446def58e25 Mon Sep 17 00:00:00 2001 From: Marco Cusumano-Towner Date: Wed, 19 May 2021 01:21:33 -0400 Subject: [PATCH 16/24] fixes to docs --- docs/make.jl | 1 - docs/src/index.md | 6 +++--- docs/src/ref/gfi.md | 3 ++- docs/src/ref/internals/parameter_optimization.md | 7 ------- docs/src/ref/learning.md | 4 ++-- docs/src/ref/selections.md | 1 - 6 files changed, 7 insertions(+), 15 deletions(-) diff --git a/docs/make.jl b/docs/make.jl index 3c2448478..7c82d4dfd 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -29,7 +29,6 @@ makedocs( "Learning Generative Functions" => "ref/learning.md" ], "Internals" => [ - "Optimizing Trainable Parameters" => "ref/internals/parameter_optimization.md", "Modeling Language Implementation" => "ref/internals/language_implementation.md" ] ] diff --git a/docs/src/index.md b/docs/src/index.md index 161192b99..98aaf390a 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -6,7 +6,7 @@ Pages = [ "getting_started.md", "tutorials.md", - "guide.md", + "guide.md" ] Depth = 2 ``` @@ -21,8 +21,8 @@ Pages = [ "ref/parameter_optimization.md", "ref/inference.md", "ref/gfi.md", - "ref/distributions.md" - "ref/extending.md", + "ref/distributions.md", + "ref/extending.md" ] Depth = 2 ``` diff --git a/docs/src/ref/gfi.md b/docs/src/ref/gfi.md index dde929722..a0a40126c 100644 --- a/docs/src/ref/gfi.md +++ b/docs/src/ref/gfi.md @@ -344,6 +344,7 @@ A generative function statically reports whether or not it is able to compute gr The **trainable parameters** of a generative function are (unlike arguments and random choices) *state* of the generative function itself, and are not contained in the trace. Generative functions that have trainable parameters maintain *gradient accumulators* for these parameters, which get incremented by the gradient induced by the given trace by a call to [`accumulate_param_gradients!`](@ref). Users then use these accumulated gradients to update to the values of the trainable parameters. +Use [`get_parameters`](@ref) to obtain the full set of trainable parameters that a generative function uses (see [Optimizing Trainable Paramters](@ref) for more details). ### Return value gradient The set of elements (either arguments, random choices, or trainable parameters) for which gradients are available is called the **gradient source set**. @@ -371,5 +372,5 @@ has_argument_grads accepts_output_grad accumulate_param_gradients! choice_gradients -get_params +get_parameters ``` diff --git a/docs/src/ref/internals/parameter_optimization.md b/docs/src/ref/internals/parameter_optimization.md index 48d88d494..e69de29bb 100644 --- a/docs/src/ref/internals/parameter_optimization.md +++ b/docs/src/ref/internals/parameter_optimization.md @@ -1,7 +0,0 @@ -# [Optimizing Trainable Parameters](@id optimizing-internal) - -To add support for a new type of gradient-based parameter update, create a new type with the following methods defined for the types of generative functions that are to be supported. -```@docs -Gen.init_update_state -Gen.apply_update! -``` diff --git a/docs/src/ref/learning.md b/docs/src/ref/learning.md index efdc09c8a..496dd051e 100644 --- a/docs/src/ref/learning.md +++ b/docs/src/ref/learning.md @@ -98,7 +98,7 @@ function train_model(data::Vector{ChoiceMap}) end ``` -Note that using the same primitives ([`generate`](@ref) and [`accumulate_param_gradients!`](@ref)), you can compose various more sophisticated learning algorithms involving e.g. stochastic gradient descent and minibatches, and more sophisticated stochastic gradient optimizers like [`ADAM`](@ref). +Note that using the same primitives ([`generate`](@ref) and [`accumulate_param_gradients!`](@ref)), you can compose various more sophisticated learning algorithms involving e.g. stochastic gradient descent and minibatches, and more sophisticated stochastic gradient optimizers. For example, [`train!`](@ref) trains a generative function from complete data with minibatches. ## Learning from Incomplete Data @@ -209,7 +209,7 @@ Then, the traces of the model can be obtained by simulating from the variational Instead of fitting the variational approximation from scratch for each observation, it is possible to fit an *inference model* instead, that takes as input the observation, and generates a distribution on latent variables as output (as in the wake sleep algorithm). When we train the variational approximation by minimizing the evidence lower bound (ELBO) this is called amortized variational inference. Variational autencoders are an example. -It is possible to perform amortized variational inference using [`black_box_vi`](@ref) or [`black_box_vimco!`](@ref). +It is possible to perform amortized variational inference using [`black_box_vi!`](@ref) or [`black_box_vimco!`](@ref). ## References diff --git a/docs/src/ref/selections.md b/docs/src/ref/selections.md index 172e31876..986173335 100644 --- a/docs/src/ref/selections.md +++ b/docs/src/ref/selections.md @@ -55,5 +55,4 @@ AllSelection HierarchicalSelection DynamicSelection StaticSelection -ComplementSelection ``` From ca03f61272a9278836da05c35c69ad68d7c84bae Mon Sep 17 00:00:00 2001 From: Marco Cusumano-Towner Date: Wed, 19 May 2021 14:20:42 -0400 Subject: [PATCH 17/24] fix get_parameters --- src/dynamic/dynamic.jl | 39 +++++++++++++++++++++------------------ src/dynamic/generate.jl | 3 +-- src/dynamic/simulate.jl | 3 +-- src/gen_fn_interface.jl | 4 +++- src/optimization.jl | 4 ++++ src/static_ir/dag.jl | 17 ++++++++++++----- test/dsl/dynamic_dsl.jl | 7 +++++++ 7 files changed, 49 insertions(+), 28 deletions(-) diff --git a/src/dynamic/dynamic.jl b/src/dynamic/dynamic.jl index 2db4a12cc..a429dd02d 100644 --- a/src/dynamic/dynamic.jl +++ b/src/dynamic/dynamic.jl @@ -17,7 +17,7 @@ mutable struct DynamicDSLFunction{T} <: GenerativeFunction{T,DynamicDSLTrace} julia_function::Function has_argument_grads::Vector{Bool} accepts_output_grad::Bool - parameters::Union{Vector,Function} + parameters::Union{Set{Tuple{GenerativeFunction,Symbol}},Function} end function DynamicDSLFunction(arg_types::Vector{Type}, @@ -29,28 +29,19 @@ function DynamicDSLFunction(arg_types::Vector{Type}, return DynamicDSLFunction{T}(arg_types, has_defaults, arg_defaults, julia_function, - has_argument_grads, accepts_output_grad, []) + has_argument_grads, accepts_output_grad, + Set{Tuple{GenerativeFunction,Symbol}}()) end function get_parameters(gen_fn::DynamicDSLFunction, parameter_context) - if isa(gen_fn.parameters, Vector) + if isa(gen_fn.parameters, Set) julia_store = get_julia_store(parameter_context) - parameter_stores_to_ids = Dict{Any,Vector}() - parameter_ids = Tuple{GenerativeFunction,Symbol}[] - for param in gen_fn.parameters - if isa(param, Tuple{GenerativeFunction,Symbol}) - push!(parameter_ids, param) - elseif isa(param, Symbol) - push!(parameter_ids, (gen_fn, param)) - else - throw(ArgumentError("Invalid parameter declaration for DML generative function $gen_fn: $param")) - end - end - parameter_stores_to_ids[julia_store] = parameter_ids + parameter_stores_to_ids = Dict{Any,Set}(julia_store => gen_fn.parameters) return parameter_stores_to_ids elseif isa(gen_fn.parameters, Function) return gen_fn.parameters(parameter_context) end + @assert false end """ @@ -60,13 +51,25 @@ Register the trainable parameters that used by a DML generative function. This includes all parameters used within any calls made by the generative function, and includes any parameters that may be used by any possible trace (stochastic control flow may cause a parameter to be used by one trace but not another). -The second argument is either a `Vector` or a `Function` that takes a parameter context and returns a `Dict` that maps parameter stores to `Vector`s of parameter IDs. -When the second argument is a `Vector`, each element is either a `Symbol` that is the name of a parameter declared in the body of `gen_fn` using `@param`, or is a tuple `(other_gen_fn::GenerativeFunction, name::Symbol)` where `@param ` was declared in the body of `other_gen_fn`. +The second argument is either an iterable collection or a `Function` that takes a parameter context and returns a `Dict` that maps parameter stores to `Set`s of parameter IDs. +When the second argument is an iterable collection, each element is either a `Symbol` that is the name of a parameter declared in the body of `gen_fn` using `@param`, or is a tuple `(other_gen_fn::GenerativeFunction, name::Symbol)` where `@param ` was declared in the body of `other_gen_fn`. The `Function` input is used when `gen_fn` uses parameters that come from more than one parameter store, including parameters that are housed in parameter stores that are not `JuliaParameterStore`s (e.g. if `gen_fn` invokes a generative function that executes in another non-Julia runtime). See [Optimizing Trainable Parameters](@ref) for details on parameter contexts, and parameter stores. """ -function register_parameters!(gen_fn::DynamicDSLFunction, parameters) +function register_parameters!(gen_fn::DynamicDSLFunction, parameters::Function) gen_fn.parameters = parameters +end +function register_parameters!(gen_fn::DynamicDSLFunction, parameters) + gen_fn.parameters = Set{Tuple{GenerativeFunction,Symbol}}() + for param in parameters + if isa(param, Tuple{GenerativeFunction,Symbol}) + push!(gen_fn.parameters, param) + elseif isa(param, Symbol) + push!(gen_fn.parameters, (gen_fn, param)) + else + throw(ArgumentError("Invalid parameter declaration for DML generative function $gen_fn: $param")) + end + end return nothing end diff --git a/src/dynamic/generate.jl b/src/dynamic/generate.jl index 6098a9012..866ca92f7 100644 --- a/src/dynamic/generate.jl +++ b/src/dynamic/generate.jl @@ -8,8 +8,7 @@ mutable struct GFGenerateState function GFGenerateState(gen_fn, args, constraints, parameter_context) parameter_store = get_julia_store(parameter_context) - registered_julia_parameters = Set{Tuple{GenerativeFunction,Symbol}}( - get_parameters(gen_fn, parameter_context)[parameter_store]) + registered_julia_parameters = get_parameters(gen_fn, parameter_context)[parameter_store] trace = DynamicDSLTrace( gen_fn, args, parameter_store, parameter_context, registered_julia_parameters) return new(trace, constraints, 0., AddressVisitor(), gen_fn, parameter_context) diff --git a/src/dynamic/simulate.jl b/src/dynamic/simulate.jl index c366176da..bc75b1ada 100644 --- a/src/dynamic/simulate.jl +++ b/src/dynamic/simulate.jl @@ -7,8 +7,7 @@ mutable struct GFSimulateState function GFSimulateState( gen_fn::GenerativeFunction, args::Tuple, parameter_context) parameter_store = get_julia_store(parameter_context) - registered_julia_parameters = Set{Tuple{GenerativeFunction,Symbol}}( - get_parameters(gen_fn, parameter_context)[parameter_store]) + registered_julia_parameters = get_parameters(gen_fn, parameter_context)[parameter_store] trace = DynamicDSLTrace( gen_fn, args, parameter_store, parameter_context, registered_julia_parameters) return new(trace, AddressVisitor(), gen_fn, parameter_context) diff --git a/src/gen_fn_interface.jl b/src/gen_fn_interface.jl index f85b5b559..413df400b 100644 --- a/src/gen_fn_interface.jl +++ b/src/gen_fn_interface.jl @@ -124,7 +124,9 @@ get_return_type(::GenerativeFunction{T,U}) where {T,U} = T get_trace_type(::GenerativeFunction{T,U}) where {T,U} = U """ - parameters::Dict{ParameterStore,Vector} = get_parameters(gen_fn::GenerativeFunction, parameter_context) + parameters::Dict{ParameterStore,Set} = get_parameters( + gen_fn::GenerativeFunction, + parameter_context=default_parameter_context) Returns the parameters used by the generative function (including all of its calls). """ diff --git a/src/optimization.jl b/src/optimization.jl index 1721058a1..848a7fc00 100644 --- a/src/optimization.jl +++ b/src/optimization.jl @@ -538,6 +538,10 @@ The default global parameter context, which is initialized to contain the mappin const default_parameter_context = Dict{Symbol,Any}( JULIA_PARAMETER_STORE_KEY => default_julia_parameter_store) +function get_parameters(gen_fn::GenerativeFunction) + return get_parameters(gen_fn, default_parameter_context) +end + function simulate(gen_fn::GenerativeFunction, args::Tuple) return simulate(gen_fn, args, default_parameter_context) end diff --git a/src/static_ir/dag.jl b/src/static_ir/dag.jl index 190b40ba4..65c1ab7bc 100644 --- a/src/static_ir/dag.jl +++ b/src/static_ir/dag.jl @@ -205,13 +205,20 @@ function set_accepts_output_grad!(builder::StaticIRBuilder, value::Bool) end function get_parameters(ir::StaticIR, gen_fn::GenerativeFunction, parameter_context) - parameters = Dict() - for call_node in ir.call_nodes - merge!(parameters, get_parameters(call_node.generative_function, parameter_context)) - end julia_store = get_julia_store(parameter_context) + parameters = Dict(julia_store => Set{Tuple{GenerativeFunction,Symbol}}()) for param_node in ir.trainable_param_nodes - parameters[store] = (gen_fn, param_node.name) + push!(parameters[julia_store], (gen_fn, param_node.name)) + end + for call_node in ir.call_nodes + callee_parameters = get_parameters(call_node.generative_function, parameter_context) + for (store, ids::Set) in callee_parameters + if haskey(parameters, store) + union!(parameters[store], ids) + else + parameters[store] = ids + end + end end return parameters end diff --git a/test/dsl/dynamic_dsl.jl b/test/dsl/dynamic_dsl.jl index 0e26f8148..f7458d6ee 100644 --- a/test/dsl/dynamic_dsl.jl +++ b/test/dsl/dynamic_dsl.jl @@ -297,6 +297,13 @@ end end register_parameters!(foo, [(bar, :theta1), :theta2]) + store_to_ids = Gen.get_parameters(foo, Gen.default_parameter_context) + @test length(store_to_ids) == 1 + ids = store_to_ids[Gen.default_julia_parameter_store] + @test length(ids) == 2 + @test (foo, :theta2) in ids + @test (bar, :theta1) in ids + init_parameter!((bar, :theta1), 0.0) init_parameter!((foo, :theta2), 0.0) From 412c1b045019dc037b59bd2b0621d0cb269b2abe Mon Sep 17 00:00:00 2001 From: Marco Cusumano-Towner Date: Wed, 19 May 2021 14:35:58 -0400 Subject: [PATCH 18/24] fix init_parameters! bug on resize params --- src/optimization.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/optimization.jl b/src/optimization.jl index 848a7fc00..745e05812 100644 --- a/src/optimization.jl +++ b/src/optimization.jl @@ -304,7 +304,10 @@ function init_parameter!( store.values[gen_fn] = Dict{Symbol,Any}() end store.values[gen_fn][name] = value - reset_gradient!(id, store) + if !haskey(store.gradient_accumulators, gen_fn) + store.gradient_accumulators[gen_fn] = Dict{Symbol,Any}() + end + store.gradient_accumulators[gen_fn][name] = Accumulator(zero(value)) return nothing end From c59bd76a25fe0632999393b8479d194185359550 Mon Sep 17 00:00:00 2001 From: Marco Cusumano-Towner Date: Wed, 19 May 2021 14:42:08 -0400 Subject: [PATCH 19/24] add default get_parameters --- src/gen_fn_interface.jl | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/gen_fn_interface.jl b/src/gen_fn_interface.jl index 413df400b..54ce4e403 100644 --- a/src/gen_fn_interface.jl +++ b/src/gen_fn_interface.jl @@ -124,13 +124,18 @@ get_return_type(::GenerativeFunction{T,U}) where {T,U} = T get_trace_type(::GenerativeFunction{T,U}) where {T,U} = U """ - parameters::Dict{ParameterStore,Set} = get_parameters( + parameters::Dict{Any,Set} = get_parameters( gen_fn::GenerativeFunction, parameter_context=default_parameter_context) Returns the parameters used by the generative function (including all of its calls). + +The parameters are returned in a `Dict` that maps parameter stores to sets of +parameter IDs within each store. """ -function get_parameters end +function get_parameters(gen_fn::GenerativeFunction, parameter_context) + return Dict{Any,Set}() +end """ bools::Tuple = has_argument_grads(gen_fn::Union{GenerativeFunction,Distribution}) From b0909ef856680da2e2d9c89e34ddf3eca08b6a40 Mon Sep 17 00:00:00 2001 From: Marco Cusumano-Towner Date: Wed, 19 May 2021 14:50:05 -0400 Subject: [PATCH 20/24] fix init_optimizer ambuigity --- src/optimization.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/optimization.jl b/src/optimization.jl index 745e05812..1e5b96bff 100644 --- a/src/optimization.jl +++ b/src/optimization.jl @@ -213,7 +213,8 @@ end Convenience method that constructs an optimizer that updates all parameters used by the given generative function, even when the parameters exist in multiple parameter stores. """ -function init_optimizer(conf, gen_fn::GenerativeFunction; parameter_context=default_parameter_context) +function init_optimizer( + conf::T, gen_fn::GenerativeFunction; parameter_context=default_parameter_context) where {T} return CompositeOptimizer(conf, get_parameters(gen_fn, parameter_context)) end From 4c9a8a18cb525127b139720b579c98ab0e869caf Mon Sep 17 00:00:00 2001 From: Marco Cusumano-Towner Date: Wed, 19 May 2021 14:54:07 -0400 Subject: [PATCH 21/24] fix ambiguity --- src/optimization.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/optimization.jl b/src/optimization.jl index 1e5b96bff..b1b2430f4 100644 --- a/src/optimization.jl +++ b/src/optimization.jl @@ -186,7 +186,7 @@ struct CompositeOptimizer function CompositeOptimizer(conf, parameter_stores_to_ids) optimizers = Dict{Any,Any}() for (store, parameter_ids) in parameter_stores_to_ids - optimizers[store] = init_optimizer(conf, parameter_ids, store) + optimizers[store] = init_optimizer(conf, collect(parameter_ids), store) end new(conf, optimizers) end @@ -214,7 +214,7 @@ end Convenience method that constructs an optimizer that updates all parameters used by the given generative function, even when the parameters exist in multiple parameter stores. """ function init_optimizer( - conf::T, gen_fn::GenerativeFunction; parameter_context=default_parameter_context) where {T} + conf, gen_fn::GenerativeFunction; parameter_context=default_parameter_context) return CompositeOptimizer(conf, get_parameters(gen_fn, parameter_context)) end From 2fba354fad773017c268ab36e048c40ac45e43f8 Mon Sep 17 00:00:00 2001 From: Marco Cusumano-Towner Date: Wed, 19 May 2021 16:32:58 -0400 Subject: [PATCH 22/24] add missing file for static_ir gradients tests --- test/static_ir/gradients.jl | 117 ++++++++++++++++++++++++++++++++++++ 1 file changed, 117 insertions(+) create mode 100644 test/static_ir/gradients.jl diff --git a/test/static_ir/gradients.jl b/test/static_ir/gradients.jl new file mode 100644 index 000000000..13cd15dc8 --- /dev/null +++ b/test/static_ir/gradients.jl @@ -0,0 +1,117 @@ +@testset "backprop" begin + + #@gen (static) function bar(mu_z::Float64) + #z = @trace(normal(mu_z, 1), :z) + #return z + mu_z + #end + + # bar + builder = StaticIRBuilder() + mu_z = add_argument_node!(builder, name=:mu_z, typ=:Float64, compute_grad=true) + one = add_constant_node!(builder, 1.) + z = add_addr_node!(builder, normal, inputs=[mu_z, one], addr=:z, name=:z) + retval = add_julia_node!(builder, (z, mu_z) -> z + mu_z, inputs=[z, mu_z], name=:retval) + set_return_node!(builder, retval) + ir = build_ir(builder) + bar = eval(generate_generative_function(ir, :bar, track_diffs=false, cache_julia_nodes=false)) + + #@gen (static) function foo(mu_a::Float64) + #param theta::Float64 + #a = @trace(normal(mu_a, 1), :a) + #b = @trace(normal(a, 1), :b) + #bar = @trace(bar(a), :bar) + #c = a * b * bar * theta + #out = @trace(normal(c, 1), :out) + #return out + #end + + # foo + builder = StaticIRBuilder() + mu_a = add_argument_node!(builder, name=:mu_a, typ=:Float64, compute_grad=true) + theta = add_trainable_param_node!(builder, :theta, typ=QuoteNode(Float64)) + one = add_constant_node!(builder, 1.) + a = add_addr_node!(builder, normal, inputs=[mu_a, one], addr=:a, name=:a) + b = add_addr_node!(builder, normal, inputs=[a, one], addr=:b, name=:b) + bar_val = add_addr_node!(builder, bar, inputs=[a], addr=:bar, name=:bar_val) + c = add_julia_node!(builder, (a, b, bar, theta) -> (a * b * bar * theta), + inputs=[a, b, bar_val, theta], name=:c) + retval = add_addr_node!(builder, normal, inputs=[c, one], addr=:out, name=:out) + set_return_node!(builder, retval) + ir = build_ir(builder) + foo = eval(generate_generative_function(ir, :foo, track_diffs=false, cache_julia_nodes=false)) + + Gen.load_generated_functions() + + # test get_parameters + store_to_ids = Gen.get_parameters(foo, Gen.default_parameter_context) + @test length(store_to_ids) == 1 + @test length(store_to_ids[Gen.default_julia_parameter_store]) == 1 + @test (foo, :theta) in store_to_ids[Gen.default_julia_parameter_store] + + function f(mu_a, theta, a, b, z, out) + lpdf = 0. + mu_z = a + lpdf += logpdf(normal, z, mu_z, 1) + lpdf += logpdf(normal, a, mu_a, 1) + lpdf += logpdf(normal, b, a, 1) + c = a * b * (z + mu_z) * theta + lpdf += logpdf(normal, out, c, 1) + return lpdf + 2 * out + end + + mu_a = 1. + theta = -0.5 + a = 2. + b = 3. + z = 4. + out = 5. + + # initialize the trainable parameter + init_parameter!((foo, :theta), theta) + + # get the initial trace + constraints = choicemap() + constraints[:a] = a + constraints[:b] = b + constraints[:out] = out + constraints[:bar => :z] = z + (trace, _) = generate(foo, (mu_a,), constraints) + + # compute gradients with choice_gradients + selection = select(:bar => :z, :a, :out) + selection = StaticSelection(selection) + retval_grad = 2. + ((mu_a_grad,), value_trie, gradient_trie) = choice_gradients(trace, selection, retval_grad) + + # check input gradient + @test isapprox(mu_a_grad, finite_diff(f, (mu_a, theta, a, b, z, out), 1, dx)) + + # check value trie + @test get_value(value_trie, :a) == a + @test get_value(value_trie, :out) == out + @test get_value(value_trie, :bar => :z) == z + @test !has_value(value_trie, :b) # was not selected + @test length(get_submaps_shallow(value_trie)) == 1 + @test length(get_values_shallow(value_trie)) == 2 + + # check gradient trie + @test length(get_submaps_shallow(gradient_trie)) == 1 + @test length(get_values_shallow(gradient_trie)) == 2 + @test !has_value(gradient_trie, :b) # was not selected + @test isapprox(get_value(gradient_trie, :a), finite_diff(f, (mu_a, theta, a, b, z, out), 3, dx)) + @test isapprox(get_value(gradient_trie, :out), finite_diff(f, (mu_a, theta, a, b, z, out), 6, dx)) + @test isapprox(get_value(gradient_trie, :bar => :z), finite_diff(f, (mu_a, theta, a, b, z, out), 5, dx)) + + # compute gradients with accumulate_param_gradients! + retval_grad = 2. + (mu_a_grad,) = accumulate_param_gradients!(trace, retval_grad) + + # check input gradient + @test isapprox(mu_a_grad, finite_diff(f, (mu_a, theta, a, b, z, out), 1, dx)) + + # check trainable parameter gradient + @test isapprox( + get_gradient((foo, :theta)), + finite_diff(f, (mu_a, theta, a, b, z, out), 2, dx)) + +end From d342a7a34c17588c1f39d4800db7ae7a51f1341f Mon Sep 17 00:00:00 2001 From: Marco Cusumano-Towner Date: Wed, 14 Jul 2021 18:06:53 -0400 Subject: [PATCH 23/24] only allow registering parameters once --- src/dynamic/dynamic.jl | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/dynamic/dynamic.jl b/src/dynamic/dynamic.jl index a429dd02d..c3d818262 100644 --- a/src/dynamic/dynamic.jl +++ b/src/dynamic/dynamic.jl @@ -17,7 +17,7 @@ mutable struct DynamicDSLFunction{T} <: GenerativeFunction{T,DynamicDSLTrace} julia_function::Function has_argument_grads::Vector{Bool} accepts_output_grad::Bool - parameters::Union{Set{Tuple{GenerativeFunction,Symbol}},Function} + parameters::Union{Nothing,Set{Tuple{GenerativeFunction,Symbol}},Function} end function DynamicDSLFunction(arg_types::Vector{Type}, @@ -30,7 +30,7 @@ function DynamicDSLFunction(arg_types::Vector{Type}, has_defaults, arg_defaults, julia_function, has_argument_grads, accepts_output_grad, - Set{Tuple{GenerativeFunction,Symbol}}()) + nothing) end function get_parameters(gen_fn::DynamicDSLFunction, parameter_context) @@ -57,9 +57,17 @@ The `Function` input is used when `gen_fn` uses parameters that come from more t See [Optimizing Trainable Parameters](@ref) for details on parameter contexts, and parameter stores. """ function register_parameters!(gen_fn::DynamicDSLFunction, parameters::Function) + if gen_fn.parameters !== nothing + throw(ArgumentError("parameters for $gen_fn were already registered")) + end gen_fn.parameters = parameters + return nothing end + function register_parameters!(gen_fn::DynamicDSLFunction, parameters) + if gen_fn.parameters !== nothing + throw(ArgumentError("parameters for $gen_fn were already registered")) + end gen_fn.parameters = Set{Tuple{GenerativeFunction,Symbol}}() for param in parameters if isa(param, Tuple{GenerativeFunction,Symbol}) From bf9d1247cac7e31ac77ad194c5268d738b13b2a3 Mon Sep 17 00:00:00 2001 From: Marco Cusumano-Towner Date: Wed, 14 Jul 2021 18:23:14 -0400 Subject: [PATCH 24/24] fix --- src/dynamic/dynamic.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/dynamic/dynamic.jl b/src/dynamic/dynamic.jl index c3d818262..37f4a337d 100644 --- a/src/dynamic/dynamic.jl +++ b/src/dynamic/dynamic.jl @@ -40,8 +40,10 @@ function get_parameters(gen_fn::DynamicDSLFunction, parameter_context) return parameter_stores_to_ids elseif isa(gen_fn.parameters, Function) return gen_fn.parameters(parameter_context) + else + # no parameters were registered + return Dict{Any,Set}() end - @assert false end """