Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basic extractors #114

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ for examples of this and other features.
* `x` (an identifier) matches anything, binds value to the variable `x`
* `T(x,y,z)` matches structs of type `T` with fields matching patterns `x,y,z`
* `T(y=1)` matches structs of type `T` whose `y` field equals `1`
* `X(x,y,z)` where `X` is not a type, calls `Match.extract(Val(:X), v)` on the value `v` and matches the result against the tuple pattern `(x,y,z)`
* `[x,y,z]` matches `AbstractArray`s with 3 entries matching `x,y,z`
* `(x,y,z)` matches `Tuple`s with 3 entries matching `x,y,z`
* `[x,y...,z]` matches `AbstractArray`s with at least 2 entries, where `x` matches the first entry, `z` matches the last entry and `y` matches the remaining entries.
Expand Down Expand Up @@ -89,6 +90,28 @@ When `pattern1` matches but `some_failure_condition` is `true`, then the whole c
Otherwise, if `some_shortcut_condition` is `true`, then `1` is the result value for this case.
Otherwise `2` is the result.

## Extractors

New patterns can be defined on values by overloading the `extract` function with the new pattern name.
For example, to match a pair of numbers using Polar coordinates, extracting the radius and angle, you could define:

```julia
function Match.extract(::Val{:Polar}, p::Tuple{<:Number,<:Number})
x, y = p
return (sqrt(x^2 + y^2), atan(y, x))
end
```
This definition allows you to use a new `Polar` extractor pattern:
```
@match Polar(r,θ) = (1,1)
@assert r == sqrt(2) && θ == π / 4
```

The `extract` function should return either a tuple of values to be matched by subpatterns or `nothing`.

Extractors can be used to implement more user-friendly matching for types defined with `SumTypes.jl` or
other packages.

## Differences from previous versions of `Match.jl`

* If no branches are matched, throws `MatchFailure` instead of returning nothing.
Expand Down
9 changes: 9 additions & 0 deletions src/Match.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ The following syntactic forms can be used in patterns:
* `x` (an identifier) matches anything, binds value to the variable `x`
* `T(x,y,z)` matches structs of type `T` with fields matching patterns `x,y,z`
* `T(y=1)` matches structs of type `T` whose `y` field equals `1`
* `X(x,y,z)` where `X` is not a type, calls `Match.extract(Val(:X), v)` on the value `v` and matches the result against the tuple pattern `(x,y,z)`
* `[x,y,z]` matches `AbstractArray`s with 3 entries matching `x,y,z`
* `(x,y,z)` matches `Tuple`s with 3 entries matching `x,y,z`
* `[x,y...,z]` matches `AbstractArray`s with at least 2 entries, where `x` matches the first entry, `z` matches the last entry and `y` matches the remaining entries.
Expand Down Expand Up @@ -247,6 +248,14 @@ struct MatchFailure <: Exception
value
end

"""
extract(::Val{x}, value)

Implement extractor with name `x`, returning a tuple of fields of `value`, or nothing if
`x` cannot be extracted from `value`.
"""
extract(::Val, value) = nothing

# const fields only suppored >= Julia 1.8
macro _const(x)
(VERSION >= v"1.8") ? Expr(:const, esc(x)) : esc(x)
Expand Down
117 changes: 77 additions & 40 deletions src/binding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,21 @@ function bind_type(location, T, input, binder)

bound_type
end
function try_bind_type(location, T, input, binder)
# bind type at macro expansion time. It will be verified at runtime.
gafter marked this conversation as resolved.
Show resolved Hide resolved
nystrom marked this conversation as resolved.
Show resolved Hide resolved
bound_type = nothing
try
bound_type = Core.eval(binder.mod, Expr(:block, location, T))
catch ex
return nothing
end

if !(bound_type isa Type)
return nothing
end

return bound_type
end

function simple_name(s::Symbol)
simple_name(string(s))
Expand Down Expand Up @@ -95,6 +110,9 @@ end
function gentemp(p::BoundFetchLengthPattern)::Symbol
gensym(string("length(", simple_name(p.input), ")"))
end
function gentemp(p::BoundFetchExtractorPattern)::Symbol
gensym(string("extract(", p.extractor, ", ", simple_name(p.input), ")"))
end

