diff --git a/src/programs/JuliaPrograms.jl b/src/programs/JuliaPrograms.jl index b5708a7cc..7e5c4358d 100644 --- a/src/programs/JuliaPrograms.jl +++ b/src/programs/JuliaPrograms.jl @@ -239,6 +239,9 @@ function parse_wiring_diagram(pres::Presentation, expr::Expr)::WiringDiagram end function parse_wiring_diagram(pres::Presentation, call::Expr0, body::Expr)::WiringDiagram + # FIXME: Presentations should be uniquely associated with syntax systems. + syntax_module = Syntax.syntax_module(first(pres.generators)) + # Parse argument names and types from call expression. call_args = @match call begin Expr(:call, [name, args...]) => args @@ -249,13 +252,11 @@ function parse_wiring_diagram(pres::Presentation, call::Expr0, body::Expr)::Wiri end parsed_args = map(call_args) do arg @match arg begin - Expr(:(::), [name::Symbol, type::Symbol]) => (name, type) - _ => error("Argument $arg is not simply typed") + Expr(:(::), [name::Symbol, type_expr::Expr0]) => + (name, eval_type_expr(pres, syntax_module, type_expr)) + _ => error("Argument $arg is missing name or type") end end - - # FIXME: Presentations should be uniquely associated with syntax systems. - syntax_module = Syntax.syntax_module(first(pres.generators)) # Compile... args = first.(parsed_args) @@ -265,14 +266,15 @@ function parse_wiring_diagram(pres::Presentation, call::Expr0, body::Expr)::Wiri func = eval(func_expr) # ...and then evaluate function that records the function calls. - arg_types = last.(parsed_args) - arg_obs = syntax_module.Ob[ generator(pres, name) for name in arg_types ] + arg_obs = syntax_module.Ob[ last(arg) for arg in parsed_args ] + arg_blocks = Int[ length(to_wiring_diagram(ob)) for ob in arg_obs ] inputs = to_wiring_diagram(otimes(arg_obs)) diagram = WiringDiagram(inputs, munit(typeof(inputs))) v_in, v_out = input_id(diagram), output_id(diagram) - in_ports = [ Port(v_in, OutputPort, i) for i in eachindex(inputs) ] + arg_ports = [ Tuple(Port(v_in, OutputPort, i) for i in (stop-len+1):stop) + for (len, stop) in zip(arg_blocks, cumsum(arg_blocks)) ] recorder = f -> (args...) -> record_call!(diagram, f, args...) - value = invokelatest(func, recorder, in_ports...; kwargs...) + value = invokelatest(func, recorder, arg_ports...; kwargs...) # Add outgoing wires for return values. out_ports = normalize_arguments((value,)) @@ -304,6 +306,20 @@ function make_lookup_table(pres::Presentation, syntax_module::Module, names) table end +""" Evaluate pseudo-Julia type expression, such as `X` or `otimes{X,Y}`. +""" +function eval_type_expr(pres::Presentation, syntax_module::Module, expr::Expr0) + function eval(expr) + @match expr begin + Expr(:curly, [name, args...]) => + invoke_term(syntax_module, name, map(eval, args)...) + name::Symbol => generator(pres, name) + _ => error("Invalid type expression $expr") + end + end + eval(expr) +end + """ Generate a Julia function expression that will record function calls. Rewrites the function body so that: diff --git a/test/programs/JuliaPrograms.jl b/test/programs/JuliaPrograms.jl index 4cfdd2102..87a632444 100644 --- a/test/programs/JuliaPrograms.jl +++ b/test/programs/JuliaPrograms.jl @@ -114,7 +114,7 @@ parsed = @parse_wiring_diagram(C, (x::X, y::Y) -> n([m(x,y), m(x,y)])) parsed = @parse_wiring_diagram(C, () -> f([])) @test parsed == to_wiring_diagram(compose(create(X),f)) -# Special morphisms: explicit syntax. +# Explicit syntax for special objects and morphisms. parsed = @parse_wiring_diagram(C, (x::X) -> id{X}(x)) @test parsed == to_wiring_diagram(id(X)) @@ -137,6 +137,20 @@ parsed = @parse_wiring_diagram(C, (x::X, y::Y) -> mcopy{otimes{X,Y}}(x,y)) parsed = @parse_wiring_diagram(C, (x::X, y::Y) -> delete{otimes{X,Y}}(x,y)) @test parsed == to_wiring_diagram(delete(otimes(X,Y))) +parsed = @parse_wiring_diagram(C, (xy::otimes{X,Y}) -> m(xy)) +@test parsed == to_wiring_diagram(m) + +parsed = @parse_wiring_diagram C (xy::otimes{X,Y}) begin + x, y = xy + (f(x), g(y)) +end +@test parsed == to_wiring_diagram(otimes(f,g)) + +parsed = @parse_wiring_diagram C (xy::otimes{X,Y}, wz::otimes{W,Z}) begin + (m(xy), n(wz)) +end +@test parsed == to_wiring_diagram(otimes(m,n)) + # Helper function: normalization of arguments. normalize(args...) = normalize_arguments(Tuple(args))