Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bitrounding + Lossless compression #3599

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ weakdeps = ["StaticArrays"]
[deps.Adapt.extensions]
AdaptStaticArraysExt = "StaticArrays"

[[deps.AliasTables]]
deps = ["Random"]
git-tree-sha1 = "82b912bb5215792fd33df26f407d064d3602af98"
uuid = "66dad0bd-aa9a-41b7-9441-69ab47430ed8"
version = "1.1.2"

[[deps.ArgTools]]
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
version = "1.1.1"
Expand All @@ -50,6 +56,24 @@ version = "0.5.0"
[[deps.Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"

[[deps.BitInformation]]
deps = ["Distributions", "Random", "StatsBase"]
git-tree-sha1 = "8f98d9d01f50d3a9bf987d7e206c993b390a98bf"
uuid = "de688a37-743e-4ac2-a6f0-bd62414d1aa7"
version = "0.6.1"

[[deps.Blosc_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Lz4_jll", "Zlib_jll", "Zstd_jll"]
git-tree-sha1 = "19b98ee7e3db3b4eff74c5c9c72bf32144e24f10"
uuid = "0b7ba130-8d10-5ba8-a3d6-c5182647fed9"
version = "1.21.5+0"

[[deps.Bzip2_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "9e2a6b69137e6969bab0152632dcb3bc108c8bdd"
uuid = "6e34b625-4abd-537c-b88f-471c36dfa7a0"
version = "1.0.8+1"

[[deps.CEnum]]
git-tree-sha1 = "389ad5c84de1ae7cf0e28e381131c98ea87d54fc"
uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82"
Expand Down Expand Up @@ -95,6 +119,12 @@ git-tree-sha1 = "afea94249b821dc754a8ca6695d3daed851e1f5a"
uuid = "76a88914-d11a-5bdc-97e0-2f5a05c973a2"
version = "0.14.1+0"

[[deps.Calculus]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "f641eb0a4f00c343bbc32346e1217b86f3ce9dad"
uuid = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9"
version = "0.5.1"

[[deps.ColorTypes]]
deps = ["FixedPointNumbers", "Random"]
git-tree-sha1 = "b10d0b65641d57b8b4d5e234446582de5047050d"
Expand Down Expand Up @@ -203,6 +233,22 @@ version = "0.10.11"
deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"

[[deps.Distributions]]
deps = ["AliasTables", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns"]
git-tree-sha1 = "22c595ca4146c07b16bcf9c8bea86f731f7109d2"
uuid = "31c24e10-a181-5473-b8eb-7969acd0382f"
version = "0.25.108"

[deps.Distributions.extensions]
DistributionsChainRulesCoreExt = "ChainRulesCore"
DistributionsDensityInterfaceExt = "DensityInterface"
DistributionsTestExt = "Test"

[deps.Distributions.weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[[deps.DocStringExtensions]]
deps = ["LibGit2"]
git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d"
Expand All @@ -214,6 +260,12 @@ deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"]
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
version = "1.6.0"

[[deps.DualNumbers]]
deps = ["Calculus", "NaNMath", "SpecialFunctions"]
git-tree-sha1 = "5837a837389fccf076445fce071c8ddaea35a566"
uuid = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74"
version = "0.6.8"

[[deps.Elliptic]]
git-tree-sha1 = "71c79e77221ab3a29918aaf6db4f217b89138608"
uuid = "b305315f-e792-5b7a-8f41-49f472929428"
Expand Down Expand Up @@ -245,6 +297,18 @@ version = "1.16.3"
[[deps.FileWatching]]
uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee"

[[deps.FillArrays]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "0653c0a2396a6da5bc4766c43041ef5fd3efbe57"
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
version = "1.11.0"
weakdeps = ["PDMats", "SparseArrays", "Statistics"]

[deps.FillArrays.extensions]
FillArraysPDMatsExt = "PDMats"
FillArraysSparseArraysExt = "SparseArrays"
FillArraysStatisticsExt = "Statistics"

[[deps.FixedPointNumbers]]
deps = ["Statistics"]
git-tree-sha1 = "05882d6995ae5c12bb5f36dd2ed3f61c98cbb172"
Expand Down Expand Up @@ -555,6 +619,12 @@ git-tree-sha1 = "ce3269ed42816bf18d500c9f63418d4b0d9f5a3b"
uuid = "e98f9f5b-d649-5603-91fd-7774390e6439"
version = "3.1.0+2"

[[deps.NaNMath]]
deps = ["OpenLibm_jll"]
git-tree-sha1 = "0877504529a3e5c3343c6f8b4c0381e57e4387e4"
uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
version = "1.0.2"

[[deps.NetCDF_jll]]
deps = ["Artifacts", "HDF5_jll", "JLLWrappers", "LibCURL_jll", "Libdl", "Pkg", "XML2_jll", "Zlib_jll"]
git-tree-sha1 = "072f8371f74c3b9e1b26679de7fbf059d45ea221"
Expand Down Expand Up @@ -652,6 +722,12 @@ git-tree-sha1 = "b437cdb0385ed38312d91d9c00c20f3798b30256"
uuid = "49802e3a-d2f1-5c88-81d8-b72133a6f568"
version = "1.5.1"

[[deps.QuadGK]]
deps = ["DataStructures", "LinearAlgebra"]
git-tree-sha1 = "9b23c31e76e333e6fb4c1595ae6afa74966a729e"
uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
version = "2.9.4"

[[deps.Quaternions]]
deps = ["LinearAlgebra", "Random", "RealDot"]
git-tree-sha1 = "994cc27cdacca10e68feb291673ec3a76aa2fae9"
Expand Down Expand Up @@ -701,6 +777,18 @@ git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7"
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
version = "1.3.0"

[[deps.Rmath]]
deps = ["Random", "Rmath_jll"]
git-tree-sha1 = "f65dcb5fa46aee0cf9ed6274ccbd597adc49aa7b"
uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa"
version = "0.7.1"

[[deps.Rmath_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "6ed52fdd3382cf21947b15e8870ac0ddbff736da"
uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f"
version = "0.4.0+0"

[[deps.Rotations]]
deps = ["LinearAlgebra", "Quaternions", "Random", "StaticArrays"]
git-tree-sha1 = "5680a9276685d392c87407df00d57c9924d9f11e"
Expand Down
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "0.91.11"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
BitInformation = "de688a37-743e-4ac2-a6f0-bd62414d1aa7"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Crayons = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
CubedSphere = "7445602f-e544-4518-8976-18f8e8ae6cdb"
Expand Down Expand Up @@ -44,6 +45,7 @@ OceananigansMakieExt = ["MakieCore", "Makie"]

[compat]
Adapt = "3, 4"
BitInformation = "0.6"
CUDA = "4.1.1, 5"
Crayons = "4"
CubedSphere = "0.1, 0.2"
Expand Down
52 changes: 52 additions & 0 deletions src/OutputWriters/bit_rounding.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
using BitInformation

struct BitRounder{K}
keepbits :: K
end

# number of keepbits (mantissa bits) for each variable
default_bit_rounding(::Val{name}) where name = 23 # single precision default
default_bit_rounding(::Val{:u}) = 2
default_bit_rounding(::Val{:v}) = 2
default_bit_rounding(::Val{:w}) = 2
default_bit_rounding(::Val{:T}) = 7
default_bit_rounding(::Val{:S}) = 16 # 12 at the surface, 16 deep ocean
Comment on lines +12 to +13
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is interesting. Why is there a difference between T, S? Is this specific to the simulation that this was tested on, or can we be sure this is valid for all simulations, past climates, future climates, idealized simulations at other resolutions, etc?

It seems we need to have default bit rounding for passive tracers.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although relatively robust through time and space, this depends on a lot of things, also whether your unit carries some offset around (e.g. Kelvin vs ˚C, density vs density anomaly). So it's tricky to generalise. I suggest to have some reasonable defaults if someone uses bit rounding (default nothing or single precision as you like) but suggest to highlight that this should be checked similar to how I did it here with the bitinformation analysis above.

For global ocean simulations I expect these to be reasonable defaults. I believe for now this is mostly to reduce the filesizes for OMIP simulations

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, I'm just not sure that OMIP is going to be the most common use case, so there's a question about what default is appropriate here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The OMIP defaults might belong in the ClimaOcean setup, perhaps

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could set the defaults here as 23 mantissa bits (=Float32 precision, whether you use Float32 or 64) and then lower in ClimaOcean?

default_bit_rounding(::Val{:η}) = 6

function BitRounding(outputs = nothing;
user_rounding...)
Comment on lines +16 to +17
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the purpose of figuring out good defaults perhaps we should include model as an input here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then default_bit_rounding can take model as an argument, and dispatch on various things, for example the equation of state (which should know the units of temperature), and perhaps the biogeochemistry model, which may know the units of some important tracers


keepbits = Dict()

# TODO:
# Check that the dimensions of keepbits are
# compatible with outputs if user_rounding
# contains an abstract array (support functions?)

for name in keys(outputs)
if name ∈ keys(user_rounding)
keepbits[name] = user_rounding[name]
else
keepbits[name] = default_bit_rounding(Val(name))
end
end

return BitRounding(keepbits)
end

# Getindex to allow indexing a BitRounder as
Base.getindex(bit_rounding::BitRounding, name::Symbol) = BitRounding(bit_rounding[name])

function round_data!(output_array, bit_rounder::BitRounder)

# The actual rounding...
keepbits = bit_rounder.keepbits

# TODO: make sure that the rounding happens
# as we expect (priority to the vertical direction!)
round!(output_array, keepbits)

return output_array
end


12 changes: 9 additions & 3 deletions src/OutputWriters/fetch_output.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,28 @@ function fetch_output(lagrangian_particles::LagrangianParticles, model)
return NamedTuple{names}([getproperty(particle_properties, name) for name in names])
end

convert_output(output, writer) = output
convert_output(output, writer, bit_rounder) = output

function convert_output(output::AbstractArray, writer)
function convert_output(output::AbstractArray, writer, bit_rounder)
if architecture(output) isa GPU
output_array = writer.array_type(undef, size(output)...)
copyto!(output_array, output)
else
output_array = convert(writer.array_type, output)
end

# always happens on the CPU
round_data!(output_array, bit_rounder)

return output_array
end

# Fallback for a `Nothing` bit_rounder (no rounding)
round_data!(output_array, bit_rounder) = nothing

# Need to broadcast manually because of https://github.com/JuliaLang/julia/issues/30836
convert_output(outputs::NamedTuple, writer) =
NamedTuple(name => convert_output(outputs[name], writer) for name in keys(outputs))
NamedTuple(name => convert_output(outputs[name], writer, writer.bit_rounder[name]) for name in keys(outputs))

function fetch_and_convert_output(output, model, writer)
fetched = fetch_output(output, model)
Expand Down
22 changes: 19 additions & 3 deletions src/OutputWriters/jld2_output_writer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,19 @@ default_included_properties(::NonhydrostaticModel) = [:grid, :coriolis, :buoyanc
default_included_properties(::ShallowWaterModel) = [:grid, :coriolis, :closure]
default_included_properties(::HydrostaticFreeSurfaceModel) = [:grid, :coriolis, :buoyancy, :closure]

mutable struct JLD2OutputWriter{O, T, D, IF, IN, FS, KW} <: AbstractOutputWriter
# currently uses CodecZlib, ZlibCompressor as default
# TODO: update to ZstdCompressor with https://github.com/JuliaIO/JLD2.jl/pull/560 merged
default_jld2_kwargs() = (; compress=true)

mutable struct JLD2OutputWriter{O, T, D, IF, IN, BR, FS, KW} <: AbstractOutputWriter
filepath :: String
outputs :: O
schedule :: T
array_type :: D
init :: IF
including :: IN
part :: Int
bit_rounder :: BR
file_splitting :: FS
overwrite_existing :: Bool
verbose :: Bool
Expand All @@ -33,6 +38,7 @@ ext(::Type{JLD2OutputWriter}) = ".jld2"
with_halos = false,
array_type = Array{Float64},
file_splitting = NoFileSplitting(),
bit_rounder = nothing,
overwrite_existing = false,
init = noinit,
including = [:grid, :coriolis, :buoyancy, :closure],
Expand Down Expand Up @@ -97,6 +103,10 @@ Keyword arguments
- `including`: List of model properties to save with every file.
Default: `[:grid, :coriolis, :buoyancy, :closure]`

## Compressing the output

- `bit_rounder`: Number of keepbits per variable and vertical level, applies bitrounding

## Miscellaneous keywords

- `verbose`: Log what the output writer is doing with statistics on compute/write times and file sizes.
Expand Down Expand Up @@ -170,10 +180,11 @@ function JLD2OutputWriter(model, outputs; filename, schedule,
file_splitting = NoFileSplitting(),
overwrite_existing = false,
init = noinit,
bit_rounder = nothing,
including = default_included_properties(model),
verbose = false,
part = 1,
jld2_kw = Dict{Symbol, Any}())
jld2_kw = default_jld2_kwargs())

mkpath(dir)
filename = auto_extension(filename, ".jld2")
Expand All @@ -185,12 +196,17 @@ function JLD2OutputWriter(model, outputs; filename, schedule,
outputs = NamedTuple(Symbol(name) => construct_output(outputs[name], model.grid, indices, with_halos)
for name in keys(outputs))

# No rounding for any variable!
if isnothing(bit_rounder)
bit_rounder = Dict{Symbol, Any}(Symbol(name) => nothing for name in keys(outputs))
end

# Convert each output to WindowedTimeAverage if schedule::AveragedTimeWindow is specified
schedule, outputs = time_average_outputs(schedule, outputs, model)

initialize_jld2_file!(filepath, init, jld2_kw, including, outputs, model)

return JLD2OutputWriter(filepath, outputs, schedule, array_type, init,
return JLD2OutputWriter(filepath, outputs, schedule, array_type, init, bit_rounder,
including, part, file_splitting, overwrite_existing, verbose, jld2_kw)
end

Expand Down
Loading