Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add capability to set FieldTimeSeries to a function #3932

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/Fields/function_field.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ fieldify_function(L, a::Function, grid) = FunctionField(L, a, grid)

Adapt.adapt_structure(to, f::FunctionField{LX, LY, LZ}) where {LX, LY, LZ} =
FunctionField{LX, LY, LZ}(Adapt.adapt(to, f.func),
Adapt.adapt(to, f.grid),
clock = Adapt.adapt(to, f.clock),
parameters = Adapt.adapt(to, f.parameters))
Adapt.adapt(to, f.grid),
clock = Adapt.adapt(to, f.clock),
parameters = Adapt.adapt(to, f.parameters))


on_architecture(to, f::FunctionField{LX, LY, LZ}) where {LX, LY, LZ} =
Expand Down
7 changes: 3 additions & 4 deletions src/Fields/set!.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ end
##### Setting to specific things
#####

function set_to_function!(u, f)
function set_to_function!(u::Field, f)
# Supports serial and distributed
arch = architecture(u)
child_arch = child_architecture(u)
Expand All @@ -53,7 +53,6 @@ function set_to_function!(u, f)
cpu_arch = cpu_architecture(arch)
cpu_grid = on_architecture(cpu_arch, u.grid)
cpu_u = Field(location(u), cpu_grid; indices = indices(u))

elseif child_arch isa CPU
cpu_grid = u.grid
cpu_u = u
Expand Down Expand Up @@ -89,7 +88,7 @@ function set_to_function!(u, f)
return u
end

function set_to_array!(u, f)
function set_to_array!(u::Field, f)
f = on_architecture(architecture(u), f)

try
Expand All @@ -111,7 +110,7 @@ function set_to_array!(u, f)
return u
end

function set_to_field!(u, v)
function set_to_field!(u::Field, v)
# We implement some niceities in here that attempt to copy halo data,
# and revert to copying just interior points if that fails.

Expand Down
2 changes: 1 addition & 1 deletion src/OutputReaders/field_time_series.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ using Oceananigans.Utils: launch!

import Oceananigans.Architectures: architecture, on_architecture
import Oceananigans.BoundaryConditions: fill_halo_regions!, BoundaryCondition, getbc
import Oceananigans.Fields: Field, set!, interior, indices, interpolate!
import Oceananigans.Fields: Field, interior, indices, interpolate!

#####
##### Data backends for FieldTimeSeries
Expand Down
1 change: 1 addition & 0 deletions src/OutputReaders/field_time_series_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -283,3 +283,4 @@ function getindex(fts::InMemoryFTS, n::Int)

return Field(location(fts), fts.grid; data, fts.boundary_conditions, fts.indices)
end

75 changes: 71 additions & 4 deletions src/OutputReaders/set_field_time_series.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,70 @@
using Printf
using Oceananigans.Architectures: cpu_architecture

import Oceananigans.Fields: set!

function set!(u::InMemoryFTS, v::InMemoryFTS)
if child_architecture(u) === child_architecture(v)
# Note: we could try to copy first halo point even when halo
# regions are a different size. That's a bit more complicated than
# the below so we leave it for the future.

try # to copy halo regions along with interior data
parent(u) .= parent(v)
catch # this could fail if the halo regions are different sizes?
# copy just the interior data
interior(u) .= interior(v)
end
else
v_data = on_architecture(child_architecture(u), v.data)

# As above, we permit ourselves a little ambition and try to copy halo data:
try
parent(u) .= parent(v_data)
catch
interior(u) .= interior(v_data, location(v), v.grid, v.indices)
end
end

return u
end

function set!(u::InMemoryFTS, v::Function)
# Supports serial and distributed
arch = architecture(u)
child_arch = child_architecture(u)
LX, LY, LZ = location(u)

# Determine cpu_grid and cpu_u
if child_arch isa GPU
cpu_arch = cpu_architecture(arch)
cpu_grid = on_architecture(cpu_arch, u.grid)
cpu_times = on_architecture(cpu_arch, u.times)
cpu_u = FieldTimeSeries{LX, LY, LZ}(cpu_grid, cpu_times; indices=indices(u))
elseif child_arch isa CPU
cpu_arch = child_arch
cpu_grid = u.grid
cpu_times = u.times
cpu_u = u
end

launch!(cpu_arch, cpu_grid, size(cpu_u),
_set_fts_to_function!, cpu_u, (LX(), LY(), LZ()), cpu_grid, cpu_times, v)

# Transfer data to GPU if u is on the GPU
child_arch isa GPU && set!(u, cpu_u)

return u
end

@kernel function _set_fts_to_function!(fts, loc, grid, times, func)
i, j, k, n = @index(Global, NTuple)
X = node(i, j, k, grid, loc...)
@inbounds begin
fts[i, j, k, n] = func(X..., times[n])
end
end

#####
##### set!
#####
Expand Down Expand Up @@ -44,10 +108,10 @@ function set!(fts::InMemoryFTS, path::String=fts.path, name::String=fts.name)
set!(fts[n], field_n)
end

return nothing
return fts
end

set!(fts::InMemoryFTS, value, n::Int) = set!(fts[n], value)
set!(fts::InMemoryFTS, v, n::Int) = set!(fts[n], value)

function set!(fts::InMemoryFTS, fields_vector::AbstractVector{<:AbstractField})
raw_data = parent(fts)
Expand All @@ -61,7 +125,7 @@ function set!(fts::InMemoryFTS, fields_vector::AbstractVector{<:AbstractField})

close(file)

return nothing
return fts
end

# Write property only if it does not already exist
Expand Down Expand Up @@ -90,6 +154,8 @@ function set!(fts::OnDiskFTS, field::Field, n::Int, time=fts.times[n])
maybe_write_property!(file, "timeseries/t/$n", time)
maybe_write_property!(file, "timeseries/$name/$n", Array(parent(field)))
end

return fts
end

function initialize_file!(file, name, fts)
Expand All @@ -100,4 +166,5 @@ function initialize_file!(file, name, fts)
return nothing
end

set!(fts::OnDiskFTS, path::String, name::String) = nothing
set!(fts::OnDiskFTS, path::String, name::String) = fts

4 changes: 2 additions & 2 deletions src/OutputWriters/checkpointer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,15 +190,15 @@ end
##### set! for checkpointer filepaths
#####

set!(model, ::Nothing) = nothing
set!(model::AbstractModel, ::Nothing) = nothing

"""
set!(model, filepath::AbstractString)

Set data in `model.velocities`, `model.tracers`, `model.timestepper.Gⁿ`, and
`model.timestepper.G⁻` to checkpointed data stored at `filepath`.
"""
function set!(model, filepath::AbstractString)
function set!(model::AbstractModel, filepath::AbstractString)

jldopen(filepath, "r") do file

Expand Down