Skip to content

Commit

Permalink
Revert "src/PositionalEmbeddings.jl: Fix neg_half to correctly work w…
Browse files Browse the repository at this point in the history
…ith pairs."

Actually, using concatenation is more efficient and claimed to be equivalent:
huggingface/transformers#25199
  • Loading branch information
mashu committed Nov 23, 2024
1 parent 385e849 commit fd4124c
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 24 deletions.
8 changes: 4 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
name = "PositionalEmbeddings"
uuid = "d504d84d-5e64-4f13-be9b-c14c41279bd1"
authors = ["Mateusz Kaduk <[email protected]> and contributors"]
version = "0.2.0"
version = "0.3.0"

[deps]
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
CUDA = "5.5.2"
Expand All @@ -13,9 +14,8 @@ Zygote = "0.6.41"
julia = "1.9"

[extras]
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"

[targets]
test = ["Test", "CUDA", "Zygote"]
test = ["Test", "CUDA"]
15 changes: 5 additions & 10 deletions src/PositionalEmbeddings.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ module PositionalEmbeddings
neg_half(x::AbstractArray{T}, dim::Int=1) where T
Helper function that negates the second half of the array along dimension `dim`.
This implementatio uses half negative array instead of interleaving pairs, as in LlaMA
https://github.com/huggingface/transformers/issues/25199
# Arguments
- `x::AbstractArray{T}`: Input array
Expand All @@ -136,16 +138,9 @@ module PositionalEmbeddings
- Array with second half negated along specified dimension
"""
function neg_half(x::AbstractArray{T}, dim::Int=1) where T
# Get even and odd indices
even_indices = 1:2:size(x, dim)
odd_indices = 2:2:size(x, dim)

# Use views for even and odd elements
x_even = view(x, even_indices, :, :)
x_odd = view(x, odd_indices, :, :)

# Combine using vcat in the correct order: [-odd; even]
vcat(-x_odd, x_even)
d_2 = size(x, dim) ÷ 2
vcat(-view(x, d_2+1:size(x,dim), :, :),
view(x, 1:d_2, :, :))
end

"""
Expand Down
20 changes: 10 additions & 10 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ end
(features*seq_len*batch_size)

reference = Float32[
9.76563f-5 0.000195313 0.000292969 0.000390625 0.000488281
-0.000582162 0.00106326 0.00163171 0.00187805 0.00202529
-0.00436025 0.000599911 0.00261259 0.00326171 0.00353053
-0.00542633 -0.00128126 0.00319962 0.00453726 0.00500356
0.000728724 -0.00439145 0.00336443 0.00570055 0.00644391
9.76563f-5 0.000195313 0.000292969 0.000390625 0.000488281;
-0.00115739 0.000881045 0.00158297 0.00186569 0.00202236;
-0.00498184 0.000253548 0.00251558 0.00323702 0.00352467;
-0.0055228 -0.00175742 0.00305532 0.00450025 0.00499477;
0.00124607 -0.00495019 0.00317429 0.00565127 0.00643219
]

rope = RoPE(features, seq_len)
Expand Down Expand Up @@ -76,11 +76,11 @@ end
x = cu(x)

reference = Float32[
9.76563f-5 0.000195313 0.000292969 0.000390625 0.000488281
-0.000582162 0.00106326 0.00163171 0.00187805 0.00202529
-0.00436025 0.000599911 0.00261259 0.00326171 0.00353053
-0.00542633 -0.00128126 0.00319962 0.00453726 0.00500356
0.000728724 -0.00439145 0.00336443 0.00570055 0.00644391
9.76563f-5 0.000195313 0.000292969 0.000390625 0.000488281;
-0.00115739 0.000881045 0.00158297 0.00186569 0.00202236;
-0.00498184 0.000253548 0.00251558 0.00323702 0.00352467;
-0.0055228 -0.00175742 0.00305532 0.00450025 0.00499477;
0.00124607 -0.00495019 0.00317429 0.00565127 0.00643219
]

rope = RoPE(features, seq_len)
Expand Down

0 comments on commit fd4124c

Please sign in to comment.