From ce3d7aaca628dfc5293f48762696235f7eed0172 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 19 Nov 2023 13:21:01 -0500 Subject: [PATCH] Generalizations to make batching work --- src/utils.jl | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 6a43acc80..0f61b2180 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -276,8 +276,10 @@ end __init_identity_jacobian(u::Number, _) = u function __init_identity_jacobian(u, fu) - return convert(parameterless_type(_mutable(u)), - Matrix{eltype(u)}(I, length(fu), length(u))) + J = similar(u, promote_type(eltype(u), eltype(fu)), length(fu), length(u)) + fill!(J, 0) + J[diagind(J)] .= 1 + return J end function __init_identity_jacobian(u::StaticArray, fu) return convert(MArray{Tuple{length(fu), length(u)}}, @@ -291,8 +293,10 @@ function __init_low_rank_jacobian(u::StaticArray, fu, threshold::Int) return U, Vᵀ end function __init_low_rank_jacobian(u, fu, threshold::Int) - Vᵀ = convert(parameterless_type(_mutable(u)), zeros(eltype(u), length(u), threshold)) - U = convert(parameterless_type(_mutable(u)), zeros(eltype(u), threshold, length(u))) + Vᵀ = similar(u, promote_type(eltype(u), eltype(fu)), length(u), threshold) + U = similar(u, promote_type(eltype(u), eltype(fu)), threshold, length(u)) + fill!(Vᵀ, 0) + fill!(U, 0) return U, Vᵀ end