Skip to content

Commit

Permalink
Write quite a lot of tests for IsometricKroneckerProduct
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanaelbosch committed Oct 28, 2023
1 parent c414577 commit 87c3e63
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 72 deletions.
32 changes: 16 additions & 16 deletions src/fast_linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,19 +97,19 @@ function fast_X_A_Xt!(out::PSDMatrix, A::PSDMatrix, X::AbstractMatrix)
return out
end

"""
alloc_free_get_U!(C::Cholesky)
Allocation-free version of `C.U`.
THIS MODIFIES `C.factors` SO AFTERWARDS `C` SHOULD NOT BE USED ANYMORE!
"""
function alloc_free_get_U!(C::Cholesky)
Cuplo = getfield(C, :uplo)
Cfactors = getfield(C, :factors)
if Cuplo === LinearAlgebra.char_uplo(:U)
return getupperright!(Cfactors)
else
return getupperright!(Cfactors')
end
end
# """
# alloc_free_get_U!(C::Cholesky)

# Allocation-free version of `C.U`.

# THIS MODIFIES `C.factors` SO AFTERWARDS `C` SHOULD NOT BE USED ANYMORE!
# """
# function alloc_free_get_U!(C::Cholesky)
# Cuplo = getfield(C, :uplo)
# Cfactors = getfield(C, :factors)
# if Cuplo === LinearAlgebra.char_uplo(:U)
# return getupperright!(Cfactors)
# else
# return getupperright!(Cfactors')
# end
# end
124 changes: 68 additions & 56 deletions src/kronecker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,26 @@ IsometricKroneckerProduct(ldim::Integer, B::AbstractVector) =
IsometricKroneckerProduct(ldim, reshape(B, :, 1))

const IKP = IsometricKroneckerProduct
get_right_factor(K::IKP) = K.B
get_left_factor_dim(K::IKP) = K.B

Kronecker.getmatrices(K::IKP) = (I(K.ldim), K.B)

Base.zero(A::IKP) = IsometricKroneckerProduct(A.ldim, zero(A.B))
Base.one(A::IKP) = IsometricKroneckerProduct(A.ldim, one(A.B))
copy!(A::IKP, B::IKP) = begin
check_same_size(A, B)
copy!(A.B, B.B)
return A
end
copy(A::IKP) = IsometricKroneckerProduct(A.ldim, copy(A.B))
similar(A::IKP) = IsometricKroneckerProduct(A.ldim, similar(A.B))
Base.size(K::IKP) = (K.ldim * size(K.B, 1), K.ldim * size(K.B, 2))

# conversion
Base.convert(::Type{T}, K::IKP) where {T<:IKP} =
K isa T ? K : T(K)
function IKP{T,TB}(K::IKP) where {T,TB}
IKP(K.ldim, convert(TB, K.B))
end

function Base.:*(A::IKP, B::IKP)
@assert A.ldim == B.ldim
Expand All @@ -55,7 +68,7 @@ function check_matmul_sizes(A::IKP, B::IKP)
# For A * B
Ad, Bd = A.ldim, B.ldim
An, Am, Bn, Bm = size(A)..., size(B)...
if !(A.ldim == B.ldim) || !(Am == Bnb)
if !(A.ldim == B.ldim) || !(Am == Bn)
throw(
DimensionMismatch(
"Matrix multiplication not compatible: A has size ($Ad$An,$Ad$Am), B has size ($Bd$Bn,$Bd$Bm)",
Expand Down Expand Up @@ -93,6 +106,18 @@ Base.:\(A::IKP, B::IKP) = begin
return IsometricKroneckerProduct(A.ldim, A.B \ B.B)
end

mul!(A::IKP, B::IKP, C::IKP) = begin
check_matmul_sizes(A, B, C)
mul!(A.B, B.B, C.B)
return A
end
mul!(A::IKP, B::IKP, C::IKP, alpha::Number, beta::Number) = begin
check_matmul_sizes(A, B, C)
mul!(A.B, B.B, C.B, alpha, beta)
return A
end

# fast_linalg.jl
_matmul!(A::IKP, B::IKP, C::IKP) = begin
check_matmul_sizes(A, B, C)
_matmul!(A.B, B.B, C.B)
Expand All @@ -105,7 +130,7 @@ _matmul!(A::IKP{T}, B::IKP{T}, C::IKP{T}) where {T<:LinearAlgebra.BlasFloat} = b
end
_matmul!(A::IKP, B::IKP, C::IKP, alpha::Number, beta::Number) = begin
check_matmul_sizes(A, B, C)
_matmul!(A.B, B.B, C.B)
_matmul!(A.B, B.B, C.B, alpha, beta)
return A
end
_matmul!(
Expand All @@ -119,21 +144,6 @@ _matmul!(
_matmul!(A.B, B.B, C.B, alpha, beta)
return A
end
copy!(A::IKP, B::IKP) = begin
check_same_size(A, B)
copy!(A.B, B.B)
return A
end
copy(A::IKP) = IsometricKroneckerProduct(A.ldim, copy(A.B))
similar(A::IKP) = IsometricKroneckerProduct(A.ldim, similar(A.B))
Base.size(K::IKP) = (K.ldim * size(K.B, 1), K.ldim * size(K.B, 2))

# conversion
Base.convert(::Type{T}, K::IKP) where {T<:IKP} =
K isa T ? K : T(K)
function IKP{T,TB}(K::IKP) where {T,TB}
IKP(K.ldim, convert(TB, K.B))
end

"""
Allocation-free reshape
Expand All @@ -154,12 +164,7 @@ function mul_vectrick!(x::AbstractVecOrMat, A::IKP, v::AbstractVecOrMat)
return x
end
function mul_vectrick!(
x::AbstractVecOrMat,
A::IKP,
v::AbstractVecOrMat,
alpha::Number,
beta::Number,
)
x::AbstractVecOrMat, A::IKP, v::AbstractVecOrMat, alpha::Number, beta::Number)
N = A.B
c, d = size(N)

Expand All @@ -169,39 +174,46 @@ function mul_vectrick!(
return x
end

_matmul!(C::AbstractVecOrMat, A::IKP, B::AbstractVecOrMat) = mul_vectrick!(C, A, B)
mul!(C::AbstractMatrix, A::IKP, B::AbstractMatrix) = mul_vectrick!(C, A, B)
mul!(C::AbstractMatrix, A::IKP, B::Adjoint{T,<:AbstractMatrix{T}}) where {T} =
mul_vectrick!(C, A, B)
mul!(C::AbstractVector, A::IKP, B::AbstractVector) = mul_vectrick!(C, A, B)

_matmul!(
C::AbstractVecOrMat{T},
A::IKP{T},
B::AbstractVecOrMat{T},
) where {T<:LinearAlgebra.BlasFloat} = mul_vectrick!(C, A, B)
_matmul!(C::AbstractVecOrMat, A::AbstractVecOrMat, B::IKP) = _matmul!(C', B', A')
_matmul!(
C::AbstractVecOrMat{T},
A::AbstractVecOrMat{T},
B::IKP{T},
) where {T<:LinearAlgebra.BlasFloat} = _matmul!(C', B', A')

_matmul!(C::AbstractVecOrMat, A::IKP, B::AbstractVecOrMat, alpha::Number, beta::Number) =
mul_vectrick!(C, A, B, alpha, beta)
_matmul!(
C::AbstractVecOrMat{T},
A::IKP{T},
B::AbstractVecOrMat{T},
alpha::Number,
beta::Number,
) where {T<:LinearAlgebra.BlasFloat} = mul_vectrick!(C, A, B, alpha, beta)
_matmul!(C::AbstractVecOrMat, A::AbstractVecOrMat, B::IKP, alpha::Number, beta::Number) =
mul_vectrick!(C', B', A', alpha, beta)
_matmul!(
C::AbstractVecOrMat{T},
A::AbstractVecOrMat{T},
B::IKP{T},
alpha::Number,
beta::Number,
) where {T<:LinearAlgebra.BlasFloat} = mul_vectrick!(C', B', A', alpha, beta)
for TC in [:AbstractVector, :AbstractMatrix]
@eval mul!(C::$TC, A::IKP, B::$TC) = mul_vectrick!(C, A, B)
@eval mul!(C::$TC, A::IKP, B::Adjoint{T,<:$TC{T}}) where {T} = mul_vectrick!(C, A, B)
@eval mul!(C::$TC, A::IKP, B::$TC, alpha::Number, beta::Number) =
mul_vectrick!(C, A, B, alpha, beta)

@eval _matmul!(C::$TC, A::IKP, B::$TC) = mul_vectrick!(C, A, B)
@eval _matmul!(
C::$TC{T},
A::IKP{T},
B::$TC{T},
) where {T<:LinearAlgebra.BlasFloat} = mul_vectrick!(C, A, B)
@eval _matmul!(C::$TC, A::$TC, B::IKP) = _matmul!(C', B', A')
@eval _matmul!(
C::$TC{T},
A::$TC{T},
B::IKP{T},
) where {T<:LinearAlgebra.BlasFloat} = _matmul!(C', B', A')

@eval _matmul!(C::$TC, A::IKP, B::$TC, alpha::Number, beta::Number) =
mul_vectrick!(C, A, B, alpha, beta)
@eval _matmul!(
C::$TC{T},
A::IKP{T},
B::$TC{T},
alpha::Number,
beta::Number,
) where {T<:LinearAlgebra.BlasFloat} = mul_vectrick!(C, A, B, alpha, beta)
@eval _matmul!(C::$TC, A::$TC, B::IKP, alpha::Number, beta::Number) =
mul_vectrick!(C', B', A', alpha, beta)
@eval _matmul!(
C::$TC{T},
A::$TC{T},
B::IKP{T},
alpha::Number,
beta::Number,
) where {T<:LinearAlgebra.BlasFloat} = mul_vectrick!(C', B', A', alpha, beta)
end
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ const GROUP = get(ENV, "GROUP", "All")
@timedsafetestset "Smoothing" begin
include("smoothing.jl")
end
@timedsafetestset "IsometricKroneckerProduct" begin
include("core/kronecker.jl")
end
end
end

Expand Down

0 comments on commit 87c3e63

Please sign in to comment.