Skip to content

Commit

Permalink
partial abstraction, merging+substituting variables
Browse files Browse the repository at this point in the history
  • Loading branch information
Kris Brown committed Oct 14, 2023
1 parent 8f785c7 commit 909ba03
Show file tree
Hide file tree
Showing 6 changed files with 205 additions and 88 deletions.
54 changes: 28 additions & 26 deletions src/categorical_algebra/CSets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,6 @@ using StructEquality

# Morphism search
#################

"""
Given a span of morphisms, we seek to find morphisms B → C that make a
commuting triangle if possible.
B
g ↗ ↘ ?
A ⟶ C
f
Accepts homomorphism search keyword arguments to constrain the Hom(B,C) search.
"""
function extend_morphisms(f::ACSetTransformation, g::ACSetTransformation;
initial=Dict(), kw...
)::Vector{ACSetTransformation}
init = combine_dicts!(extend_morphism_constraints(f,g), initial)
isnothing(init) ? [] : homomorphisms(codom(g), codom(f); initial=init, kw...)
end

"""
Combine a user-specified initial dict with that generated by constraints
Expand Down Expand Up @@ -357,32 +339,52 @@ end

"""
Given a value for each variable, create a morphism X → X′ which applies the
substitution. We do this via pushout.
substitution. We do this via pushout.
O --> X where C has AttrVars for `merge` equivalence classes
↓ and O has only AttrVars (sent to concrete values or eq classes
C in the map to C.
`subs` and `merge` are dictionaries keyed by attrtype names
`subs` is a dictionary (keyed by attrtype names) of int-keyed dictionaries
`subs` values are int-keyed dictionaries indicating binding, e.g.
`; subs = (Weight = Dict(1 => 3.20, 5 => 2.32), ...)`
`merge` values are vectors of vectors indicating equivalence classes, e.g.
`; merge = (Weight = [[2,3], [4,6]], ...)`
"""
function sub_vars(X::ACSet, subs::AbstractDict)
function sub_vars(X::ACSet, subs::AbstractDict=Dict(), merge::AbstractDict=Dict())
S = acset_schema(X)
O, C = [constructor(X)() for _ in 1:2]
ox_, oc_ = Dict(), Dict()
ox_, oc_ = Dict{Symbol, Any}(), Dict{Symbol,Any}()
for at in attrtypes(S)
d = get(subs, at, Dict())
ox_[at] = AttrVar.(filter(p->p keys(d) && !(d[p] isa AttrVar), parts(X,at)))
oc_[at] = [d[p.val] for p in ox_[at]]
oc_[at] = Any[d[p.val] for p in ox_[at]]
add_parts!(O, at, length(oc_[at]))
end

for eq in get(merge, at, [])
isempty(eq) && error("Cannot have empty eq class")
c = AttrVar(add_part!(C, at))
for var in eq
add_part!(O, at)
push!(ox_[at], AttrVar(var))
push!(oc_[at], c)
end
end
end
ox = ACSetTransformation(O,X; ox_...)
oc = ACSetTransformation(O,C; oc_...)
return first(legs(pushout(ox, oc)))
end


# TODO replace with CSetTransformation limit when Catlab 0.16 is released

"""
Take an ACSet pullback combinatorially and freely add variables for all
attribute subparts.
TODO do var_limit, more generally
This relies on implementation details of `abstract`.
"""
function var_pullback(c::Cospan{<:StructACSet{S,Ts}}) where {S,Ts}
Expand Down
195 changes: 145 additions & 50 deletions src/rewrite/PBPO.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ import Catlab.CategoricalAlgebra: left, right
using Catlab.CategoricalAlgebra.CSets: backtracking_search, abstract_attributes

using StructEquality
using DataStructures: DefaultDict

using ACSets.DenseACSets: types, attrtype_type
using ..Utils
import ..Utils:
rewrite_match_maps, get_matches, get_expr_binding_map, AbsRule, ruletype
Expand Down Expand Up @@ -89,9 +91,9 @@ function canon(l,r,tl,tk,l′)::PBPORule
end

"""
PBPO matches consist of *two* morphisms. First, a match L → G and secondly
PBPO matches consist of *two* morphisms. First, a match m: L → G, and secondly
a typing G → L′. With attributes, it is not so simple because G has concrete
values for attributes and L′ may have variables. Therefore, we actually the
values for attributes and L′ may have variables. Therefore, we actually change the
typing to map out of A, an abstracted version of G (with its attributes replaced
by variables). So we lift matches L->G to matches L->A, then search α∈Hom(A,L′).
Expand All @@ -104,9 +106,9 @@ is set to true.
L ⟶ L′
tl
m
m
L ⟶ G
tl ↓ ↘a ↑ (abstraction)
tl ↓ ↘a ↑ (abs = partial abstraction. Note `a` is `Labs` in the code.)
L′⟵ A
α
Expand All @@ -117,7 +119,7 @@ we can deduce precisely what m is by looking at α.
function get_matches(rule::PBPORule, G::ACSet; initial=nothing,
α_unique=true, random=false, n=-1, kw...)
S = acset_schema(G)
res = [] # Pairs of (m,α)
res = [] # Quadruples of of (m, Labs, abs, α)
L = codom(left(rule))

# Process the initial constraints for match morphism and typing morphism
Expand All @@ -137,54 +139,48 @@ function get_matches(rule::PBPORule, G::ACSet; initial=nothing,
m_seen = false # keeps track if α_unique is violated for each new m
if all(ac->apply_constraint(ac, m), rule.acs)
@debug "m: $([k=>collect(v) for (k,v) in pairs(components(m))])"
# Construct abtract version of G. ab: A->G
ab = abstract_attributes(G)
A = dom(ab) # not completely abstract: fill in where L has concrete attrs
for (a, cd, _) in attrs(S)
for (v, fv) in filter(v_->!(v_[2] isa AttrVar),collect(enumerate(L[a])))
A[m[cd](v), a] = fv

# Construct partially-abtract version of G. Labs: L->A and abs: A->G
Labs, abs = partial_abstract(m)
A = codom(Labs)

# If we have a built in function to deduce the adherence from the match
if !isnothing(rule.adherence)
init = rule.adherence(m)
# Return nothing if failure
if !isnothing(init)
αs = homomorphisms(A, codom(rule.tl); initial=init)
# Also return nothing if the result is not unique
if length(αs) ==1
push!(res, deepcopy((m, Labs, abs, only(αs))))
end
end
end
ab = remove_freevars(ab)
A = dom(ab) # now with free variables removed
# Construct a:L->A such that m = a;ab
ainit = NamedTuple(Dict(o=>collect(m[o]) for o in ob(S)))
a = only(homomorphisms(L, A; initial=ainit))
# Search for maps α: A -> L′ such that a;α=tl
init = combine_dicts!(extend_morphism_constraints(rule.tl,a), typinit)
if !isnothing(init)
# If we have a built in function to deduce the adherence from the match
if !isnothing(rule.adherence)
init = rule.adherence(m) # return nothing if failure
if !isnothing(init)
αs = homomorphisms(codom(a), codom(rule.tl); initial=init)
if length(αs) ==1
push!(res, deepcopy((m,a,ab,only(αs))))
end
end
else
# Search for adherence morphisms.
backtracking_search(codom(a), codom(rule.tl); initial=init, kw...) do α
@debug "\tα: ", [k=>collect(v) for (k,v) in pairs(components(α))]
strong_match = all(ob(S)) do o
all(parts(A,o)) do i
p1 = preimage(rule.tl[o],α[o](i))
p2 = preimage(a[o], i)
sort(p1) == sort(p2)
end
end
if strong_match && all(lc -> apply_constraint(lc, α), rule.lcs)
all(is_natural, [m,a,ab,α]) || error("Unnatural match")
if m_seen error("Multiple α for a single match $m") end
@debug "\tSUCCESS"
push!(res, deepcopy((m,a,ab,α)))
m_seen |= α_unique
return length(res) == n
else
@debug "\tFAILURE (strong $strong_match)"
return false
else
# Search for adherence morphisms: A -> L′
init = extend_morphism_constraints(rule.tl, Labs)
backtracking_search(A, codom(rule.tl); initial=init, kw...) do α
@debug "\tα: ", [k=>collect(v) for (k,v) in pairs(components(α))]

# Check strong match condition
strong_match = all(types(S)) do o
prt = o ob(S) ? identity : AttrVar
all(prt.(parts(A,o))) do i
p1 = preimage(rule.tl[o],α[o](i))
p2 = preimage(Labs[o], i)
p1 == p2
end
end
if strong_match && all(lc -> apply_constraint(lc, α), rule.lcs)
all(is_natural, [m, Labs, abs, α]) || error("Unnatural match")
if m_seen error("Multiple α for a single match $m") end
@debug "\tSUCCESS"
push!(res, deepcopy((m, Labs, abs, α)))
m_seen |= α_unique
return length(res) == n
else
@debug "\tFAILURE (strong $strong_match)"
return false
end
end
end
end
Expand All @@ -193,6 +189,105 @@ function get_matches(rule::PBPORule, G::ACSet; initial=nothing,
return res
end


"""
This construction addresses the following problem: ideally when we 'abstract'
an ACSet from X to A->X, maps *into* X, say B->X, can be canonically pulled back
to maps B->A which commute. However, A won't do
here, because there may not even exist any maps B->A. If B has concrete
attributes, then those cannot be sent to an AttrVar in A. Furthermore, if B
has multiple 'references' to an AttrVar (two different edges, each with
AttrVar(1), sent to two different edges with the same atttribute value in X),
then there is no longer a *canonical* place to send AttrVar(1) to in A, as there
is a distinct AttrVar for every single part+attr in X. So we need a construction
which does two things to A->X, starting with a map B->X. 1.) replaces exactly the
variables we need with concrete values in order to allow a map B->A, 2.) quotients
variables in A so that there is exactly one choice for where to send attrvars in
B such that the triangle commutes.
Starting with a map L -> G (where G has no AttrVars),
we want the analogous map into a "partially abstracted" version of G that
has concrete attributes replaced with AttrVars *EXCEPT* for those attributes
which are mapped to by concrete attributes of L. Likewise, multiple occurences
of the same variable in L correspond to AttrVars which should be merged in the
partially-abstracted G.
For example, for a schema with a single Ob and Attr (where all combinatorial
maps are just {1↦1, 2↦2}):
- L = [AttrVar(1), :foo]
- G = [:bar, :foo, :baz]
- abs(G) = [AttrVar(1), AttrVar(2), AttrVar(3)]
- expected result: [AttrVar(1), :foo, AttrVar(2)]
L -> Partial_abs(G)
↓ ↑
G <- abs(G)
This function computes the top arrow of this diagram starting with the left
arrow. The bottom arrow is computed by `abstract_attributes` and the right
arrow by `sub_vars`. Furthermore, a map from Partial_abs(G) to G is provided.
This is the factorization system arising from a coreflective subcategory.
(see https://ncatlab.org/nlab/show/reflective+factorization+system
and https://blog.algebraicjulia.org/post/2023/06/varacsets/)
"""
function partial_abstract(lg::ACSetTransformation)
L, G = dom(lg), codom(lg)
S = acset_schema(L)
abs_G = abstract_attributes(G)
A = dom(abs_G)

# Construct partially-abstracted G
#---------------------------------
subs = Dict{Symbol,Dict{Int}}()
merges = Dict{Symbol,Vector{Vector{Int}}}()
for at in attrtypes(S)
subdict = Dict{Int, Any}()
mergelist = DefaultDict{Int,Vector{Int}}(()->Int[])
for (f, o, _) in attrs(S; to=at)
for iₒ in parts(L, o)
var = A[lg[o](iₒ), f].val
val = L[iₒ, f]
if val isa AttrVar
push!(mergelist[val.val], var)
else
subdict[var] = val
end
end
end
subs[at] = subdict
merges[at] = collect(filter(l->!isempty(l), collect(values(mergelist))))
end
pabs_G = sub_vars(dom(abs_G), subs, merges)

# Construct maps
#---------------
prt(o) = o ob(S) ? identity : AttrVar
T(o) = o ob(S) ? Int : Union{AttrVar,attrtype_type(L, o)}

# The quotienting via `sub_vars` means L->PA determined purely by ob components
to_pabs_init = Dict{Symbol,Vector{Int}}(map(ob(S)) do o
o => map(prt(o).(collect(lg[o]))) do i
pabs_G[o](only(preimage(abs_G[o], i)))
end
end)

from_pabs_comps = Dict(map(types(S)) do o
comp = Vector{T(o)}(map(prt(o).(parts(codom(pabs_G), o))) do Pᵢ
only(unique([abs_G[o](prt(o)(pi)) for pi in preimage(pabs_G[o], Pᵢ)]))
end)
o => comp
end)

to_pabs = only(homomorphisms(L, codom(pabs_G); initial=to_pabs_init))
from_pabs = ACSetTransformation(codom(pabs_G), codom(lg); from_pabs_comps...)
ComposablePair(to_pabs, from_pabs)
end

"""
r
K ----> R
Expand Down
2 changes: 1 addition & 1 deletion src/rewrite/Representable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ end
function yoneda_cache(T::Type,S=nothing; clear=false, cache="cache")
S = isnothing(S) ? Presentation(T) : S
tname = nameof(T) |> string
cache_dict = Dict{Symbol,Tuple{T,Int}}(Iterators.map(generators(S, :Ob)) do ob
cache_dict = Dict{Symbol,Tuple{T,Int}}(map(generators(S, :Ob)) do ob
name = nameof(ob)
cache_dir = mkpath(joinpath(cache, "$tname"))
path, ipath = joinpath.(cache_dir, ["$name.json", "_id_$name.json"])
Expand Down
2 changes: 1 addition & 1 deletion src/schedules/Basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ struct Initialize <: AgentBox
name::Symbol
state::StructACSet
in_agent::Union{Nothing,StructACSet}
Initialize(s, in_agent=nothing, n="") = new(n,s,in_agent)
Initialize(s, in_agent=nothing, n=Symbol("")) = new(n,s,in_agent)

Check warning on line 66 in src/schedules/Basic.jl

View check run for this annotation

Codecov / codecov/patch

src/schedules/Basic.jl#L66

Added line #L66 was not covered by tests
end
input_ports(r::Initialize) = isnothing(r.in_agent) ? [] : [r.in_agent]
output_ports(r::Initialize) = [typeof(r.state)()]
Expand Down
16 changes: 8 additions & 8 deletions src/schedules/Wiring.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ Names(;kw...) = Names(Dict([string(k)=>v for (k,v) in pairs(kw)]))
Base.getindex(n::Names,s::String) = n.from_name[s]
Base.getindex(n::Names,s::Symbol) = n[string(s)]
Base.getindex(n::Names,x)::String = get(n.to_name,x,"?")
function Base.setindex(n::Names, x::String, y)
Base.length(n::Names) = length(n.from_name)
function Base.setindex!(n::Names{T}, y::T, x::String) where T

Check warning on line 39 in src/schedules/Wiring.jl

View check run for this annotation

Codecov / codecov/patch

src/schedules/Wiring.jl#L38-L39

Added lines #L38 - L39 were not covered by tests
n.from_name[x] = y
n.to_name[y] = x
end
Expand All @@ -49,11 +50,11 @@ Make a wiring diagram with ob/hom generators using @program macro
TODO double check that this does not introduce any wire splitting.
"""
function mk_sched(t_args::NamedTuple,args::NamedTuple,names::Names,
kw::Union{NamedTuple,AbstractDict}, wd::Expr)
function mk_sched(t_args::NamedTuple,args::NamedTuple,names::Names{T},
kw::Union{NamedTuple,AbstractDict}, wd::Expr) where T
n_trace=length(t_args)
os = Dict(Symbol(k)=>v for (k,v) in collect(names.from_name))
hs = Dict(Symbol(k)=>v isa AgentBox ? singleton(v) : v for (k,v) in pairs(kw))
os = Dict{Symbol, T}(Symbol(k)=>v for (k,v) in collect(names.from_name))
hs = Dict{Symbol, Schedule}(Symbol(k)=>v isa AgentBox ? singleton(v) : v for (k,v) in pairs(kw))
P = Presentation(TM)
os_ = Dict(v=>add_generator!(P, Ob(TM,k)) for (k,v) in collect(os))

Expand All @@ -65,7 +66,6 @@ function mk_sched(t_args::NamedTuple,args::NamedTuple,names::Names,
add_generator!(P, Hom(k, i, o))
end
args_ = Expr(:tuple,[Expr(Symbol("::"), k,v) for (k,v) in pairs(merge(t_args,args))]...)

tmp = parse_wiring_diagram(P, args_, wd)
Xports = Ports{ThTracedMonoidalWithBidiagonals}(input_ports(tmp)[1:n_trace])
newer_x = Ob(TM,Xports) # arbitrary gatexpr
Expand Down Expand Up @@ -95,7 +95,7 @@ function mk_sched(t_args::NamedTuple,args::NamedTuple,names::Names,
end

new_d = trace(Xports, tmp)
sub = ocompose(new_d, [hs[Symbol(b.value)].d for b in boxes(new_d)])
sub = ocompose(new_d, WiringDiagram[hs[Symbol(b.value)].d for b in boxes(new_d)])
sub.diagram[:wire_value] = nothing
for x in Symbol.(["$(x)_port_type" for x in [:outer_in,:outer_out,]])
sub.diagram[:,x] = [names[v] for v in sub.diagram[x]]
Expand Down Expand Up @@ -240,4 +240,4 @@ merge_wires(agent::StructACSet, n::Int=2)::Schedule =
id(agents::AbstractVector{<:StructACSet}) = id(SPorts(Ports(agents)))


end # module
end # module
Loading

0 comments on commit 909ba03

Please sign in to comment.