Skip to content

Commit

Permalink
wip: undo all changes
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Jul 9, 2024
1 parent 1e672bc commit 6146408
Show file tree
Hide file tree
Showing 7 changed files with 251 additions and 225 deletions.
1 change: 0 additions & 1 deletion src/DynamicExpressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ import .NodeModule:
constructorof,
with_type_parameters,
preserve_sharing,
max_degree,
leaf_copy,
branch_copy,
leaf_hash,
Expand Down
221 changes: 104 additions & 117 deletions src/Node.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,34 +3,30 @@ module NodeModule
using DispatchDoctor: @unstable

import ..OperatorEnumModule: AbstractOperatorEnum
import ..UtilsModule: deprecate_varmap, Undefined
import ..UtilsModule: @memoize_on, @with_memoize, deprecate_varmap, Undefined

const DEFAULT_NODE_TYPE = Float32

"""
AbstractNode{D}
AbstractNode
Abstract type for D-arity trees. Must have the following fields:
Abstract type for binary trees. Must have the following fields:
- `degree::Integer`: Degree of the node. Either 0, 1, or 2. If 1,
then `l` needs to be defined as the left child. If 2,
then `r` also needs to be defined as the right child.
- `children`: A collection of D references to children nodes.
# Deprecated fields
- `l::AbstractNode{D}`: Left child of the current node. Should only be
- `l::AbstractNode`: Left child of the current node. Should only be
defined if `degree >= 1`; otherwise, leave it undefined (see the
the constructors of [`Node{T}`](@ref) for an example).
Don't use `nothing` to represent an undefined value
as it will incur a large performance penalty.
- `r::AbstractNode{D}`: Right child of the current node. Should only
- `r::AbstractNode`: Right child of the current node. Should only
be defined if `degree == 2`.
"""
abstract type AbstractNode{D} end
abstract type AbstractNode end

"""
AbstractExpressionNode{T,D} <: AbstractNode{D}
AbstractExpressionNode{T} <: AbstractNode
Abstract type for nodes that represent an expression.
Along with the fields required for `AbstractNode`,
Expand Down Expand Up @@ -71,27 +67,11 @@ You likely do not need to, but you could choose to override the following:
- `with_type_parameters`
"""
abstract type AbstractExpressionNode{T,D} <: AbstractNode{D} end

for N in (:Node, :GraphNode)
@eval mutable struct $N{T,D} <: AbstractExpressionNode{T,D}
degree::UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc.
constant::Bool # false if variable
val::T # If is a constant, this stores the actual value
feature::UInt16 # (Possibly undefined) If is a variable (e.g., x in cos(x)), this stores the feature index.
op::UInt8 # (Possibly undefined) If operator, this is the index of the operator in the degree-specific operator enum
children::NTuple{D,Base.RefValue{$N{T,D}}} # Children nodes

#################
## Constructors:
#################
$N{_T,_D}() where {_T,_D} = new{_T,_D::Int}()
end
end
abstract type AbstractExpressionNode{T} <: AbstractNode end

