Skip to content

Commit

Permalink
ENH: Allow compound argument types when parsing wiring diagrams.
Browse files Browse the repository at this point in the history
  • Loading branch information
epatters committed Jan 28, 2020
1 parent 223fa72 commit 95537a6
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 10 deletions.
34 changes: 25 additions & 9 deletions src/programs/JuliaPrograms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,9 @@ function parse_wiring_diagram(pres::Presentation, expr::Expr)::WiringDiagram
end

function parse_wiring_diagram(pres::Presentation, call::Expr0, body::Expr)::WiringDiagram
# FIXME: Presentations should be uniquely associated with syntax systems.
syntax_module = Syntax.syntax_module(first(pres.generators))

# Parse argument names and types from call expression.
call_args = @match call begin
Expr(:call, [name, args...]) => args
Expand All @@ -249,13 +252,11 @@ function parse_wiring_diagram(pres::Presentation, call::Expr0, body::Expr)::Wiri
end
parsed_args = map(call_args) do arg
@match arg begin
Expr(:(::), [name::Symbol, type::Symbol]) => (name, type)
_ => error("Argument $arg is not simply typed")
Expr(:(::), [name::Symbol, type_expr::Expr0]) =>
(name, eval_type_expr(pres, syntax_module, type_expr))
_ => error("Argument $arg is missing name or type")
end
end

# FIXME: Presentations should be uniquely associated with syntax systems.
syntax_module = Syntax.syntax_module(first(pres.generators))

# Compile...
args = first.(parsed_args)
Expand All @@ -265,14 +266,15 @@ function parse_wiring_diagram(pres::Presentation, call::Expr0, body::Expr)::Wiri
func = eval(func_expr)

# ...and then evaluate function that records the function calls.
arg_types = last.(parsed_args)
arg_obs = syntax_module.Ob[ generator(pres, name) for name in arg_types ]
arg_obs = syntax_module.Ob[ last(arg) for arg in parsed_args ]
arg_blocks = Int[ length(to_wiring_diagram(ob)) for ob in arg_obs ]
inputs = to_wiring_diagram(otimes(arg_obs))
diagram = WiringDiagram(inputs, munit(typeof(inputs)))
v_in, v_out = input_id(diagram), output_id(diagram)
in_ports = [ Port(v_in, OutputPort, i) for i in eachindex(inputs) ]
arg_ports = [ Tuple(Port(v_in, OutputPort, i) for i in (stop-len+1):stop)
for (len, stop) in zip(arg_blocks, cumsum(arg_blocks)) ]
recorder = f -> (args...) -> record_call!(diagram, f, args...)
value = invokelatest(func, recorder, in_ports...; kwargs...)
value = invokelatest(func, recorder, arg_ports...; kwargs...)

# Add outgoing wires for return values.
out_ports = normalize_arguments((value,))
Expand Down Expand Up @@ -304,6 +306,20 @@ function make_lookup_table(pres::Presentation, syntax_module::Module, names)
table
end

""" Evaluate pseudo-Julia type expression, such as `X` or `otimes{X,Y}`.
"""
function eval_type_expr(pres::Presentation, syntax_module::Module, expr::Expr0)
function eval(expr)
@match expr begin
Expr(:curly, [name, args...]) =>
invoke_term(syntax_module, name, map(eval, args)...)
name::Symbol => generator(pres, name)
_ => error("Invalid type expression $expr")
end
end
eval(expr)
end

""" Generate a Julia function expression that will record function calls.
Rewrites the function body so that:
Expand Down
16 changes: 15 additions & 1 deletion test/programs/JuliaPrograms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ parsed = @parse_wiring_diagram(C, (x::X, y::Y) -> n([m(x,y), m(x,y)]))
parsed = @parse_wiring_diagram(C, () -> f([]))
@test parsed == to_wiring_diagram(compose(create(X),f))

# Special morphisms: explicit syntax.
# Explicit syntax for special objects and morphisms.

parsed = @parse_wiring_diagram(C, (x::X) -> id{X}(x))
@test parsed == to_wiring_diagram(id(X))
Expand All @@ -137,6 +137,20 @@ parsed = @parse_wiring_diagram(C, (x::X, y::Y) -> mcopy{otimes{X,Y}}(x,y))
parsed = @parse_wiring_diagram(C, (x::X, y::Y) -> delete{otimes{X,Y}}(x,y))
@test parsed == to_wiring_diagram(delete(otimes(X,Y)))

parsed = @parse_wiring_diagram(C, (xy::otimes{X,Y}) -> m(xy))
@test parsed == to_wiring_diagram(m)

parsed = @parse_wiring_diagram C (xy::otimes{X,Y}) begin
x, y = xy
(f(x), g(y))
end
@test parsed == to_wiring_diagram(otimes(f,g))

parsed = @parse_wiring_diagram C (xy::otimes{X,Y}, wz::otimes{W,Z}) begin
(m(xy), n(wz))
end
@test parsed == to_wiring_diagram(otimes(m,n))

# Helper function: normalization of arguments.

normalize(args...) = normalize_arguments(Tuple(args))
Expand Down

0 comments on commit 95537a6

Please sign in to comment.