Skip to content

Commit

Permalink
some fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
simone-silvestri authored Jan 30, 2025
1 parent bb121c8 commit ac8cd58
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 19 deletions.
8 changes: 4 additions & 4 deletions src/Architectures.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ device(::GPU) = CUDA.CUDABackend(; always_inline=true)
# for execution and will import whatever host or kernel is called.
# this says thay reactant will import the cuda kernel version of the code
# which makes some optimizations easier. Reactant may still execute the
# code on CPU GPU or TPU dependong on what its default client is.
# code on CPU GPU or TPU depending on what its default client is.
device(::ReactantState) = CUDA.CUDABackend(; always_inline=true)

architecture() = nothing
Expand Down Expand Up @@ -134,9 +134,9 @@ end
@inline unsafe_free!(a) = nothing

# Convert arguments to GPU-compatible types
@inline convert_args(::CPU, args) = args
@inline convert_args(::GPU, args) = CUDA.cudaconvert(args)
@inline convert_args(::GPU, args::Tuple) = map(CUDA.cudaconvert, args)
@inline convert_to_device(arch, args) = args
@inline convert_to_device(::GPU, args) = CUDA.cudaconvert(args)
@inline convert_to_device(::GPU, args::Tuple) = map(CUDA.cudaconvert, args)

# Deprecated functions
function arch_array(arch, arr)
Expand Down
10 changes: 5 additions & 5 deletions src/DistributedComputations/distributed_architectures.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ using Oceananigans.Architectures
using Oceananigans.Grids: topology, validate_tupled_argument
using CUDA: ndevices, device!

import Oceananigans.Architectures: device, cpu_architecture, on_architecture, array_type, child_architecture, convert_args
import Oceananigans.Architectures: device, cpu_architecture, on_architecture, array_type, child_architecture, convert_to_device
import Oceananigans.Grids: zeros
import Oceananigans.Utils: sync_device!, tupleit

Expand Down Expand Up @@ -306,10 +306,10 @@ ranks(arch::Distributed) = ranks(arch.partition)
child_architecture(arch::Distributed) = arch.child_architecture
device(arch::Distributed) = device(child_architecture(arch))

zeros(FT, arch::Distributed, N...) = zeros(FT, child_architecture(arch), N...)
array_type(arch::Distributed) = array_type(child_architecture(arch))
sync_device!(arch::Distributed) = sync_device!(arch.child_architecture)
convert_args(arch::Distributed, arg) = convert_args(child_architecture(arch), arg)
zeros(FT, arch::Distributed, N...) = zeros(FT, child_architecture(arch), N...)
array_type(arch::Distributed) = array_type(child_architecture(arch))
sync_device!(arch::Distributed) = sync_device!(arch.child_architecture)
convert_to_device(arch::Distributed, arg) = convert_to_device(child_architecture(arch), arg)

# Switch to a synchronized architecture
synchronized(arch) = arch
Expand Down
11 changes: 4 additions & 7 deletions src/Grids/zeros_and_ones.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
using CUDA
using Oceananigans.Architectures: CPU, GPU, AbstractArchitecture
using Oceananigans.Architectures: device, AbstractArchitecture

import Base: zeros
import KernelAbstractions: zeros

zeros(FT, ::CPU, N...) = zeros(FT, N...)
zeros(FT, ::GPU, N...) = CUDA.zeros(FT, N...)
zeros(FT, arch::AbstractArchitecture, N...) = zeros(device(arch), FT, N...)

zeros(arch::AbstractArchitecture, grid, N...) = zeros(eltype(grid), arch, N...)
zeros(grid::AbstractGrid, N...) = zeros(eltype(grid), architecture(grid), N...)

@inline Base.zero(grid::AbstractGrid) = zero(eltype(grid))
@inline Base.one(grid::AbstractGrid) = one(eltype(grid))
@inline Base.one(grid::AbstractGrid) = one(eltype(grid))
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ export FixedSubstepNumber, FixedTimeStepSize

using Oceananigans
using Oceananigans.Architectures
using Oceananigans.Architectures: convert_args
using Oceananigans.Architectures: convert_to_device
using Oceananigans.Fields
using Oceananigans.Utils
using Oceananigans.Grids
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ function iterate_split_explicit!(free_surface, grid, GUⁿ, GVⁿ, Δτᴮ, weig
# launching ~100 very small kernels: we are limited by
# latency of argument conversion to GPU-compatible values.
# To alleviate this penalty we convert first and then we substep!
converted_η_args = convert_args(arch, η_args)
converted_U_args = convert_args(arch, U_args)
converted_η_args = convert_to_device(arch, η_args)
converted_U_args = convert_to_device(arch, U_args)

@unroll for substep in 1:Nsubsteps
Base.@_inline_meta
Expand Down

0 comments on commit ac8cd58

Please sign in to comment.