From 7b0aee4b47312de3986ec2a72cb78cc1d661405d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Mon, 13 Jan 2025 12:10:11 +0100 Subject: [PATCH 1/3] Fix fused_map_reduce when result is a matrix --- src/reduce.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/reduce.jl b/src/reduce.jl index c1644470..adb3eda1 100644 --- a/src/reduce.jl +++ b/src/reduce.jl @@ -22,7 +22,7 @@ reduce_op(op::AddSubMul) = add_sub_op(op) reduce_op(::typeof(add_dot)) = + -neutral_element(::typeof(+), T::Type) = zero(T) +neutral_element(::typeof(+), T::Type) = Zero() map_op(::AddSubMul) = * From cec63b1aac7d94dc43ef12eb1a69ff955dee00a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Mon, 13 Jan 2025 14:32:27 +0100 Subject: [PATCH 2/3] Instantiate zero --- src/MutableArithmetics.jl | 4 +-- src/reduce.jl | 27 ++++++++++++++++++++- test/reduce.jl | 51 +++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + 4 files changed, 80 insertions(+), 3 deletions(-) create mode 100644 test/reduce.jl diff --git a/src/MutableArithmetics.jl b/src/MutableArithmetics.jl index 15025542..d99c7ecf 100644 --- a/src/MutableArithmetics.jl +++ b/src/MutableArithmetics.jl @@ -113,8 +113,6 @@ include("implementations/SparseArrays.jl") include("evalpoly.jl") -include("reduce.jl") - """ isequal_canonical(a, b) @@ -181,6 +179,8 @@ end include("rewrite.jl") include("rewrite_generic.jl") + +include("reduce.jl") include("dispatch.jl") # Test that can be used to test an implementation of the interface diff --git a/src/reduce.jl b/src/reduce.jl index adb3eda1..686e0356 100644 --- a/src/reduce.jl +++ b/src/reduce.jl @@ -24,6 +24,27 @@ reduce_op(::typeof(add_dot)) = + neutral_element(::typeof(+), T::Type) = Zero() +""" + instantiate_zero(x, ::Type{T}) where {T} + +If `x` is `Zero` and `zero(::T)` is defined, then returns `zero(T)`. +Otherwise, `zero(x)` is returned. +For instance, `instantiate_zero(Zero(), Matrix{Int})` returns `Zero()` +because `zero(::Matrix)` is not defined. +Types that don't define `zero` should explicitly implement a new method +for this function that return `Zero()`. +""" +function instantiate_zero end + +instantiate_zero(x, ::Type) = x + +instantiate_zero(::Zero, ::Type{T}) where {T} = zero(T) + +# The arrays of `StaticArrays.jl` actually implement `zero` even though they +# are subtypes of `AbstractArray` but with this method, it will be `Zero()` +# anyway. At least it is consistent with other subtypes of `AbstractArray`. +instantiate_zero(::Zero, ::Type{<:AbstractArray}) = Zero() + map_op(::AddSubMul) = * map_op(::typeof(add_dot)) = LinearAlgebra.dot @@ -47,7 +68,11 @@ function fused_map_reduce(op::F, args::Vararg{Any,N}) where {F<:Function,N} accumulator = buffered_operate!!(buffer, op, accumulator, getindex.(args, I)...) end - return accumulator + # If there are no elements, instead of returning `MA.Zero`, we return + # `zero(T)` unless we know `zero(::T)` is not defined like if `T` is `Matrix{...}`. + # Returning `Zero()` could also work but it would be breaking so we opt for + # returning `zero(T)` when possible. + return instantiate_zero(accumulator, T) end function operate(::typeof(sum), a::AbstractArray) diff --git a/test/reduce.jl b/test/reduce.jl new file mode 100644 index 00000000..0c24092f --- /dev/null +++ b/test/reduce.jl @@ -0,0 +1,51 @@ +# Copyright (c) 2019 MutableArithmetics.jl contributors +# +# This Source Code Form is subject to the terms of the Mozilla Public License, +# v.2.0. If a copy of the MPL was not distributed with this file, You can obtain +# one at http://mozilla.org/MPL/2.0/. + +module TestReduce + +using Test + +import MutableArithmetics as MA +using LinearAlgebra + +function runtests() + for name in names(@__MODULE__; all = true) + if startswith("$(name)", "test_") + @testset "$(name)" begin + getfield(@__MODULE__, name)() + end + end + end + return +end + +function _test_is_zero(x, ::Type{T}) where {T} + @test iszero(x) + @test x isa T + return +end + +function test_empty_dot() + _test_is_zero(MA.operate(dot, Int[], Int[]), Int) + _test_is_zero(MA.operate(dot, BigInt[], Int[]), BigInt) + _test_is_zero(MA.operate(dot, Int[], Float64[]), Float64) + _test_is_zero(MA.operate(dot, Matrix{Int}[], Matrix{Float64}[]), Float64) + @test MA.fused_map_reduce(MA.add_mul, Matrix{Int}[], Float64[]) isa MA.Zero + @test MA.fused_map_reduce(MA.add_mul, Float64[], Matrix{Int}[]) isa MA.Zero + return +end + +function test_add_mul_matrix() + A = [1 2; 3 4] + B = [3 1; 1 2] + @test MA.fused_map_reduce(MA.add_mul, [A, B], [2, -1]) == 2A - B + @test MA.fused_map_reduce(MA.add_mul, [-1, 3], [A, B]) == 3B - A + return +end + +end # module + +TestReduce.runtests() diff --git a/test/runtests.jl b/test/runtests.jl index 9a9bb53c..a1163be4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -38,6 +38,7 @@ end end include("matmul.jl") include("dispatch.jl") +include("reduce.jl") include("rewrite.jl") include("rewrite_generic.jl") From 211701be4f990a7a3a9764148ac930d1b1ce7d6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Mon, 13 Jan 2025 15:52:25 +0100 Subject: [PATCH 3/3] Add Julia v1.10 --- .github/workflows/ci.yml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a360cf09..f820716a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -24,6 +24,12 @@ jobs: - version: '1.6' os: ubuntu-latest arch: x64 + - version: '1.10' + os: ubuntu-latest + arch: x86 + - version: '1.10' + os: ubuntu-latest + arch: x64 - version: '1' os: ubuntu-latest arch: x64