From 96108f2600acaf1e324248d31019b34c83ec5146 Mon Sep 17 00:00:00 2001 From: Nate Nystrom Date: Tue, 12 Nov 2024 18:25:03 +0100 Subject: [PATCH 1/5] Basic extractors --- README.md | 23 ++++++++++ src/Match.jl | 9 ++++ src/binding.jl | 115 +++++++++++++++++++++++++++++++----------------- test/rematch.jl | 39 ++++++++++++++++ 4 files changed, 146 insertions(+), 40 deletions(-) diff --git a/README.md b/README.md index f7b0f75..55485b5 100644 --- a/README.md +++ b/README.md @@ -46,6 +46,7 @@ for examples of this and other features. * `x` (an identifier) matches anything, binds value to the variable `x` * `T(x,y,z)` matches structs of type `T` with fields matching patterns `x,y,z` * `T(y=1)` matches structs of type `T` whose `y` field equals `1` +* `X(x,y,z)` where `X` is not a type, calls `Match.extract(Val(:X), v)` on the value `v` and matches the result against the tuple pattern `(x,y,z)` * `[x,y,z]` matches `AbstractArray`s with 3 entries matching `x,y,z` * `(x,y,z)` matches `Tuple`s with 3 entries matching `x,y,z` * `[x,y...,z]` matches `AbstractArray`s with at least 2 entries, where `x` matches the first entry, `z` matches the last entry and `y` matches the remaining entries. @@ -89,6 +90,28 @@ When `pattern1` matches but `some_failure_condition` is `true`, then the whole c Otherwise, if `some_shortcut_condition` is `true`, then `1` is the result value for this case. Otherwise `2` is the result. +## Extractors + +New patterns can be defined on values by overloading the `extract` function with the new pattern name. +For example, to match a pair of numbers using Polar coordinates, extracting the radius and angle, you could define: + +```julia +function Match.extract(::Val{:Polar}, p::Tuple{<:Number,<:Number}) + x, y = p + return (sqrt(x^2 + y^2), atan(y, x)) +end +``` +This definition allows you to use a new `Polar` extractor pattern: +``` +@match Polar(r,θ) = (1,1) +@assert r == sqrt(2) && θ == π / 4 +``` + +The `extract` function should return either a tuple of values to be matched by subpatterns or `nothing`. + +Extractors can be used to implement more user-friendly matching for types defined with `SumTypes.jl` or +other packages. + ## Differences from previous versions of `Match.jl` * If no branches are matched, throws `MatchFailure` instead of returning nothing. diff --git a/src/Match.jl b/src/Match.jl index f083efc..7ca91cb 100644 --- a/src/Match.jl +++ b/src/Match.jl @@ -54,6 +54,7 @@ The following syntactic forms can be used in patterns: * `x` (an identifier) matches anything, binds value to the variable `x` * `T(x,y,z)` matches structs of type `T` with fields matching patterns `x,y,z` * `T(y=1)` matches structs of type `T` whose `y` field equals `1` +* `X(x,y,z)` where `X` is not a type, calls `Match.extract(Val(:X), v)` on the value `v` and matches the result against the tuple pattern `(x,y,z)` * `[x,y,z]` matches `AbstractArray`s with 3 entries matching `x,y,z` * `(x,y,z)` matches `Tuple`s with 3 entries matching `x,y,z` * `[x,y...,z]` matches `AbstractArray`s with at least 2 entries, where `x` matches the first entry, `z` matches the last entry and `y` matches the remaining entries. @@ -247,6 +248,14 @@ struct MatchFailure <: Exception value end +""" + extract(::Val{x}, value) + +Implement extractor with name `x`, returning a tuple of fields of `value`, or nothing if +`x` cannot be extracted from `value`. +""" +extract(::Val, value) = nothing + # const fields only suppored >= Julia 1.8 macro _const(x) (VERSION >= v"1.8") ? Expr(:const, esc(x)) : esc(x) diff --git a/src/binding.jl b/src/binding.jl index 14a7e8f..025d093 100644 --- a/src/binding.jl +++ b/src/binding.jl @@ -66,6 +66,21 @@ function bind_type(location, T, input, binder) bound_type end +function try_bind_type(location, T, input, binder) + # bind type at macro expansion time. It will be verified at runtime. + bound_type = nothing + try + bound_type = Core.eval(binder.mod, Expr(:block, location, T)) + catch ex + return nothing + end + + if !(bound_type isa Type) + return nothing + end + + return bound_type +end function simple_name(s::Symbol) simple_name(string(s)) @@ -227,7 +242,7 @@ function bind_pattern!( # struct pattern. # TypeName(patterns...) T = source.args[1] - subpatterns = source.args[2:length(source.args)] + subpatterns = source.args[2:end] len = length(subpatterns) named_fields = [pat.args[1] for pat in subpatterns if is_expr(pat, :kw)] named_count = length(named_fields) @@ -241,50 +256,70 @@ function bind_pattern!( match_positionally = named_count == 0 - # bind type at macro expansion time - pattern0, assigned = bind_pattern!(location, :(::($T)), input, binder, assigned) - bound_type = (pattern0::BoundTypeTestPattern).type - patterns = BoundPattern[pattern0] - field_names::Tuple = match_fieldnames(bound_type) - if match_positionally && len != length(field_names) - error("$(location.file):$(location.line): The type `$bound_type` has " * - "$(length(field_names)) fields but the pattern expects $len fields.") - end + is_type = !Base.isnothing(try_bind_type(location, T, input, binder)) + if T isa Symbol && match_positionally && !is_type + # TODO support named tuples + patterns = BoundPattern[] + # call Match.extract(Val(T), input) and match the result against the tuple of subpatterns + extract = BoundExpression(location, Expr(:call, Match.extract, Val(T), input)) + # check that the extractor exists + methods = Base.methods(Match.extract, (Val{T}, Any,)) + if length(methods) <= 1 + # If there's only one matching Match.extract method, it's the default one. + # that returns nothing. Worse if there are fewer than one. + error("$(location.file):$(location.line): `$T` is neither a type name " * + "nor is there a `Match.extract(::Val{$T}, _)` implementation.") + end + fetch = BoundFetchExpressionPattern(extract, nothing, Any) + temp1 = push_pattern!(patterns, binder, fetch) + subpattern, assigned = bind_pattern!(location, Expr(:tuple, subpatterns...), temp1, binder, assigned) + patterns = BoundPattern[fetch, subpattern] + else + # bind type at macro expansion time + pattern0, assigned = bind_pattern!(location, :(::($T)), input, binder, assigned) + bound_type = (pattern0::BoundTypeTestPattern).type + patterns = BoundPattern[pattern0] + field_names::Tuple = match_fieldnames(bound_type) + if match_positionally && len != length(field_names) + error("$(location.file):$(location.line): The type `$bound_type` has " * + "$(length(field_names)) fields but the pattern expects $len fields.") + end - for i in 1:len - pat = subpatterns[i] - if match_positionally - field_name = field_names[i] - pattern_source = pat - else - @assert pat.head == :kw - field_name = pat.args[1] - pattern_source = pat.args[2] - if !(field_name in field_names) - error("$(location.file):$(location.line): Type `$bound_type` has " * - "no field `$field_name`.") + for i in 1:len + pat = subpatterns[i] + if match_positionally + field_name = field_names[i] + pattern_source = pat + else + @assert pat.head == :kw + field_name = pat.args[1] + pattern_source = pat.args[2] + if !(field_name in field_names) + error("$(location.file):$(location.line): Type `$bound_type` has " * + "no field `$field_name`.") + end end - end - field_type = nothing - if field_name == match_fieldnames(Symbol)[1] - # special case Symbol's hypothetical name field. - field_type = String - else - for (fname, ftype) in zip(Base.fieldnames(bound_type), Base.fieldtypes(bound_type)) - if fname == field_name - field_type = ftype - break + field_type = nothing + if field_name == match_fieldnames(Symbol)[1] + # special case Symbol's hypothetical name field. + field_type = String + else + for (fname, ftype) in zip(Base.fieldnames(bound_type), Base.fieldtypes(bound_type)) + if fname == field_name + field_type = ftype + break + end end end - end - @assert field_type !== nothing + @assert field_type !== nothing - fetch = BoundFetchFieldPattern(location, pattern_source, input, field_name, field_type) - field_temp = push_pattern!(patterns, binder, fetch) - bound_subpattern, assigned = bind_pattern!( - location, pattern_source, field_temp, binder, assigned) - push!(patterns, bound_subpattern) + fetch = BoundFetchFieldPattern(location, pattern_source, input, field_name, field_type) + field_temp = push_pattern!(patterns, binder, fetch) + bound_subpattern, assigned = bind_pattern!( + location, pattern_source, field_temp, binder, assigned) + push!(patterns, bound_subpattern) + end end pattern = BoundAndPattern(location, source, patterns) @@ -410,7 +445,7 @@ function bind_pattern!( pattern0, assigned = bind_pattern!(location, subpattern, input, binder, assigned) pattern1 = shred_where_clause(guard, false, location, binder, assigned) pattern = BoundAndPattern(location, source, BoundPattern[pattern0, pattern1]) - + elseif is_expr(source, :if, 2) # if expr end if !is_empty_block(source.args[2]) diff --git a/test/rematch.jl b/test/rematch.jl index 5b4b9b9..faa4c6e 100644 --- a/test/rematch.jl +++ b/test/rematch.jl @@ -474,6 +474,45 @@ end end) end +@testset "extractor function" begin + @eval function Match.extract(::Val{:polar}, p::Foo) + return (sqrt(p.x^2 + p.y^2), atan(p.y, p.x)) + end + @test (@eval @match Foo(1,1) begin + polar(r,θ) => r == sqrt(2) && θ == π / 4 + _ => false + end) +end + +@testset "extractor function that might fail" begin + @eval function Match.extract(::Val{:diff}, p::Foo) + return p.x >= p.y ? (p.x - p.y,) : nothing + end + @test (@eval @match Foo(1,1) begin + diff(2) => false + diff(1) => false + diff(0) => true + _ => false + end) + @test (@eval @match Foo(2,1) begin + diff(2) => false + diff(1) => true + _ => false + end) + @test (@eval @match Foo(1,2) begin + diff(2) => false + diff(1) => false + _ => true + end) +end + +@testset "extractor function missing" begin + @test_throws LoadError (@eval @match Foo(1,1) begin + Bar(0) => true + _ => false + end) +end + @testset "Miscellanea" begin # match against fiddly symbols (https://github.com/JuliaServices/Match.jl/issues/32) @test (@match :(@when a < b) begin From 44bd65fcc6f5e88ab3144f86db2edffeaa74f7dc Mon Sep 17 00:00:00 2001 From: Nate Nystrom Date: Wed, 13 Nov 2024 12:05:34 +0100 Subject: [PATCH 2/5] minor fix --- src/binding.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/binding.jl b/src/binding.jl index 025d093..1e10981 100644 --- a/src/binding.jl +++ b/src/binding.jl @@ -271,9 +271,9 @@ function bind_pattern!( "nor is there a `Match.extract(::Val{$T}, _)` implementation.") end fetch = BoundFetchExpressionPattern(extract, nothing, Any) - temp1 = push_pattern!(patterns, binder, fetch) - subpattern, assigned = bind_pattern!(location, Expr(:tuple, subpatterns...), temp1, binder, assigned) - patterns = BoundPattern[fetch, subpattern] + temp = push_pattern!(patterns, binder, fetch) + subpattern, assigned = bind_pattern!(location, Expr(:tuple, subpatterns...), temp, binder, assigned) + push!(patterns, subpattern) else # bind type at macro expansion time pattern0, assigned = bind_pattern!(location, :(::($T)), input, binder, assigned) From 505900eb6892166703e339aa6e82ccafae5acddf Mon Sep 17 00:00:00 2001 From: Nate Nystrom Date: Wed, 13 Nov 2024 14:46:09 +0100 Subject: [PATCH 3/5] nested extractor test --- test/rematch.jl | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/test/rematch.jl b/test/rematch.jl index faa4c6e..809c34a 100644 --- a/test/rematch.jl +++ b/test/rematch.jl @@ -506,6 +506,16 @@ end end) end +@testset "nested extractor function" begin + @eval function Match.extract(::Val{:foo}, p::Foo) + return (p.x, p.y) + end + @test_throws LoadError (@eval @match Foo(Foo(1,2),3) begin + foo(foo(1,2),3) => true + _ => false + end) +end + @testset "extractor function missing" begin @test_throws LoadError (@eval @match Foo(1,1) begin Bar(0) => true From 515e886e39bba579d23ba62d8a7ed3baeb8b4f56 Mon Sep 17 00:00:00 2001 From: Nate Nystrom Date: Wed, 13 Nov 2024 15:01:26 +0100 Subject: [PATCH 4/5] Explicit AST for extractor fetch --- src/binding.jl | 10 ++++++---- src/bound_pattern.jl | 16 ++++++++++++++++ src/lowering.jl | 3 +++ src/pretty.jl | 7 +++++++ test/rematch.jl | 3 ++- 5 files changed, 34 insertions(+), 5 deletions(-) diff --git a/src/binding.jl b/src/binding.jl index 1e10981..88af6e5 100644 --- a/src/binding.jl +++ b/src/binding.jl @@ -110,6 +110,9 @@ end function gentemp(p::BoundFetchLengthPattern)::Symbol gensym(string("length(", simple_name(p.input), ")")) end +function gentemp(p::BoundFetchExtractorPattern)::Symbol + gensym(string("extract(", p.extractor, ", ", simple_name(p.input), ")")) +end # # The following are special bindings used to handle the point where @@ -260,17 +263,16 @@ function bind_pattern!( if T isa Symbol && match_positionally && !is_type # TODO support named tuples patterns = BoundPattern[] - # call Match.extract(Val(T), input) and match the result against the tuple of subpatterns - extract = BoundExpression(location, Expr(:call, Match.extract, Val(T), input)) # check that the extractor exists methods = Base.methods(Match.extract, (Val{T}, Any,)) if length(methods) <= 1 # If there's only one matching Match.extract method, it's the default one. # that returns nothing. Worse if there are fewer than one. error("$(location.file):$(location.line): `$T` is neither a type name " * - "nor is there a `Match.extract(::Val{$T}, _)` implementation.") + "nor is there a `Match.extract(::Val{$T}, _)` implementation.") end - fetch = BoundFetchExpressionPattern(extract, nothing, Any) + # call Match.extract(Val(T), input) and match the result against the tuple of subpatterns + fetch = BoundFetchExtractorPattern(location, source, input, T, Any) temp = push_pattern!(patterns, binder, fetch) subpattern, assigned = bind_pattern!(location, Expr(:tuple, subpatterns...), temp, binder, assigned) push!(patterns, subpattern) diff --git a/src/bound_pattern.jl b/src/bound_pattern.jl index 2e59388..a3ec139 100644 --- a/src/bound_pattern.jl +++ b/src/bound_pattern.jl @@ -286,6 +286,22 @@ function Base.:(==)(a::BoundFetchLengthPattern, b::BoundFetchLengthPattern) a.input == b.input end +# Fetch a value using the given extractor function into a temporary. See +# `BoundFetchFieldPattern` for the general idea of how these are used. +struct BoundFetchExtractorPattern <: BoundFetchPattern + location::LineNumberNode + source::Any + input::Symbol + extractor::Symbol + type::Type +end +function Base.hash(a::BoundFetchExtractorPattern, h::UInt64) + hash((a.input, a.extractor, 0xd7882f5b4888d335), h) +end +function Base.:(==)(a::BoundFetchExtractorPattern, b::BoundFetchExtractorPattern) + a.input == b.input && a.extractor == b.extractor +end + # Preserve the value of the expression into a temp. Used # (1) to force the binding on both sides of an or-pattern to be the same (a phi), and # (2) to load the value of a `where` clause. diff --git a/src/lowering.jl b/src/lowering.jl index aba1482..f4c9669 100644 --- a/src/lowering.jl +++ b/src/lowering.jl @@ -89,6 +89,9 @@ end function code(bound_pattern::BoundFetchExpressionPattern) code(bound_pattern.bound_expression) end +function code(bound_pattern::BoundFetchExtractorPattern) + Expr(:call, Match.extract, Val(bound_pattern.extractor), bound_pattern.input) +end # Return an expression that computes whether or not the pattern matches. function lower_pattern_to_boolean(bound_pattern::BoundPattern, binder::BinderContext) diff --git a/src/pretty.jl b/src/pretty.jl index 50a488c..e512f10 100644 --- a/src/pretty.jl +++ b/src/pretty.jl @@ -185,6 +185,13 @@ function pretty(io::IO, p::BoundFetchLengthPattern) pretty(io, p.input) print(io, ")") end +function pretty(io::IO, p::BoundFetchExtractorPattern) + print(io, "extract(") + pretty(io, p.extractor) + print(io, ", ") + pretty(io, p.input) + print(io, ")") +end function pretty(io::IO, p::BoundFetchExpressionPattern) pretty(io, p.bound_expression) end diff --git a/test/rematch.jl b/test/rematch.jl index 809c34a..206169a 100644 --- a/test/rematch.jl +++ b/test/rematch.jl @@ -510,13 +510,14 @@ end @eval function Match.extract(::Val{:foo}, p::Foo) return (p.x, p.y) end - @test_throws LoadError (@eval @match Foo(Foo(1,2),3) begin + @test (@eval @match Foo(Foo(1,2),3) begin foo(foo(1,2),3) => true _ => false end) end @testset "extractor function missing" begin + # Bar is neither a type nor an extractor function @test_throws LoadError (@eval @match Foo(1,1) begin Bar(0) => true _ => false From b227f334988ae12dbb3147a6b1d759ca5d074408 Mon Sep 17 00:00:00 2001 From: Nate Nystrom Date: Wed, 13 Nov 2024 20:19:30 +0100 Subject: [PATCH 5/5] PR comment --- src/binding.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/binding.jl b/src/binding.jl index 88af6e5..bd1c71d 100644 --- a/src/binding.jl +++ b/src/binding.jl @@ -67,7 +67,7 @@ function bind_type(location, T, input, binder) bound_type end function try_bind_type(location, T, input, binder) - # bind type at macro expansion time. It will be verified at runtime. + # bind type at macro expansion time in the caller's module. bound_type = nothing try bound_type = Core.eval(binder.mod, Expr(:block, location, T))