Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Alternative extractor implementation #116

Merged
merged 38 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
96108f2
Basic extractors
nystrom Nov 12, 2024
44bd65f
minor fix
nystrom Nov 13, 2024
505900e
nested extractor test
nystrom Nov 13, 2024
515e886
Explicit AST for extractor fetch
nystrom Nov 13, 2024
b227f33
PR comment
nystrom Nov 13, 2024
2bfd172
Merge branch 'main' into nn-extract
gafter Nov 13, 2024
9d1e9cc
improved diags
nystrom Nov 14, 2024
73d1a24
alternative extractor design
nystrom Nov 14, 2024
dc165a2
more tests
nystrom Nov 14, 2024
e32db40
Expand extractor readme
nystrom Nov 14, 2024
a49ed7a
readme fixes
nystrom Nov 14, 2024
8ee738d
readme tweak
nystrom Nov 14, 2024
02cba8f
cleanup
nystrom Nov 14, 2024
8e594d5
test error message
nystrom Nov 14, 2024
2b611b4
fix hash
nystrom Nov 14, 2024
8cd29d0
make type and extractor mutually exclusive
nystrom Nov 14, 2024
47e008c
fix readme for new extractor semantics
nystrom Nov 14, 2024
8756f8b
support named tuples
nystrom Nov 15, 2024
4ce7d70
remove named tuples support
nystrom Nov 16, 2024
1ea128d
remove more named tuples
nystrom Nov 16, 2024
aab8476
remove try_bind_type
nystrom Nov 18, 2024
cbefcc4
coverage for pretty-printing extract
nystrom Nov 18, 2024
06cb952
fix test for 1.11
nystrom Nov 18, 2024
f11d5ea
fix readme
nystrom Nov 18, 2024
1398564
cleanup tests
nystrom Nov 20, 2024
9496b28
lookup extractors by arity
nystrom Nov 21, 2024
c3ee22d
fix readme for extractors
nystrom Nov 21, 2024
e9c35e0
lowering for extractors with arity
nystrom Nov 21, 2024
704c3da
oops
nystrom Nov 21, 2024
4a34131
minor cleanup
nystrom Nov 22, 2024
bc5bc5f
improve error message when extractor not defined
nystrom Nov 24, 2024
33b3606
remove unused type
nystrom Nov 24, 2024
bd34946
delete comment
nystrom Nov 24, 2024
5ac33ca
error message tests
nystrom Nov 24, 2024
62ae036
PR comments
nystrom Nov 24, 2024
0672110
fix docstring
nystrom Nov 24, 2024
87e9fc3
remove .
nystrom Nov 24, 2024
f1d2b3c
bump version
nystrom Nov 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,46 @@ When `pattern1` matches but `some_failure_condition` is `true`, then the whole c
Otherwise, if `some_shortcut_condition` is `true`, then `1` is the result value for this case.
Otherwise `2` is the result.

## Extractors

Struct patterns of the form `T(x,y,z)` can be overridden by defining an _extractor_ function for `T`.
When a value `v` is matched against a pattern `T(x,y,z)`, `Match.extract(T, v)` is called and the result is then matched against the tuple pattern `(x,y,z)`.
The value `v` need not be of type `T`.
If the result of the `extract` call is `nothing`, the value `v` is checked against the struct type `T`, as usual, with its fields checked against the subpatterns `x`, `y`, and `z`.

For example, to match a pair of numbers using Polar coordinates, extracting the radius and angle, you could define:
```julia
struct Polar end
function Match.extract(::Type{Polar}, p::Tuple{<:Number,<:Number})
x, y = p
return (sqrt(x^2 + y^2), atan(y, x))
end
```
This definition allows you to use a new `Polar` extractor pattern:
```julia
@match Polar(r,θ) = (1,1)
@assert r == sqrt(2) && θ == π / 4
```

The `extract` function should return either a tuple of values to be matched by the subpatterns or return `nothing`.

Extractors can also be used to ignore or transform fields of existing types during matching.
For example, this extractor ignores the `annos` field of the `AddExpr` type:
```julia
struct AddExpr
left
right
annos
end
function Match.extract(::Type{AddExpr}, e::AddExpr)
return (e.left, e.right)
end
@match AddExpr(x, y) = node
```

Extractors can be used to implement more user-friendly matching for types defined with `SumTypes.jl` or
other packages.

## Differences from previous versions of `Match.jl`

* If no branches are matched, throws `MatchFailure` instead of returning nothing.
Expand Down
11 changes: 10 additions & 1 deletion src/Match.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ The following syntactic forms can be used in patterns:

* `_` matches anything
* `x` (an identifier) matches anything, binds value to the variable `x`
* `T(x,y,z)` matches structs of type `T` with fields matching patterns `x,y,z`
* `T(x,y,z)` matches structs of type `T` with fields matching patterns `x,y,z`.
nystrom marked this conversation as resolved.
Show resolved Hide resolved
* `T(y=1)` matches structs of type `T` whose `y` field equals `1`
* `[x,y,z]` matches `AbstractArray`s with 3 entries matching `x,y,z`
* `(x,y,z)` matches `Tuple`s with 3 entries matching `x,y,z`
Expand Down Expand Up @@ -247,6 +247,15 @@ struct MatchFailure <: Exception
value
end

