Skip to content

Commit

Permalink
restrict _tryaxes to AbstractArrays since iterators would error
Browse files Browse the repository at this point in the history
  • Loading branch information
lxvm committed Jan 17, 2024
1 parent e13e029 commit b37baf7
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ _reverse(x::Symmetric) = Symmetric(_reverse(x.data), x.uplo == 'U' ? :L : :U)

# With mismatched lengths, map stops early. With mismatched shapes, it makes a vector.
# So we keep axes(x) to restore gradient dx to its full length & correct shape.
_tryaxes(x) = axes(x)
_tryaxes(x::AbstractArray) = axes(x)
_tryaxes(x::Tuple) = Val(length(x))
_restore(dx, ax::Tuple) = axes(dx) == ax ? dx : reshape(vcat(dx, falses(prod(length, ax) - length(dx))), ax)
_restore(dx, ::Val{N}) where {N} = ntuple(i -> get(dx,i,nothing), N)
Expand Down

0 comments on commit b37baf7

Please sign in to comment.