Skip to content

Commit

Permalink
Remove GPU sum() rule
Browse files Browse the repository at this point in the history
  • Loading branch information
ToucheSir authored Sep 1, 2023
1 parent 69c2616 commit 73ebf4b
Showing 1 changed file with 0 additions and 5 deletions.
5 changes: 0 additions & 5 deletions src/lib/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -364,11 +364,6 @@ using GPUArraysCore # replaces @require CUDA block, weird indenting to preserve
@adjoint (::Type{T})(xs::Array) where {T <: AbstractGPUArray} =
T(xs), Δ -> (convert(Array, Δ), )

@adjoint function sum(xs::AbstractGPUArray; dims = :)
placeholder = similar(xs)
sum(xs, dims = dims), Δ -> (placeholder .= Δ,)
end

# Make sure sum(f, ::CuArray) uses broadcast through forward-mode defined above
# Not the ChainRules.rrule which will use the Zygote.Context and thus not be GPU compatible
function _pullback(cx::AContext, ::typeof(sum), f, xs::AbstractGPUArray)
Expand Down

0 comments on commit 73ebf4b

Please sign in to comment.