#! format: off
"""
Node{T,D} <: AbstractExpressionNode{T,D}
Node{T} <: AbstractExpressionNode{T}
Node defines a symbolic expression stored in a binary tree.
A single `Node` instance is one "node" of this tree, and
Expand All @@ -101,42 +81,63 @@ nodes, you can evaluate or print a given expression.
# Fields
- `degree::UInt8`: Degree of the node. 0 for constants, 1 for
unary operators, 2 for binary operators, etc. Maximum of `D`.
unary operators, 2 for binary operators.
- `constant::Bool`: Whether the node is a constant.
- `val::T`: Value of the node. If `degree==0`, and `constant==true`,
this is the value of the constant. It has a type specified by the
overall type of the `Node` (e.g., `Float64`).
- `feature::UInt16`: Index of the feature to use in the
case of a feature node. Only defined if `degree == 0 && constant == false`.
case of a feature node. Only used if `degree==0` and `constant==false`.
Only defined if `degree == 0 && constant == false`.
- `op::UInt8`: If `degree==1`, this is the index of the operator
in `operators.unaops`. If `degree==2`, this is the index of the
operator in `operators.binops`. In other words, this is an enum
of the operators, and is dependent on the specific `OperatorEnum`
object. Only defined if `degree >= 1`
- `children::NTuple{D,Base.RefValue{Node{T,D}}}`: Children of the node. Only defined up to `degree`
- `l::Node{T}`: Left child of the node. Only defined if `degree >= 1`.
Same type as the parent node.
- `r::Node{T}`: Right child of the node. Only defined if `degree == 2`.
Same type as the parent node. This is to be passed as the right
argument to the binary operator.
# Constructors
Node([T]; val=nothing, feature=nothing, op=nothing, children=nothing, allocator=default_allocator)
Node{T}(; val=nothing, feature=nothing, op=nothing, children=nothing, allocator=default_allocator)
Node([T]; val=nothing, feature=nothing, op=nothing, l=nothing, r=nothing, children=nothing, allocator=default_allocator)
Node{T}(; val=nothing, feature=nothing, op=nothing, l=nothing, r=nothing, children=nothing, allocator=default_allocator)
Create a new node in an expression tree. If `T` is not specified in either the type or the
first argument, it will be inferred from the value of `val` passed or the children.
The `children` keyword is used to pass in a collection of children nodes.
first argument, it will be inferred from the value of `val` passed or `l` and/or `r`.
If it cannot be inferred from these, it will default to `Float32`.
The `children` keyword can be used instead of `l` and `r` and should be a tuple of children. This
is to permit the use of splatting in constructors.
You may also construct nodes via the convenience operators generated by creating an `OperatorEnum`.
You may also choose to specify a default memory allocator for the node other than simply `Node{T}()`
in the `allocator` keyword argument.
"""
Node

mutable struct Node{T} <: AbstractExpressionNode{T}
degree::UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc.
constant::Bool # false if variable
val::T # If is a constant, this stores the actual value
# ------------------- (possibly undefined below)
feature::UInt16 # If is a variable (e.g., x in cos(x)), this stores the feature index.
op::UInt8 # If operator, this is the index of the operator in operators.binops, or operators.unaops
l::Node{T} # Left child node. Only defined for degree=1 or degree=2.
r::Node{T} # Right child node. Only defined for degree=2.

#################
## Constructors:
#################
Node{_T}() where {_T} = new{_T}()
end

"""
GraphNode{T,D} <: AbstractExpressionNode{T,D}
GraphNode{T} <: AbstractExpressionNode{T}
Exactly the same as [`Node{T,D}`](@ref), but with the assumption that some
Exactly the same as [`Node{T}`](@ref), but with the assumption that some
nodes will be shared. All copies of this graph-like structure will
be performed with this assumption, to preserve structure of the graph.
Expand All @@ -145,7 +146,7 @@ be performed with this assumption, to preserve structure of the graph.
```julia
julia> operators = OperatorEnum(;
binary_operators=[+, -, *], unary_operators=[cos, sin]
);
);
julia> x = GraphNode(feature=1)
x1
Expand All @@ -164,38 +165,17 @@ This has the same constructors as [`Node{T}`](@ref). Shared nodes
are created simply by using the same node in multiple places
when constructing or setting properties.
"""
GraphNode

@inline function Base.getproperty(n::Union{Node,GraphNode}, k::Symbol)
if k == :l
# TODO: Should a depwarn be raised here? Or too slow?
return getfield(n, :children)[1][]
elseif k == :r
return getfield(n, :children)[2][]
else
return getfield(n, k)
end
end
@inline function Base.setproperty!(n::Union{Node,GraphNode}, k::Symbol, v)
if k == :l
getfield(n, :children)[1][] = v
elseif k == :r
getfield(n, :children)[2][] = v
elseif k == :degree
setfield!(n, :degree, convert(UInt8, v))
elseif k == :constant
setfield!(n, :constant, convert(Bool, v))
elseif k == :feature
setfield!(n, :feature, convert(UInt16, v))
elseif k == :op
setfield!(n, :op, convert(UInt8, v))
elseif k == :val
setfield!(n, :val, convert(eltype(n), v))
elseif k == :children
setfield!(n, :children, v)
else
error("Invalid property: $k")
end
mutable struct GraphNode{T} <: AbstractExpressionNode{T}
degree::UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc.
constant::Bool # false if variable
val::T # If is a constant, this stores the actual value
# ------------------- (possibly undefined below)
feature::UInt16 # If is a variable (e.g., x in cos(x)), this stores the feature index.
op::UInt8 # If operator, this is the index of the operator in operators.binops, or operators.unaops
l::GraphNode{T} # Left child node. Only defined for degree=1 or degree=2.
r::GraphNode{T} # Right child node. Only defined for degree=2.

