Skip to content

Commit

Permalink
Merge pull request #1493 from lkdvos/ld/sort
Browse files Browse the repository at this point in the history
  • Loading branch information
darsnack authored Jan 4, 2024
2 parents c1d82be + 03a8ef7 commit 9df7226
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 15 deletions.
14 changes: 0 additions & 14 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -251,20 +251,6 @@ end

@adjoint iterate(r::UnitRange, i...) = iterate(r, i...), _ -> nothing

@adjoint function sort(x::AbstractArray; by=identity)
p = sortperm(x, by=by)
return x[p], x̄ -> (x̄[invperm(p)],)
end

@adjoint function filter(f, x::AbstractVector)
t = map(f, x)
x[t], Δ -> begin
dx = _zero(x, eltype(Δ))
dx[t] .= Δ
(nothing, dx)
end
end

# Iterators

@adjoint function enumerate(xs)
Expand Down
6 changes: 5 additions & 1 deletion test/gradcheck.jl
Original file line number Diff line number Diff line change
Expand Up @@ -425,13 +425,17 @@ end
[2,3,1],
[1, 2, 3],
[1,2,3],
[2,1,3]
[2,1,3],
[1,3,2],
[3,2,1]
]
for i = 1:3
@test gradient(v->sort(v)[i], [3.,1,2])[1][correct[1][i]] == 1
@test gradient(v->sort(v)[i], [1.,2,3])[1][correct[2][i]] == 1
@test gradient(v->sort(v,by=x->x%10)[i], [11,2,99])[1][correct[3][i]] == 1
@test gradient(v->sort(v,by=x->x%10)[i], [2,11,99])[1][correct[4][i]] == 1
@test gradient(v->sort(v,rev=true)[i], [3.,1,2])[1][correct[5][i]] == 1
@test gradient(v->sort(v,rev=true)[i], [1.,2,3])[1][correct[6][i]] == 1
end
end

Expand Down

0 comments on commit 9df7226

Please sign in to comment.