diff --git a/Project.toml b/Project.toml index f776185..3351c90 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "1.0.0-DEV" [deps] InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" diff --git a/src/join_nodes.jl b/src/join_nodes.jl index fab528e..84f009f 100644 --- a/src/join_nodes.jl +++ b/src/join_nodes.jl @@ -1,5 +1,6 @@ using Base.Iterators: flatten +using IterTools: repeatedly export JoinNode @@ -12,15 +13,14 @@ to assert. """ struct JoinNode <: AbstractReteJoinNode label::String - a_inputs # ::Set{AbstractMemoryNode{T1}} - b_inputs # ::Set{AbstractMemoryNode{T2}} + inputs outputs::Set{AbstractReteNode} join_function - JoinNode(label::String, join_function) = + JoinNode(label::String, input_arity, join_function) = new(label, - Set{AbstractMemoryNode}(), - Set{AbstractMemoryNode}(), + Tuple(repeatedly(() -> Set{AbstractMemoryNode}(), + input_arity)), Set{AbstractReteNode}(), join_function) end @@ -35,26 +35,36 @@ label(n::JoinNode) = n.label function connect(from::AbstractReteNode, to::JoinNode, input::Int) @assert input >= 1 - @assert input <= 2 + @assert input <= length(to.inputs) push!(from.outputs, to) - if input == 1 - push!(to.a_inputs, from) - else - push!(to.b_inputs, from) - end + push!(to.inputs[input], from) end function receive(node::JoinNode, fact, from::AbstractMemoryNode) - if from in node.a_inputs - askc(node.b_inputs) do b_fact - node.join_function(node, fact, b_fact) - end - end - if from in node.b_inputs - askc(node.a_inputs) do a_fact - node.join_function(node, a_fact, fact) + args = Vector(undef, length(node.inputs)) + last_from_pos = findlast(map(i -> from in i, node.inputs)) + function helper(argnumber, hasfact) + if argnumber > length(args) + if hasfact + node.join_function(node, args...) + end + else + # Avoid computing more of the power set of arguments if + # we've not added fact and we've passed the last set of + # inputs that could contain from. + if !hasfact && argnumber > last_from_pos + return + end + for input in node.inputs[argnumber] + askc(input) do i_fact + args[argnumber] = i_fact + helper(argnumber + 1, + hasfact || (i_fact == fact)) + end + end end end + helper(1, false) end diff --git a/src/memory_nodes.jl b/src/memory_nodes.jl index 076d2e7..3fcbfa6 100644 --- a/src/memory_nodes.jl +++ b/src/memory_nodes.jl @@ -43,10 +43,10 @@ function receive(node::IsaMemoryNode{T}, fact::T) where{T} if fact in node.memory return end + push!(node.memory, fact) for output in node.outputs emit(node, output, fact) end - push!(node.memory, fact) end diff --git a/test/runtests.jl b/test/runtests.jl index 91ef604..6a17962 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -39,7 +39,7 @@ end ints = IsaMemoryNode{Int}() connect(root, chars) connect(root, ints) - join = JoinNode("join char int", + join = JoinNode("join char int", 2, function(node, c, i) emit(node, "$c$i") end) @@ -54,15 +54,17 @@ end for i in 1:3 receive(root, i) end - println(conclusions.memory) - @test conclusions.memory == Set{String}([ - "a1", "b1", "c1", "a2", "b2", "c2", "a3", "b3", "c3"]) + results = collecting() do c + askc(c, conclusions) + end + @test sort(results) == + sort(["a1", "b1", "c1", "a2", "b2", "c2", "a3", "b3", "c3"]) end @testset "symetric join test" begin root = BasicReteNode("root") ints = IsaMemoryNode{Int}() - join = JoinNode("join", + join = JoinNode("join", 2, function(node, a, b) if b == a + 1 emit(node, (a, b)) @@ -77,9 +79,41 @@ end for i in 1:5 receive(root, i) end - @test conclusions.memory == - Set{Tuple{Int, Int}}([ - (1, 2), (2, 3), (3, 4), (4, 5)]) + results = collecting() do c + askc(c, conclusions) + end + @test sort(results) == + sort([(1, 2), (2, 3), (3, 4), (4, 5)]) +end + +@testset "3-ary join" begin + root = BasicReteNode("root") + ints = IsaMemoryNode{Int}() + chars = IsaMemoryNode{Char}() + join = JoinNode("join", 3, + function(node, a, b, c) + if a != c + emit(node, "$a$b$c") + end + end) + conclusions = IsaMemoryNode{String}() + connect(root, ints) + connect(root, chars) + connect(root, conclusions) + connect(ints, join, 1) + connect(chars, join, 2) + connect(ints, join, 3) + connect(join, root) + for i in 1:2 + receive(root, i) + end + for c in 'a':'b' + receive(root, c) + end + results = collecting() do c + askc(c, conclusions) + end + @test sort(results) == sort(["1a2", "2a1", "1b2", "2b1"]) end @testset "ensure_IsaMemoryNode" begin