GraphNode{_T}() where {_T} = new{_T}()
end

################################################################################
Expand All @@ -204,62 +184,59 @@ end
Base.eltype(::Type{<:AbstractExpressionNode{T}}) where {T} = T
Base.eltype(::AbstractExpressionNode{T}) where {T} = T

max_degree(::Type{<:AbstractNode}) = 2 # Default
max_degree(::Type{<:AbstractNode{D}}) where {D} = D

@unstable constructorof(::Type{N}) where {N<:Node} = Node{T,max_degree(N)} where {T}
@unstable constructorof(::Type{N}) where {N<:GraphNode} =
GraphNode{T,max_degree(N)} where {T}
@unstable constructorof(::Type{N}) where {N<:AbstractNode} = Base.typename(N).wrapper
@unstable constructorof(::Type{<:Node}) = Node
@unstable constructorof(::Type{<:GraphNode}) = GraphNode

with_type_parameters(::Type{N}, ::Type{T}) where {N<:Node,T} = Node{T,max_degree(N)}
function with_type_parameters(::Type{N}, ::Type{T}) where {N<:GraphNode,T}
return GraphNode{T,max_degree(N)}
function with_type_parameters(::Type{N}, ::Type{T}) where {N<:AbstractExpressionNode,T}
return constructorof(N){T}
end

# with_degree(::Type{N}, ::Val{D}) where {T,N<:Node{T},D} = Node{T,D}
# with_degree(::Type{N}, ::Val{D}) where {T,N<:GraphNode{T},D} = GraphNode{T,D}
with_type_parameters(::Type{<:Node}, ::Type{T}) where {T} = Node{T}
with_type_parameters(::Type{<:GraphNode}, ::Type{T}) where {T} = GraphNode{T}

function default_allocator(::Type{N}, ::Type{T}) where {N<:AbstractExpressionNode,T}
return with_type_parameters(N, T)()
end
default_allocator(::Type{<:Node}, ::Type{T}) where {T} = Node{T}()
default_allocator(::Type{<:GraphNode}, ::Type{T}) where {T} = GraphNode{T}()

"""Trait declaring whether nodes share children or not."""
preserve_sharing(::Union{Type{<:AbstractNode},AbstractNode}) = false
preserve_sharing(::Union{Type{<:Node},Node}) = false
preserve_sharing(::Union{Type{<:GraphNode},GraphNode}) = true

include("base.jl")

