Skip to content

Commit

Permalink
tests passing
Browse files Browse the repository at this point in the history
  • Loading branch information
marcoct committed May 18, 2021
1 parent d6f636c commit 3fc5a0e
Show file tree
Hide file tree
Showing 39 changed files with 404 additions and 230 deletions.
4 changes: 2 additions & 2 deletions src/dynamic/assess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ function traceat(state::GFAssessState, gen_fn::GenerativeFunction{T,U},
end

function assess(
gen_fn::DynamicDSLFunction, args::Tuple, choices::ChoiceMap;
parameter_context=default_parameter_context)
gen_fn::DynamicDSLFunction, args::Tuple, choices::ChoiceMap,
parameter_context::Dict)
state = GFAssessState(gen_fn, choices, parameter_context)
retval = exec(gen_fn, state, args)

Expand Down
5 changes: 4 additions & 1 deletion src/dynamic/backprop.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,11 @@ mutable struct GFBackpropParamsState
end
end

function read_param(state::GFBackpropParamsState, name::Symbol)
function read_param!(state::GFBackpropParamsState, name::Symbol)
parameter_id = (state.active_gen_fn, name)
if !(parameter_id in state.trace.registered_julia_parameters)
throw(ArgumentError("parameter $parameter_id was not registered using register_parameters!"))
end
return state.tracked_params[parameter_id]
end

Expand Down
63 changes: 51 additions & 12 deletions src/dynamic/dynamic.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
export register_parameters!

include("trace.jl")

"""
Expand All @@ -8,13 +10,14 @@ A generative function based on a shallowly embedding modeling language based on
Constructed using the `@gen` keyword.
Most methods in the generative function interface involve a end-to-end execution of the function.
"""
struct DynamicDSLFunction{T} <: GenerativeFunction{T,DynamicDSLTrace}
mutable struct DynamicDSLFunction{T} <: GenerativeFunction{T,DynamicDSLTrace}
arg_types::Vector{Type}
has_defaults::Bool
arg_defaults::Vector{Union{Some{Any},Nothing}}
julia_function::Function
has_argument_grads::Vector{Bool}
accepts_output_grad::Bool
parameters::Union{Vector,Function}
end