#
# The following are special bindings used to handle the point where
Expand Down Expand Up @@ -227,7 +245,7 @@ function bind_pattern!(
# struct pattern.
# TypeName(patterns...)
T = source.args[1]
subpatterns = source.args[2:length(source.args)]
subpatterns = source.args[2:end]
len = length(subpatterns)
named_fields = [pat.args[1] for pat in subpatterns if is_expr(pat, :kw)]
named_count = length(named_fields)
Expand All @@ -241,50 +259,69 @@ function bind_pattern!(

match_positionally = named_count == 0

# bind type at macro expansion time
pattern0, assigned = bind_pattern!(location, :(::($T)), input, binder, assigned)
bound_type = (pattern0::BoundTypeTestPattern).type
patterns = BoundPattern[pattern0]
field_names::Tuple = match_fieldnames(bound_type)
if match_positionally && len != length(field_names)
error("$(location.file):$(location.line): The type `$bound_type` has " *
"$(length(field_names)) fields but the pattern expects $len fields.")
end
is_type = !Base.isnothing(try_bind_type(location, T, input, binder))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line does no good if T is not a symbol. Perhaps

is_extractor = T isa Symbol && Base.isnothing(try_bind_type(location, T, input, binder))

if T isa Symbol && match_positionally && !is_type
# TODO support named tuples
patterns = BoundPattern[]
# check that the extractor exists
methods = Base.methods(Match.extract, (Val{T}, Any,))
if length(methods) <= 1
# If there's only one matching Match.extract method, it's the default one.
# that returns nothing. Worse if there are fewer than one.
error("$(location.file):$(location.line): `$T` is neither a type name " *
"nor is there a `Match.extract(::Val{$T}, _)` implementation.")
end
# call Match.extract(Val(T), input) and match the result against the tuple of subpatterns
fetch = BoundFetchExtractorPattern(location, source, input, T, Any)
temp = push_pattern!(patterns, binder, fetch)
subpattern, assigned = bind_pattern!(location, Expr(:tuple, subpatterns...), temp, binder, assigned)
push!(patterns, subpattern)
else
# bind type at macro expansion time
pattern0, assigned = bind_pattern!(location, :(::($T)), input, binder, assigned)
bound_type = (pattern0::BoundTypeTestPattern).type
patterns = BoundPattern[pattern0]
field_names::Tuple = match_fieldnames(bound_type)
if match_positionally && len != length(field_names)
error("$(location.file):$(location.line): The type `$bound_type` has " *
"$(length(field_names)) fields but the pattern expects $len fields.")
end

for i in 1:len
pat = subpatterns[i]
if match_positionally
field_name = field_names[i]
pattern_source = pat
else
@assert pat.head == :kw
field_name = pat.args[1]
pattern_source = pat.args[2]
if !(field_name in field_names)
error("$(location.file):$(location.line): Type `$bound_type` has " *
"no field `$field_name`.")
for i in 1:len
pat = subpatterns[i]
if match_positionally
field_name = field_names[i]
pattern_source = pat
else
@assert pat.head == :kw
field_name = pat.args[1]
pattern_source = pat.args[2]
if !(field_name in field_names)
error("$(location.file):$(location.line): Type `$bound_type` has " *
"no field `$field_name`.")
end
end
end

field_type = nothing
if field_name == match_fieldnames(Symbol)[1]
# special case Symbol's hypothetical name field.
field_type = String
else
for (fname, ftype) in zip(Base.fieldnames(bound_type), Base.fieldtypes(bound_type))
if fname == field_name
field_type = ftype
break
field_type = nothing
if field_name == match_fieldnames(Symbol)[1]
# special case Symbol's hypothetical name field.
field_type = String
else
for (fname, ftype) in zip(Base.fieldnames(bound_type), Base.fieldtypes(bound_type))
if fname == field_name
field_type = ftype
break
end
end
end
end
@assert field_type !== nothing
@assert field_type !== nothing

fetch = BoundFetchFieldPattern(location, pattern_source, input, field_name, field_type)
field_temp = push_pattern!(patterns, binder, fetch)
bound_subpattern, assigned = bind_pattern!(
location, pattern_source, field_temp, binder, assigned)
push!(patterns, bound_subpattern)
fetch = BoundFetchFieldPattern(location, pattern_source, input, field_name, field_type)
field_temp = push_pattern!(patterns, binder, fetch)
bound_subpattern, assigned = bind_pattern!(
location, pattern_source, field_temp, binder, assigned)
push!(patterns, bound_subpattern)
end
end

pattern = BoundAndPattern(location, source, patterns)
Expand Down Expand Up @@ -410,7 +447,7 @@ function bind_pattern!(
pattern0, assigned = bind_pattern!(location, subpattern, input, binder, assigned)
pattern1 = shred_where_clause(guard, false, location, binder, assigned)
pattern = BoundAndPattern(location, source, BoundPattern[pattern0, pattern1])

elseif is_expr(source, :if, 2)
# if expr end
if !is_empty_block(source.args[2])
Expand Down
16 changes: 16 additions & 0 deletions src/bound_pattern.jl
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,22 @@ function Base.:(==)(a::BoundFetchLengthPattern, b::BoundFetchLengthPattern)
a.input == b.input
end

# Fetch a value using the given extractor function into a temporary. See
# `BoundFetchFieldPattern` for the general idea of how these are used.
struct BoundFetchExtractorPattern <: BoundFetchPattern
location::LineNumberNode
source::Any
input::Symbol
extractor::Symbol
type::Type
end
function Base.hash(a::BoundFetchExtractorPattern, h::UInt64)
hash((a.input, a.extractor, 0xd7882f5b4888d335), h)
end
function Base.:(==)(a::BoundFetchExtractorPattern, b::BoundFetchExtractorPattern)
a.input == b.input && a.extractor == b.extractor
end

# Preserve the value of the expression into a temp. Used
# (1) to force the binding on both sides of an or-pattern to be the same (a phi), and
# (2) to load the value of a `where` clause.
Expand Down
3 changes: 3 additions & 0 deletions src/lowering.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ end
function code(bound_pattern::BoundFetchExpressionPattern)
code(bound_pattern.bound_expression)
end
function code(bound_pattern::BoundFetchExtractorPattern)
Expr(:call, Match.extract, Val(bound_pattern.extractor), bound_pattern.input)
end

# Return an expression that computes whether or not the pattern matches.
function lower_pattern_to_boolean(bound_pattern::BoundPattern, binder::BinderContext)
Expand Down
7 changes: 7 additions & 0 deletions src/pretty.jl
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,13 @@ function pretty(io::IO, p::BoundFetchLengthPattern)
pretty(io, p.input)
print(io, ")")
end
function pretty(io::IO, p::BoundFetchExtractorPattern)
print(io, "extract(")
pretty(io, p.extractor)
print(io, ", ")
pretty(io, p.input)
print(io, ")")
end
function pretty(io::IO, p::BoundFetchExpressionPattern)
pretty(io, p.bound_expression)
end
50 changes: 50 additions & 0 deletions test/rematch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,56 @@ end
end)
end

@testset "extractor function" begin
@eval function Match.extract(::Val{:polar}, p::Foo)
return (sqrt(p.x^2 + p.y^2), atan(p.y, p.x))
end
@test (@eval @match Foo(1,1) begin
polar(r,θ) => r == sqrt(2) && θ == π / 4
_ => false
end)
end

@testset "extractor function that might fail" begin
@eval function Match.extract(::Val{:diff}, p::Foo)
return p.x >= p.y ? (p.x - p.y,) : nothing
end
@test (@eval @match Foo(1,1) begin
diff(2) => false
diff(1) => false
diff(0) => true
_ => false
end)
@test (@eval @match Foo(2,1) begin
diff(2) => false
diff(1) => true
_ => false
end)
@test (@eval @match Foo(1,2) begin
diff(2) => false
diff(1) => false
_ => true
end)
end

@testset "nested extractor function" begin
@eval function Match.extract(::Val{:foo}, p::Foo)
return (p.x, p.y)
end
@test (@eval @match Foo(Foo(1,2),3) begin
foo(foo(1,2),3) => true
_ => false
end)
end

@testset "extractor function missing" begin
# Bar is neither a type nor an extractor function
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please test the location and text of the diagnostic.

@test_throws LoadError (@eval @match Foo(1,1) begin
Bar(0) => true
_ => false
end)
end

@testset "Miscellanea" begin
# match against fiddly symbols (https://github.com/JuliaServices/Match.jl/issues/32)
@test (@match :(@when a < b) begin
Expand Down
Loading