#! format: off
@inline function (::Type{N})(
::Type{T1}=Undefined; val=nothing, feature=nothing, op=nothing, l=nothing, r=nothing, children=nothing, allocator::F=default_allocator,
) where {T1,N<:AbstractExpressionNode{T} where T,F}
_children = if l !== nothing && r === nothing
@assert children === nothing
(l,)
elseif l !== nothing && r !== nothing
@assert children === nothing
(l, r)
else
children
) where {T1,N<:AbstractExpressionNode,F}
validate_not_all_defaults(N, val, feature, op, l, r, children)
if children !== nothing
@assert l === nothing && r === nothing
if length(children) == 1
return node_factory(N, T1, val, feature, op, only(children), nothing, allocator)
else
return node_factory(N, T1, val, feature, op, children..., allocator)
end
end
validate_not_all_defaults(N, val, feature, op, _children)
return node_factory(N, T1, val, feature, op, _children, allocator)
return node_factory(N, T1, val, feature, op, l, r, allocator)
end
function validate_not_all_defaults(::Type{N}, val, feature, op, children) where {N<:AbstractExpressionNode}
function validate_not_all_defaults(::Type{N}, val, feature, op, l, r, children) where {N<:AbstractExpressionNode}
return nothing
end
function validate_not_all_defaults(::Type{N}, val, feature, op, children) where {T,N<:AbstractExpressionNode{T}}
if val === nothing && feature === nothing && op === nothing && children === nothing
function validate_not_all_defaults(::Type{N}, val, feature, op, l, r, children) where {T,N<:AbstractExpressionNode{T}}
if val === nothing && feature === nothing && op === nothing && l === nothing && r === nothing && children === nothing
error(
"Encountered the call for $N() inside the generic constructor. "
* "Did you forget to define `$(Base.typename(N).wrapper){T,D}() where {T,D} = new{T,D}()`?"
* "Did you forget to define `$(Base.typename(N).wrapper){T}() where {T} = new{T}()`?"
)
end
return nothing
end
"""Create a constant leaf."""
@inline function node_factory(
::Type{N}, ::Type{T1}, val::T2, ::Nothing, ::Nothing, ::Nothing, allocator::F,
::Type{N}, ::Type{T1}, val::T2, ::Nothing, ::Nothing, ::Nothing, ::Nothing, allocator::F,
) where {N,T1,T2,F}
T = node_factory_type(N, T1, T2)
n = allocator(N, T)
Expand All @@ -270,7 +247,7 @@ end
end
"""Create a variable leaf, to store data."""
@inline function node_factory(
::Type{N}, ::Type{T1}, ::Nothing, feature::Integer, ::Nothing, ::Nothing, allocator::F,
::Type{N}, ::Type{T1}, ::Nothing, feature::Integer, ::Nothing, ::Nothing, ::Nothing, allocator::F,
) where {N,T1,F}
T = node_factory_type(N, T1, DEFAULT_NODE_TYPE)
n = allocator(N, T)
Expand All @@ -279,18 +256,28 @@ end
n.feature = feature
return n
end
"""Create an operator node."""
"""Create a unary operator node."""
@inline function node_factory(
::Type{N}, ::Type{T1}, ::Nothing, ::Nothing, op::Integer, l::AbstractExpressionNode{T2}, ::Nothing, allocator::F,
) where {N,T1,T2,F}
@assert l isa N
T = T2 # Always prefer existing nodes, so we don't mess up references from conversion
n = allocator(N, T)
n.degree = 1
n.op = op
n.l = l
return n
end
"""Create a binary operator node."""
@inline function node_factory(
::Type{N}, ::Type, ::Nothing, ::Nothing, op::Integer, children::Tuple, allocator::F,
) where {N<:AbstractExpressionNode,F}
T = promote_type(map(eltype, children)...) # Always prefer existing nodes, so we don't mess up references from conversion
D2 = length(children)
@assert D2 <= max_degree(N)
NT = with_type_parameters(N, T)
::Type{N}, ::Type{T1}, ::Nothing, ::Nothing, op::Integer, l::AbstractExpressionNode{T2}, r::AbstractExpressionNode{T3}, allocator::F,
) where {N,T1,T2,T3,F}
T = promote_type(T2, T3)
n = allocator(N, T)
n.degree = D2
n.degree = 2
n.op = op
n.children = ntuple(i -> i <= D2 ? Ref(convert(NT, children[i])) : Ref{NT}(), Val(max_degree(N)))
n.l = T2 === T ? l : convert(with_type_parameters(N, T), l)
n.r = T3 === T ? r : convert(with_type_parameters(N, T), r)
return n
end

Expand Down Expand Up @@ -331,14 +318,14 @@ function (::Type{N})(
return N(; feature=i)
end

function Base.promote_rule(::Type{Node{T1,D}}, ::Type{Node{T2,D}}) where {T1,T2,D}
return Node{promote_type(T1, T2),D}
function Base.promote_rule(::Type{Node{T1}}, ::Type{Node{T2}}) where {T1,T2}
return Node{promote_type(T1, T2)}
end
function Base.promote_rule(::Type{GraphNode{T1,D}}, ::Type{Node{T2,D}}) where {T1,T2,D}
return GraphNode{promote_type(T1, T2),D}
function Base.promote_rule(::Type{GraphNode{T1}}, ::Type{Node{T2}}) where {T1,T2}
return GraphNode{promote_type(T1, T2)}
end
function Base.promote_rule(::Type{GraphNode{T1,D}}, ::Type{GraphNode{T2,D}}) where {T1,T2,D}
return GraphNode{promote_type(T1, T2),D}
function Base.promote_rule(::Type{GraphNode{T1}}, ::Type{GraphNode{T2}}) where {T1,T2}
return GraphNode{promote_type(T1, T2)}
end

# TODO: Verify using this helps with garbage collection
Expand Down
Loading

0 comments on commit 6146408

Please sign in to comment.