From 5b4e6a576e3433b994cd5a67629891e5456da151 Mon Sep 17 00:00:00 2001 From: MarkNahabedian Date: Sat, 4 Jan 2025 11:39:09 -0500 Subject: [PATCH] Clearer thinking: is_memory_for_type should use == rather than <:. --- src/memory_nodes.jl | 11 ++++++++--- src/node_abstraction.jl | 3 +++ test/rule_example_2.jl | 6 ++++-- test/runtests.jl | 2 +- test/test_thing_subthing.jl | 37 +++++++++++++++++++++++++++++++++++++ 5 files changed, 53 insertions(+), 6 deletions(-) create mode 100644 test/test_thing_subthing.jl diff --git a/src/memory_nodes.jl b/src/memory_nodes.jl index 58e68a8..66e1bd4 100644 --- a/src/memory_nodes.jl +++ b/src/memory_nodes.jl @@ -44,11 +44,12 @@ label(node::IsaMemoryNode{T}) where {T} = "isa $T memory" is_memory_for_type(node, typ::Type)::Bool returns `true` if `node` stores objects of the specified type. +`typ` must batch the type stured by `node`, not merely be a subtype. Used by [`find_memory_for_type`](@ref). """ is_memory_for_type(node::IsaMemoryNode, typ::Type)::Bool = - typeof(node).parameters[1] <: typ + typeof(node).parameters[1] == typ Rete.is_memory_for_type(::AbstractMemoryNode, ::Any) = false @@ -77,8 +78,12 @@ end """ askc(continuation, root::AbstractReteRootNode, t::Type) -calls `continuation` on every fact of the specified type (or its -subtypes) that are stored in the network rooted at `root`. +calls `continuation` on every fact of the specified type that are +stored in the network rooted at `root`. + +Does not consider subtypes because that could lead to `continuation` +being called on the same fact more than once (from the memory node for +the type itself and from the memory nodes of subtypes). Assumes all memory nodes are direct outputs of `root`. diff --git a/src/node_abstraction.jl b/src/node_abstraction.jl index 383d888..d2877cf 100644 --- a/src/node_abstraction.jl +++ b/src/node_abstraction.jl @@ -22,6 +22,9 @@ determine if it stores that type of fact. A memory node should remember exactly one copy of each fact it receives and return each fact it has remembered exactly once for any given call to [`askc`](@ref). + +A memory node should only remember facts which match the type that the +memory node is defined to store. Not any of its subtypes. """ abstract type AbstractMemoryNode <: AbstractReteNode end diff --git a/test/rule_example_2.jl b/test/rule_example_2.jl index 6ced910..cd4feec 100644 --- a/test/rule_example_2.jl +++ b/test/rule_example_2.jl @@ -40,8 +40,10 @@ inputs \toutputs \tfacts \tlabel sort(["abc", "bcd", "cde", "def", "efg"]) # Test askc for subtypes: count_by_type = DefaultDict{Type, Int}(0) - askc(root, Any) do fact - count_by_type[typeof(fact)] += 1 + for o in root.outputs + askc(o) do fact + count_by_type[typeof(fact)] += 1 + end end @test Set(collect(count_by_type)) == Set([Char => 7, diff --git a/test/runtests.jl b/test/runtests.jl index 371e148..ef2c491 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -176,4 +176,4 @@ include("rule_grouping.jl") include("three_parameter_rule_example.jl") include("test_rule_decls.jl") include("test_copy_facts.jl") - +include("test_thing_subthing.jl") diff --git a/test/test_thing_subthing.jl b/test/test_thing_subthing.jl new file mode 100644 index 0000000..0e6f3ce --- /dev/null +++ b/test/test_thing_subthing.jl @@ -0,0 +1,37 @@ + +abstract type AbstractThing end + +struct Thing1 <: AbstractThing + c::Char +end + +struct Thing2 <: AbstractThing + c::Char +end + +@rule JuxtaposeAbstractThingsRule(t1::AbstractThing, t2::AbstractThing, ::Tuple) begin + emit((t1.c, t2.c)) +end + +@rule JuxtaposeThings12Rule(t1::Thing1, t2::Thing2, ::String) begin + emit("$(t1.c)$(t2.c)") +end + +@testset "test thing/subthing" begin + root = ReteRootNode("root") + install(root, JuxtaposeAbstractThingsRule) + install(root, JuxtaposeThings12Rule) + for c in 'a':'c' + receive(root, Thing1(c)) + end + for c in 'd':'e' + receive(root, Thing2(c)) + end + @test askc(Counter(), root, Thing1) == 3 + @test askc(Counter(), root, Thing2) == 2 + @test askc(Counter(), root, AbstractThing) == 5 + @test askc(Counter(), root, String) == 6 + @test askc(Counter(), root, Tuple) == 25 +end + +