diff --git a/src/Fields/function_field.jl b/src/Fields/function_field.jl index 41b2c8b5fc..4f2ac384e2 100644 --- a/src/Fields/function_field.jl +++ b/src/Fields/function_field.jl @@ -58,9 +58,9 @@ fieldify_function(L, a::Function, grid) = FunctionField(L, a, grid) Adapt.adapt_structure(to, f::FunctionField{LX, LY, LZ}) where {LX, LY, LZ} = FunctionField{LX, LY, LZ}(Adapt.adapt(to, f.func), - Adapt.adapt(to, f.grid), - clock = Adapt.adapt(to, f.clock), - parameters = Adapt.adapt(to, f.parameters)) + Adapt.adapt(to, f.grid), + clock = Adapt.adapt(to, f.clock), + parameters = Adapt.adapt(to, f.parameters)) on_architecture(to, f::FunctionField{LX, LY, LZ}) where {LX, LY, LZ} = diff --git a/src/Fields/set!.jl b/src/Fields/set!.jl index e311f659f1..d1e43a6daa 100644 --- a/src/Fields/set!.jl +++ b/src/Fields/set!.jl @@ -43,7 +43,7 @@ end ##### Setting to specific things ##### -function set_to_function!(u, f) +function set_to_function!(u::Field, f) # Supports serial and distributed arch = architecture(u) child_arch = child_architecture(u) @@ -53,7 +53,6 @@ function set_to_function!(u, f) cpu_arch = cpu_architecture(arch) cpu_grid = on_architecture(cpu_arch, u.grid) cpu_u = Field(location(u), cpu_grid; indices = indices(u)) - elseif child_arch isa CPU cpu_grid = u.grid cpu_u = u @@ -89,7 +88,7 @@ function set_to_function!(u, f) return u end -function set_to_array!(u, f) +function set_to_array!(u::Field, f) f = on_architecture(architecture(u), f) try @@ -111,7 +110,7 @@ function set_to_array!(u, f) return u end -function set_to_field!(u, v) +function set_to_field!(u::Field, v) # We implement some niceities in here that attempt to copy halo data, # and revert to copying just interior points if that fails. diff --git a/src/OutputReaders/field_time_series.jl b/src/OutputReaders/field_time_series.jl index 19a1bdb2c8..ea3d7d7201 100644 --- a/src/OutputReaders/field_time_series.jl +++ b/src/OutputReaders/field_time_series.jl @@ -23,7 +23,7 @@ using Oceananigans.Utils: launch! import Oceananigans.Architectures: architecture, on_architecture import Oceananigans.BoundaryConditions: fill_halo_regions!, BoundaryCondition, getbc -import Oceananigans.Fields: Field, set!, interior, indices, interpolate! +import Oceananigans.Fields: Field, interior, indices, interpolate! ##### ##### Data backends for FieldTimeSeries diff --git a/src/OutputReaders/field_time_series_indexing.jl b/src/OutputReaders/field_time_series_indexing.jl index d849192f67..e29ab8a404 100644 --- a/src/OutputReaders/field_time_series_indexing.jl +++ b/src/OutputReaders/field_time_series_indexing.jl @@ -283,3 +283,4 @@ function getindex(fts::InMemoryFTS, n::Int) return Field(location(fts), fts.grid; data, fts.boundary_conditions, fts.indices) end + diff --git a/src/OutputReaders/set_field_time_series.jl b/src/OutputReaders/set_field_time_series.jl index 577082782e..91bfe7722a 100644 --- a/src/OutputReaders/set_field_time_series.jl +++ b/src/OutputReaders/set_field_time_series.jl @@ -1,6 +1,70 @@ using Printf using Oceananigans.Architectures: cpu_architecture +import Oceananigans.Fields: set! + +function set!(u::InMemoryFTS, v::InMemoryFTS) + if child_architecture(u) === child_architecture(v) + # Note: we could try to copy first halo point even when halo + # regions are a different size. That's a bit more complicated than + # the below so we leave it for the future. + + try # to copy halo regions along with interior data + parent(u) .= parent(v) + catch # this could fail if the halo regions are different sizes? + # copy just the interior data + interior(u) .= interior(v) + end + else + v_data = on_architecture(child_architecture(u), v.data) + + # As above, we permit ourselves a little ambition and try to copy halo data: + try + parent(u) .= parent(v_data) + catch + interior(u) .= interior(v_data, location(v), v.grid, v.indices) + end + end + + return u +end + +function set!(u::InMemoryFTS, v::Function) + # Supports serial and distributed + arch = architecture(u) + child_arch = child_architecture(u) + LX, LY, LZ = location(u) + + # Determine cpu_grid and cpu_u + if child_arch isa GPU + cpu_arch = cpu_architecture(arch) + cpu_grid = on_architecture(cpu_arch, u.grid) + cpu_times = on_architecture(cpu_arch, u.times) + cpu_u = FieldTimeSeries{LX, LY, LZ}(cpu_grid, cpu_times; indices=indices(u)) + elseif child_arch isa CPU + cpu_arch = child_arch + cpu_grid = u.grid + cpu_times = u.times + cpu_u = u + end + + launch!(cpu_arch, cpu_grid, size(cpu_u), + _set_fts_to_function!, cpu_u, (LX(), LY(), LZ()), cpu_grid, cpu_times, v) + + # Transfer data to GPU if u is on the GPU + child_arch isa GPU && set!(u, cpu_u) + + return u +end + +@kernel function _set_fts_to_function!(fts, loc, grid, times, func) + i, j, k, n = @index(Global, NTuple) + X = node(i, j, k, grid, loc...) + @inbounds begin + fts[i, j, k, n] = func(X..., times[n]) + end +end + ##### ##### set! ##### @@ -44,10 +108,10 @@ function set!(fts::InMemoryFTS, path::String=fts.path, name::String=fts.name) set!(fts[n], field_n) end - return nothing + return fts end -set!(fts::InMemoryFTS, value, n::Int) = set!(fts[n], value) +set!(fts::InMemoryFTS, v, n::Int) = set!(fts[n], value) function set!(fts::InMemoryFTS, fields_vector::AbstractVector{<:AbstractField}) raw_data = parent(fts) @@ -61,7 +125,7 @@ function set!(fts::InMemoryFTS, fields_vector::AbstractVector{<:AbstractField}) close(file) - return nothing + return fts end # Write property only if it does not already exist @@ -90,6 +154,8 @@ function set!(fts::OnDiskFTS, field::Field, n::Int, time=fts.times[n]) maybe_write_property!(file, "timeseries/t/$n", time) maybe_write_property!(file, "timeseries/$name/$n", Array(parent(field))) end + + return fts end function initialize_file!(file, name, fts) @@ -100,4 +166,5 @@ function initialize_file!(file, name, fts) return nothing end -set!(fts::OnDiskFTS, path::String, name::String) = nothing +set!(fts::OnDiskFTS, path::String, name::String) = fts + diff --git a/src/OutputWriters/checkpointer.jl b/src/OutputWriters/checkpointer.jl index d2da594e08..25d93c9925 100644 --- a/src/OutputWriters/checkpointer.jl +++ b/src/OutputWriters/checkpointer.jl @@ -190,7 +190,7 @@ end ##### set! for checkpointer filepaths ##### -set!(model, ::Nothing) = nothing +set!(model::AbstractModel, ::Nothing) = nothing """ set!(model, filepath::AbstractString) @@ -198,7 +198,7 @@ set!(model, ::Nothing) = nothing Set data in `model.velocities`, `model.tracers`, `model.timestepper.Gⁿ`, and `model.timestepper.G⁻` to checkpointed data stored at `filepath`. """ -function set!(model, filepath::AbstractString) +function set!(model::AbstractModel, filepath::AbstractString) jldopen(filepath, "r") do file