Skip to content

Commit

Permalink
ext/CUDAext.jl: Add weakdep to CUDA and copying RoPE cache to GPU.
Browse files Browse the repository at this point in the history
  • Loading branch information
mashu committed Nov 23, 2024
1 parent d1654c5 commit c65d81b
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 14 deletions.
7 changes: 4 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PositionalEmbeddings"
uuid = "d504d84d-5e64-4f13-be9b-c14c41279bd1"
authors = ["Mateusz Kaduk <[email protected]> and contributors"]
version = "0.1.0"
version = "0.3.0"

[deps]
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Expand All @@ -10,7 +10,7 @@ Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"

[extensions]
CUDAExt = "CUDA"
CUDAext = "CUDA"

[compat]
CUDA = "5.5.2"
Expand All @@ -19,8 +19,9 @@ Zygote = "0.6.41"
julia = "1.9"

[extras]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "Zygote"]
test = ["Test", "Zygote", "CUDA"]
7 changes: 4 additions & 3 deletions ext/CUDAext.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module CUDAext
import CUDA: cu
using CUDA
import PositionalEmbeddings: cu
using PositionalEmbeddings: RoPE

"""
CUDA.cu(rope::RoPE{T, A}) where {T, A<:AbstractArray{T}}
Expand All @@ -10,7 +11,7 @@ module CUDAext
function CUDA.cu(rope::RoPE{T, A}) where {T, A<:AbstractArray{T}}
cos_cached_gpu = CUDA.CuArray(rope.cos_cached)
sin_cached_gpu = CUDA.CuArray(rope.sin_cached)
RoPE{T, typeof(cos_cached_gpu)}(rope.features, cos_cached_gpu, sin_cached_gpu)
RoPE{T, typeof(cos_cached_gpu)}(rope.features, cos_cached_gpu, sin_cached_gpu, rope.scale)
end

end # module
end # module
11 changes: 3 additions & 8 deletions src/PositionalEmbeddings.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,6 @@ module PositionalEmbeddings
using Functors
export RoPE, AbsolutePE

function cu(args...; opts...)
if !isdefined(@__MODULE__, :CUDA)
error("Import CUDA to enable GPU support")
else
error("Invalid method call")
end
end

"""
compute_frequencies(dim::Int, seq_len::Int, base::Number=10_000)
Expand Down Expand Up @@ -135,13 +127,16 @@ module PositionalEmbeddings
neg_half(x::AbstractArray{T}, dim::Int=1) where T
Helper function that negates the second half of the array along dimension `dim`.
This implementation uses half negative array and not interleaving pairs, following LLaMA and issue from
https://github.com/huggingface/transformers/issues/25199
# Arguments
- `x::AbstractArray{T}`: Input array
- `dim::Int=1`: Dimension along which to perform the operation
# Returns
- Array with second half negated along specified dimension
"""
function neg_half(x::AbstractArray{T}, dim::Int=1) where T
d_2 = size(x, dim) ÷ 2
Expand Down

0 comments on commit c65d81b

Please sign in to comment.