function DynamicDSLFunction(arg_types::Vector{Type},
Expand All @@ -26,29 +29,65 @@ function DynamicDSLFunction(arg_types::Vector{Type},
return DynamicDSLFunction{T}(arg_types,
has_defaults, arg_defaults,
julia_function,
has_argument_grads, accepts_output_grad)
has_argument_grads, accepts_output_grad, [])
end

function Base.show(io::IO, gen_fn::DynamicDSLFunction)
return "Gen DML generative function: $(gen_fn.julia_function)"
function get_parameters(gen_fn::DynamicDSLFunction, parameter_context)
if isa(gen_fn.parameters, Vector)
julia_store = get_julia_store(parameter_context)
parameter_stores_to_ids = Dict{Any,Vector}()
parameter_ids = Tuple{GenerativeFunction,Symbol}[]
for param in gen_fn.parameters
if isa(param, Tuple{GenerativeFunction,Symbol})
push!(parameter_ids, param)
elseif isa(param, Symbol)
push!(parameter_ids, (gen_fn, param))
else
throw(ArgumentError("Invalid parameter declaration for DML generative function $gen_fn: $param"))
end
end
parameter_stores_to_ids[julia_store] = parameter_ids
return parameter_stores_to_ids
elseif isa(gen_fn.parameters, Function)
return gen_fn.parameters(parameter_context)
end
end

function Base.show(io::IO, ::MIME"text/plain", gen_fn::DynamicDSLFunction)
return "Gen DML generative function: $(gen_fn.julia_function)"
"""
register_parameters!(gen_fn::DynamicDSLFunction, parameters)

Register the altrainable parameters that are used by a DML generative function.

This includes all parameters used within any calls made by the generative function.

There are two variants:

# TODO document the variants
"""
function register_parameters!(gen_fn::DynamicDSLFunction, parameters)
gen_fn.parameters = parameters
return nothing
end

function get_parameters(gen_fn::DynamicDSLFunction, parameter_context)
# TODO for this, we need to walk the code... (and throw errors when the
function Base.show(io::IO, gen_fn::DynamicDSLFunction)
print(io, "Gen DML generative function: $(gen_fn.julia_function)")
end

function Base.show(io::IO, ::MIME"text/plain", gen_fn::DynamicDSLFunction)
print(io, "Gen DML generative function: $(gen_fn.julia_function)")
end

function DynamicDSLTrace(gen_fn::T, args, parameter_store::JuliaParameterStore) where {T<:DynamicDSLFunction}
function DynamicDSLTrace(
gen_fn::T, args, parameter_store::JuliaParameterStore,
parameter_context, registered_julia_parameters) where {T<:DynamicDSLFunction}
# pad args with default values, if available
if gen_fn.has_defaults && length(args) < length(gen_fn.arg_defaults)
defaults = gen_fn.arg_defaults[length(args)+1:end]
defaults = map(x -> something(x), defaults)
args = Tuple(vcat(collect(args), defaults))
end
return DynamicDSLTrace{T}(gen_fn, args, parameter_store)
return DynamicDSLTrace{T}(
gen_fn, args, parameter_store, parameter_context, registered_julia_parameters)
end

accepts_output_grad(gen_fn::DynamicDSLFunction) = gen_fn.accepts_output_grad
Expand Down Expand Up @@ -118,10 +157,10 @@ end
function dynamic_param_impl(expr::Expr)
@assert expr.head == :genparam "Not a Gen param expression."
name = expr.args[1]
Expr(:(=), name, Expr(:call, GlobalRef(@__MODULE__, :read_param), state, QuoteNode(name)))
Expr(:(=), name, Expr(:call, GlobalRef(@__MODULE__, :read_param!), state, QuoteNode(name)))
end

function read_param(state, name::Symbol)
function read_param!(state, name::Symbol)
parameter_id = get_parameter_id(state, name)
store = get_parameter_store(state)
return get_parameter_value(parameter_id, store)
Expand Down
19 changes: 10 additions & 9 deletions src/dynamic/generate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ mutable struct GFGenerateState

function GFGenerateState(gen_fn, args, constraints, parameter_context)
parameter_store = get_julia_store(parameter_context)
trace = DynamicDSLTrace(gen_fn, args, parameter_store)
registered_julia_parameters = Set{Tuple{GenerativeFunction,Symbol}}(
get_parameters(gen_fn, parameter_context)[parameter_store])
trace = DynamicDSLTrace(
gen_fn, args, parameter_store, parameter_context, registered_julia_parameters)
return new(trace, constraints, 0., AddressVisitor(), gen_fn, parameter_context)
end
end
Expand All @@ -23,7 +26,6 @@ function set_active_gen_fn!(state::GFGenerateState, gen_fn::GenerativeFunction)
state.active_gen_fn = gen_fn
end


function traceat(state::GFGenerateState, dist::Distribution{T},
args, key) where {T}
local retval::T
Expand Down Expand Up @@ -53,7 +55,7 @@ function traceat(state::GFGenerateState, dist::Distribution{T},
state.weight += score
end

retval
return retval
end

function traceat(state::GFGenerateState, gen_fn::GenerativeFunction{T,U},
Expand All @@ -69,8 +71,7 @@ function traceat(state::GFGenerateState, gen_fn::GenerativeFunction{T,U},

# get subtrace
(subtrace, weight) = generate(
gen_fn, args, constraints;
parameter_context=state.parameter_context)
gen_fn, args, constraints, state.parameter_context)

# add to the trace
add_call!(state.trace, key, subtrace)
Expand All @@ -81,14 +82,14 @@ function traceat(state::GFGenerateState, gen_fn::GenerativeFunction{T,U},
# get return value
retval = get_retval(subtrace)

retval
return retval
end

function generate(
gen_fn::DynamicDSLFunction, args::Tuple, constraints::ChoiceMap;
parameter_context=default_parameter_context)
gen_fn::DynamicDSLFunction, args::Tuple, constraints::ChoiceMap,
parameter_context::Dict)
state = GFGenerateState(gen_fn, args, constraints, parameter_context)
retval = exec(gen_fn, state, args)
set_retval!(state.trace, retval)
(state.trace, state.weight)
return (state.trace, state.weight)
end
8 changes: 3 additions & 5 deletions src/dynamic/propose.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ function traceat(state::GFProposeState, dist::Distribution{T},
# update weight
state.weight += logpdf(dist, retval, args...)

retval
return retval
end

function traceat(state::GFProposeState, gen_fn::GenerativeFunction{T,U},
Expand All @@ -56,12 +56,10 @@ function traceat(state::GFProposeState, gen_fn::GenerativeFunction{T,U},
# update weight
state.weight += weight

retval
return retval
end

function propose(
gen_fn::DynamicDSLFunction, args::Tuple;
parameter_context=default_parameter_context)
function propose(gen_fn::DynamicDSLFunction, args::Tuple, parameter_context::Dict)
state = GFProposeState(gen_fn, parameter_context)
retval = exec(gen_fn, state, args)
return (state.choices, state.weight, retval)
Expand Down
9 changes: 4 additions & 5 deletions src/dynamic/regenerate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@ mutable struct GFRegenerateState

function GFRegenerateState(gen_fn, args, prev_trace, selection)
visitor = AddressVisitor()
trace = DynamicDSLTrace(gen_fn, args, get_parameter_store(prev_trace))
return new(prev_trace, trace, selection,
0., visitor, gen_fn)
trace = initialize_from(prev_trace, args)
return new(prev_trace, trace, selection, 0.0, visitor, gen_fn)
end
end

Expand All @@ -24,7 +23,6 @@ function set_active_gen_fn!(state::GFRegenerateState, gen_fn::GenerativeFunction
state.active_gen_fn = gen_fn
end


function traceat(state::GFRegenerateState, dist::Distribution{T},
args, key) where {T}
local prev_retval::T
Expand Down Expand Up @@ -88,7 +86,8 @@ function traceat(state::GFRegenerateState, gen_fn::GenerativeFunction{T,U},
(subtrace, weight, _) = regenerate(
prev_subtrace, args, map((_) -> UnknownChange(), args), subselection)
else
(subtrace, weight) = generate(gen_fn, args, EmptyChoiceMap())
(subtrace, weight) = generate(
gen_fn, args, EmptyChoiceMap(), state.trace.parameter_context)
end

# update weight
Expand Down
11 changes: 6 additions & 5 deletions src/dynamic/simulate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ mutable struct GFSimulateState
function GFSimulateState(
gen_fn::GenerativeFunction, args::Tuple, parameter_context)
parameter_store = get_julia_store(parameter_context)
trace = DynamicDSLTrace(gen_fn, args, parameter_store)
registered_julia_parameters = Set{Tuple{GenerativeFunction,Symbol}}(
get_parameters(gen_fn, parameter_context)[parameter_store])
trace = DynamicDSLTrace(
gen_fn, args, parameter_store, parameter_context, registered_julia_parameters)
return new(trace, AddressVisitor(), gen_fn, parameter_context)
end
end
Expand Down Expand Up @@ -49,7 +52,7 @@ function traceat(state::GFSimulateState, gen_fn::GenerativeFunction{T,U},
visit!(state.visitor, key)

# get subtrace
subtrace = simulate(gen_fn, args; parameter_context=state.parameter_context)
subtrace = simulate(gen_fn, args, state.parameter_context)

# add to the trace
add_call!(state.trace, key, subtrace)
Expand All @@ -60,9 +63,7 @@ function traceat(state::GFSimulateState, gen_fn::GenerativeFunction{T,U},
retval
end

function simulate(
gen_fn::DynamicDSLFunction, args::Tuple;
parameter_context=default_parameter_context)
function simulate(gen_fn::DynamicDSLFunction, args::Tuple, parameter_context::Dict)
state = GFSimulateState(gen_fn, args, parameter_context)
retval = exec(gen_fn, state, args)
set_retval!(state.trace, retval)
Expand Down
17 changes: 15 additions & 2 deletions src/dynamic/trace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,27 @@ mutable struct DynamicDSLTrace{T} <: Trace
noise::Float64
args::Tuple
parameter_store::JuliaParameterStore
parameter_context::Dict
registered_julia_parameters::Set{Tuple{GenerativeFunction,Symbol}} # for runtime cross-check
retval::Any
function DynamicDSLTrace{T}(gen_fn::T, args, parameter_store::JuliaParameterStore) where {T}
function DynamicDSLTrace{T}(
gen_fn::T, args, parameter_store::JuliaParameterStore, parameter_context,
registered_julia_parameters::Set{Tuple{GenerativeFunction,Symbol}}) where {T}
trie = Trie{Any,ChoiceOrCallRecord}()
# retval is not known yet
new(gen_fn, trie, true, 0, 0, args, parameter_store)
new(
gen_fn, trie, true, 0, 0, args, parameter_store,
parameter_context, registered_julia_parameters)
end
end

function initialize_from(other::DynamicDSLTrace, args)
gen_fn = get_gen_fn(other)
return DynamicDSLTrace(
gen_fn, args, other.parameter_store, other.parameter_context,
other.registered_julia_parameters)
end

get_parameter_store(trace::DynamicDSLTrace) = trace.parameter_store

set_retval!(trace::DynamicDSLTrace, retval) = (trace.retval = retval)
Expand Down
10 changes: 5 additions & 5 deletions src/dynamic/update.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ mutable struct GFUpdateState
function GFUpdateState(gen_fn, args, prev_trace, constraints)
visitor = AddressVisitor()
discard = choicemap()
trace = DynamicDSLTrace(gen_fn, args, get_parameter_store(prev_trace))
return new(prev_trace, trace, constraints,
0., visitor, discard, gen_fn)
parameter_store = get_parameter_store(prev_trace)
trace = initialize_from(prev_trace, args)
return new(prev_trace, trace, constraints, 0.0, visitor, discard, gen_fn)
end
end

Expand All @@ -26,7 +26,6 @@ function set_active_gen_fn!(state::GFUpdateState, gen_fn::GenerativeFunction)
state.active_gen_fn = gen_fn
end


function traceat(state::GFUpdateState, dist::Distribution{T},
args::Tuple, key) where {T}

Expand Down Expand Up @@ -101,7 +100,8 @@ function traceat(state::GFUpdateState, gen_fn::GenerativeFunction{T,U},
(subtrace, weight, _, discard) = update(prev_subtrace,
args, map((_) -> UnknownChange(), args), constraints)
else
(subtrace, weight) = generate(gen_fn, args, constraints)
(subtrace, weight) = generate(
gen_fn, args, constraints, state.trace.parameter_context)
end

# update the weight
Expand Down
Loading

0 comments on commit 3fc5a0e

Please sign in to comment.