Skip to content

Commit

Permalink
N-ary join
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkNahabedian committed Mar 14, 2024
1 parent 1485424 commit 6fc92d9
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 28 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "1.0.0-DEV"

[deps]
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"

Expand Down
48 changes: 29 additions & 19 deletions src/join_nodes.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

using Base.Iterators: flatten
using IterTools: repeatedly

export JoinNode

Expand All @@ -12,15 +13,14 @@ to assert.
"""
struct JoinNode <: AbstractReteJoinNode
label::String
a_inputs # ::Set{AbstractMemoryNode{T1}}
b_inputs # ::Set{AbstractMemoryNode{T2}}
inputs
outputs::Set{AbstractReteNode}
join_function

JoinNode(label::String, join_function) =
JoinNode(label::String, input_arity, join_function) =
new(label,
Set{AbstractMemoryNode}(),
Set{AbstractMemoryNode}(),
Tuple(repeatedly(() -> Set{AbstractMemoryNode}(),
input_arity)),
Set{AbstractReteNode}(),
join_function)
end
Expand All @@ -35,26 +35,36 @@ label(n::JoinNode) = n.label

function connect(from::AbstractReteNode, to::JoinNode, input::Int)
@assert input >= 1
@assert input <= 2
@assert input <= length(to.inputs)
push!(from.outputs, to)
if input == 1
push!(to.a_inputs, from)
else
push!(to.b_inputs, from)
end
push!(to.inputs[input], from)
end


function receive(node::JoinNode, fact, from::AbstractMemoryNode)
if from in node.a_inputs
askc(node.b_inputs) do b_fact
node.join_function(node, fact, b_fact)
end
end
if from in node.b_inputs
askc(node.a_inputs) do a_fact
node.join_function(node, a_fact, fact)
args = Vector(undef, length(node.inputs))
last_from_pos = findlast(map(i -> from in i, node.inputs))
function helper(argnumber, hasfact)
if argnumber > length(args)
if hasfact
node.join_function(node, args...)
end
else
# Avoid computing more of the power set of arguments if
# we've not added fact and we've passed the last set of
# inputs that could contain from.
if !hasfact && argnumber > last_from_pos
return
end
for input in node.inputs[argnumber]
askc(input) do i_fact
args[argnumber] = i_fact
helper(argnumber + 1,
hasfact || (i_fact == fact))
end
end
end
end
helper(1, false)
end

2 changes: 1 addition & 1 deletion src/memory_nodes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ function receive(node::IsaMemoryNode{T}, fact::T) where{T}
if fact in node.memory
return
end
push!(node.memory, fact)
for output in node.outputs
emit(node, output, fact)
end
push!(node.memory, fact)
end


Expand Down
50 changes: 42 additions & 8 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ end
ints = IsaMemoryNode{Int}()
connect(root, chars)
connect(root, ints)
join = JoinNode("join char int",
join = JoinNode("join char int", 2,
function(node, c, i)
emit(node, "$c$i")
end)
Expand All @@ -54,15 +54,17 @@ end
for i in 1:3
receive(root, i)
end
println(conclusions.memory)
@test conclusions.memory == Set{String}([
"a1", "b1", "c1", "a2", "b2", "c2", "a3", "b3", "c3"])
results = collecting() do c
askc(c, conclusions)
end
@test sort(results) ==
sort(["a1", "b1", "c1", "a2", "b2", "c2", "a3", "b3", "c3"])
end

@testset "symetric join test" begin
root = BasicReteNode("root")
ints = IsaMemoryNode{Int}()
join = JoinNode("join",
join = JoinNode("join", 2,
function(node, a, b)
if b == a + 1
emit(node, (a, b))
Expand All @@ -77,9 +79,41 @@ end
for i in 1:5
receive(root, i)
end
@test conclusions.memory ==
Set{Tuple{Int, Int}}([
(1, 2), (2, 3), (3, 4), (4, 5)])
results = collecting() do c
askc(c, conclusions)
end
@test sort(results) ==
sort([(1, 2), (2, 3), (3, 4), (4, 5)])
end

@testset "3-ary join" begin
root = BasicReteNode("root")
ints = IsaMemoryNode{Int}()
chars = IsaMemoryNode{Char}()
join = JoinNode("join", 3,
function(node, a, b, c)
if a != c
emit(node, "$a$b$c")
end
end)
conclusions = IsaMemoryNode{String}()
connect(root, ints)
connect(root, chars)
connect(root, conclusions)
connect(ints, join, 1)
connect(chars, join, 2)
connect(ints, join, 3)
connect(join, root)
for i in 1:2
receive(root, i)
end
for c in 'a':'b'
receive(root, c)
end
results = collecting() do c
askc(c, conclusions)
end
@test sort(results) == sort(["1a2", "2a1", "1b2", "2b1"])
end

@testset "ensure_IsaMemoryNode" begin
Expand Down

0 comments on commit 6fc92d9

Please sign in to comment.