"""
extract(T::Type, value)

Implement extractor for type `T`, returning a tuple of fields of `value`, or `nothing` if
the match fails. This can be used to override matching on type `T`. If `extract(T, value)`
returns a tuple, it will be used instead of the default field extraction.
"""
extract(::Type, ::Any) = nothing

# const fields only suppored >= Julia 1.8
macro _const(x)
(VERSION >= v"1.8") ? Expr(:const, esc(x)) : esc(x)
Expand Down
142 changes: 105 additions & 37 deletions src/binding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,21 @@ function bind_type(location, T, input, binder)

bound_type
end
function try_bind_type(location, T, input, binder)
nystrom marked this conversation as resolved.
Show resolved Hide resolved
# bind type at macro expansion time in the caller's module.
bound_type = nothing
try
bound_type = Core.eval(binder.mod, Expr(:block, location, T))
catch ex
return nothing
end

if !(bound_type isa Type)
return nothing
end

return bound_type
end

function simple_name(s::Symbol)
simple_name(string(s))
Expand Down Expand Up @@ -95,6 +110,9 @@ end
function gentemp(p::BoundFetchLengthPattern)::Symbol
gensym(string("length(", simple_name(p.input), ")"))
end
function gentemp(p::BoundFetchExtractorPattern)::Symbol
gensym(string("extract(", p.extractor, ", ", simple_name(p.input), ")"))
end

