Skip to content

Commit

Permalink
More fixes for interface tests
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Apr 22, 2024
1 parent c53e70b commit 20cc8e7
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 10 deletions.
22 changes: 20 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
name = "SparseDiffTools"
uuid = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
authors = ["Pankaj Mishra <[email protected]>", "Chris Rackauckas <[email protected]>"]
authors = [
"Pankaj Mishra <[email protected]>",
"Chris Rackauckas <[email protected]>",
]
version = "2.19.0"

[deps]
Expand Down Expand Up @@ -68,6 +71,7 @@ Zygote = "0.6"
julia = "1.10"

[extras]
AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a"
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Expand All @@ -83,4 +87,18 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "BandedMatrices", "BlockBandedMatrices", "Enzyme", "IterativeSolvers", "Pkg", "Random", "SafeTestsets", "Symbolics", "Zygote", "StaticArrays"]
test = [
"Test",
"AllocCheck",
"BandedMatrices",
"BlockBandedMatrices",
"Enzyme",
"IterativeSolvers",
"Pkg",
"PolyesterForwardDiff",
"Random",
"SafeTestsets",
"Symbolics",
"Zygote",
"StaticArrays",
]
3 changes: 3 additions & 0 deletions src/SparseDiffTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ end

abstract type AbstractAutoDiffVecProd end

my_dense_ad(ad::AbstractADType) = ad
my_dense_ad(ad::AutoSparse) = ADTypes.dense_ad(ad)

include("coloring/high_level.jl")
include("coloring/backtracking_coloring.jl")
include("coloring/contraction_coloring.jl")
Expand Down
9 changes: 7 additions & 2 deletions src/highlevel/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,13 @@ If `fx` is not specified, it will be computed by calling `f(x)`.
A cache for computing the Jacobian of type `AbstractMaybeSparseJacobianCache`.
"""
function sparse_jacobian_cache(
ad::AbstractADType, sd::AbstractSparsityDetection, args...; kwargs...)
return sparse_jacobian_cache_aux(mode(ad), ad, sd, args...; kwargs...)
ad::AbstractADType, sd::AbstractSparsityDetection, f, x; fx = nothing)
return sparse_jacobian_cache_aux(mode(ad), ad, sd, f, x; fx)
end

function sparse_jacobian_cache(
ad::AbstractADType, sd::AbstractSparsityDetection, f!, x, fx)
return sparse_jacobian_cache_aux(mode(ad), ad, sd, f!, x, fx)
end

function sparse_jacobian_static_array(ad, cache, f, x::SArray)
Expand Down
4 changes: 2 additions & 2 deletions src/highlevel/forward_mode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ function sparse_jacobian_cache_aux(
sd::AbstractMaybeSparsityDetection, f::F, x; fx = nothing) where {F}
coloring_result = sd(ad, f, x)
fx = fx === nothing ? similar(f(x)) : fx
tag = __standard_tag(ad.tag, f, x)
tag = __standard_tag(my_dense_ad(ad).tag, f, x)
if coloring_result isa NoMatrixColoring
cache = ForwardDiff.JacobianConfig(f, x, __chunksize(ad, x), tag)
jac_prototype = nothing
Expand All @@ -34,7 +34,7 @@ function sparse_jacobian_cache_aux(
::ForwardMode, ad::Union{AutoSparse{<:AutoForwardDiff}, AutoForwardDiff},
sd::AbstractMaybeSparsityDetection, f!::F, fx, x) where {F}
coloring_result = sd(ad, f!, fx, x)
tag = __standard_tag(ad.tag, f!, x)
tag = __standard_tag(my_dense_ad(ad).tag, f!, x)
if coloring_result isa NoMatrixColoring
cache = ForwardDiff.JacobianConfig(f!, fx, x, __chunksize(ad, x), tag)
jac_prototype = nothing
Expand Down
3 changes: 0 additions & 3 deletions test/1.10specific/Project.toml

This file was deleted.

1 change: 0 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ if GROUP == "Core" || GROUP == "All"
end

if GROUP == "InterfaceI" || GROUP == "All"
activate_env("1.10specific")
@time @safetestset "Jac Vecs and Hes Vecs" begin
include("test_jaches_products.jl")
end
Expand Down

0 comments on commit 20cc8e7

Please sign in to comment.