#
# The following are special bindings used to handle the point where
Expand Down Expand Up @@ -227,7 +245,7 @@ function bind_pattern!(
# struct pattern.
# TypeName(patterns...)
T = source.args[1]
subpatterns = source.args[2:length(source.args)]
subpatterns = source.args[2:end]
len = length(subpatterns)
named_fields = [pat.args[1] for pat in subpatterns if is_expr(pat, :kw)]
named_count = length(named_fields)
Expand All @@ -242,52 +260,102 @@ function bind_pattern!(
match_positionally = named_count == 0

# bind type at macro expansion time
pattern0, assigned = bind_pattern!(location, :(::($T)), input, binder, assigned)
bound_type = (pattern0::BoundTypeTestPattern).type
patterns = BoundPattern[pattern0]
field_names::Tuple = match_fieldnames(bound_type)
if match_positionally && len != length(field_names)
error("$(location.file):$(location.line): The type `$bound_type` has " *
"$(length(field_names)) fields but the pattern expects $len fields.")
bound_type = bind_type(location, T, input, binder)

# First try the extractor, then try the struct type.
disjuncts = BoundPattern[]
nystrom marked this conversation as resolved.
Show resolved Hide resolved
extractor_temp = nothing

# Try the extractor first.
# This only works with positional arguments.
# TODO support named tuples
if match_positionally
# Check if there is an extractor method for the pattern type.
methods = Base.methods(Match.extract, (Type{bound_type}, Any,))
# There is always at least one method (the default), so we know the extractor
# method is implemented if there's at least two methods.
if length(methods) >= 2
conjuncts = BoundPattern[]
# call Match.extract(Val(T), input) and match the result against the tuple of subpatterns
fetch = BoundFetchExtractorPattern(location, source, input, bound_type, Any)
extractor_temp = push_pattern!(conjuncts, binder, fetch)
subpattern, assigned = bind_pattern!(location,
Expr(:tuple, subpatterns...), extractor_temp, binder, assigned)
push!(conjuncts, subpattern)
push!(disjuncts, BoundAndPattern(location, source, conjuncts))
end
end

for i in 1:len
pat = subpatterns[i]
if match_positionally
field_name = field_names[i]
pattern_source = pat
else
@assert pat.head == :kw
field_name = pat.args[1]
pattern_source = pat.args[2]
if !(field_name in field_names)
error("$(location.file):$(location.line): Type `$bound_type` has " *
"no field `$field_name`.")
end
# If the extractor failed, try the field-by-field match.
type_conjuncts = BoundPattern[]

if isnothing(extractor_temp)
# Ensure that the extractor actually failed.
# Otherwise, we could have an ambiguity between the type pattern and the extractor pattern.
if !isnothing(extractor_temp)
nystrom marked this conversation as resolved.
Show resolved Hide resolved
pattern0 = BoundTypeTestPattern(location, :( Base.Nothing ), extractor_temp, Nothing)
push!(type_conjuncts, pattern0)
end

field_type = nothing
if field_name == match_fieldnames(Symbol)[1]
# special case Symbol's hypothetical name field.
field_type = String
field_names::Tuple = match_fieldnames(bound_type)

if match_positionally && len != length(field_names)
# If the extractor is defined, silently fail if the field-by-field match fails.
gafter marked this conversation as resolved.
Show resolved Hide resolved
if isnothing(extractor_temp)
error("$(location.file):$(location.line): The type `$bound_type` has " *
"$(length(field_names)) fields but the pattern expects $len fields.")
else
pattern0 = BoundFalsePattern(location, source)
nystrom marked this conversation as resolved.
Show resolved Hide resolved
push!(type_conjuncts, pattern0)
end
else
nystrom marked this conversation as resolved.
Show resolved Hide resolved
for (fname, ftype) in zip(Base.fieldnames(bound_type), Base.fieldtypes(bound_type))
if fname == field_name
field_type = ftype
break
pattern0 = BoundTypeTestPattern(location, T, input, bound_type)
push!(type_conjuncts, pattern0)

for i in 1:len
pat = subpatterns[i]
if match_positionally
field_name = field_names[i]
pattern_source = pat
else
@assert pat.head == :kw
field_name = pat.args[1]
pattern_source = pat.args[2]
if !(field_name in field_names)
if isnothing(extractor_temp)
error("$(location.file):$(location.line): Type `$bound_type` has " *
"no field `$field_name`.")
end
end
end

field_type = nothing
if field_name == match_fieldnames(Symbol)[1]
# special case Symbol's hypothetical name field.
field_type = String
else
for (fname, ftype) in zip(Base.fieldnames(bound_type), Base.fieldtypes(bound_type))
if fname == field_name
field_type = ftype
break
end
end
end
@assert field_type !== nothing

fetch = BoundFetchFieldPattern(location, pattern_source, input, field_name, field_type)
field_temp = push_pattern!(type_conjuncts, binder, fetch)
bound_subpattern, assigned = bind_pattern!(
location, pattern_source, field_temp, binder, assigned)
push!(type_conjuncts, bound_subpattern)
end
end
@assert field_type !== nothing

fetch = BoundFetchFieldPattern(location, pattern_source, input, field_name, field_type)
field_temp = push_pattern!(patterns, binder, fetch)
bound_subpattern, assigned = bind_pattern!(
location, pattern_source, field_temp, binder, assigned)
push!(patterns, bound_subpattern)
pattern = BoundAndPattern(location, source, type_conjuncts)
push!(disjuncts, pattern)
end

pattern = BoundAndPattern(location, source, patterns)
pattern = BoundOrPattern(location, source, disjuncts)

elseif is_expr(source, :(&&), 2)
# conjunction: `(a && b)` where `a` and `b` are patterns.
Expand Down Expand Up @@ -410,7 +478,7 @@ function bind_pattern!(
pattern0, assigned = bind_pattern!(location, subpattern, input, binder, assigned)
pattern1 = shred_where_clause(guard, false, location, binder, assigned)
pattern = BoundAndPattern(location, source, BoundPattern[pattern0, pattern1])

elseif is_expr(source, :if, 2)
# if expr end
if !is_empty_block(source.args[2])
Expand Down
16 changes: 16 additions & 0 deletions src/bound_pattern.jl
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,22 @@ function Base.:(==)(a::BoundFetchLengthPattern, b::BoundFetchLengthPattern)
a.input == b.input
end

# Fetch a value using the given extractor function into a temporary. See
# `BoundFetchFieldPattern` for the general idea of how these are used.
struct BoundFetchExtractorPattern <: BoundFetchPattern
location::LineNumberNode
source::Any
input::Symbol
extractor::Type
type::Type
end
function Base.hash(a::BoundFetchExtractorPattern, h::UInt64)
hash((a.input, a.extractor, 0xd7882f5b4888d335), h)
end
function Base.:(==)(a::BoundFetchExtractorPattern, b::BoundFetchExtractorPattern)
a.input == b.input && a.extractor == b.extractor
end

# Preserve the value of the expression into a temp. Used
# (1) to force the binding on both sides of an or-pattern to be the same (a phi), and
# (2) to load the value of a `where` clause.
Expand Down
5 changes: 4 additions & 1 deletion src/lowering.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ function code(bound_pattern::BoundTypeTestPattern, binder::BinderContext)
$string($(bound_pattern.type)), " at macro expansion time but ",
$src, " later."))))
push!(binder.assertions, Expr(:block, bound_pattern.location, :($test || $thrown)))
push!(binder.asserted_types, )
push!(binder.asserted_types, src)
end
:($(bound_pattern.input) isa $(bound_pattern.type))
end
Expand Down Expand Up @@ -89,6 +89,9 @@ end
function code(bound_pattern::BoundFetchExpressionPattern)
code(bound_pattern.bound_expression)
end
function code(bound_pattern::BoundFetchExtractorPattern)
Expr(:call, Match.extract, bound_pattern.extractor, bound_pattern.input)
end

# Return an expression that computes whether or not the pattern matches.
function lower_pattern_to_boolean(bound_pattern::BoundPattern, binder::BinderContext)
Expand Down
7 changes: 7 additions & 0 deletions src/pretty.jl
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,13 @@ function pretty(io::IO, p::BoundFetchLengthPattern)
pretty(io, p.input)
print(io, ")")
end
function pretty(io::IO, p::BoundFetchExtractorPattern)
nystrom marked this conversation as resolved.
Show resolved Hide resolved
print(io, "extract(")
pretty(io, p.extractor)
print(io, ", ")
pretty(io, p.input)
print(io, ")")
end
function pretty(io::IO, p::BoundFetchExpressionPattern)
pretty(io, p.bound_expression)
end
Loading
Loading