diff --git a/Project.toml b/Project.toml index 1932ede..2da2b87 100644 --- a/Project.toml +++ b/Project.toml @@ -1,9 +1,10 @@ name = "TrixiBase" uuid = "9a0f1c46-06d5-4909-a5a3-ce25d3fa3284" authors = ["Michael Schlottke-Lakemper "] -version = "0.1.5-DEV" +version = "0.1.5" [deps] +ChangePrecision = "3cb15238-376d-56a3-8042-d33272777c9a" TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" [weakdeps] @@ -13,6 +14,7 @@ MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" TrixiBaseMPIExt = "MPI" [compat] +ChangePrecision = "1.1.0" MPI = "0.20" TimerOutputs = "0.5.25" julia = "1.8" diff --git a/src/TrixiBase.jl b/src/TrixiBase.jl index daac03c..d1699fd 100644 --- a/src/TrixiBase.jl +++ b/src/TrixiBase.jl @@ -1,11 +1,12 @@ module TrixiBase +using ChangePrecision: ChangePrecision using TimerOutputs: TimerOutput, TimerOutputs include("trixi_include.jl") include("trixi_timeit.jl") -export trixi_include +export trixi_include, trixi_include_changeprecision export @trixi_timeit, timer, timeit_debug_enabled, disable_debug_timings, enable_debug_timings diff --git a/src/trixi_include.jl b/src/trixi_include.jl index 92e773a..3e12703 100644 --- a/src/trixi_include.jl +++ b/src/trixi_include.jl @@ -3,7 +3,7 @@ # of `TrixiBase`. However, users will want to evaluate in the global scope of `Main` or something # similar to manage dependencies on their own. """ - trixi_include([mod::Module=Main,] elixir::AbstractString; kwargs...) + trixi_include([mapexpr::Function,] [mod::Module=Main,] elixir::AbstractString; kwargs...) `include` the file `elixir` and evaluate its content in the global scope of module `mod`. You can override specific assignments in `elixir` by supplying keyword arguments. @@ -16,6 +16,10 @@ into calls to `solve` with it's default value used in the SciML ecosystem for ODEs, see the "Miscellaneous" section of the [documentation](https://docs.sciml.ai/DiffEqDocs/stable/basics/common_solver_opts/). +The optional first argument `mapexpr` can be used to transform the included code before +it is evaluated: for each parsed expression `expr` in `elixir`, the `include` function +actually evaluates `mapexpr(expr)`. If it is omitted, `mapexpr` defaults to `identity`. + # Examples ```@example @@ -30,7 +34,7 @@ julia> redirect_stdout(devnull) do 0.1 ``` """ -function trixi_include(mod::Module, elixir::AbstractString; kwargs...) +function trixi_include(mapexpr::Function, mod::Module, elixir::AbstractString; kwargs...) # Check that all kwargs exist as assignments code = read(elixir, String) expr = Meta.parse("begin \n$code \nend") @@ -45,13 +49,61 @@ function trixi_include(mod::Module, elixir::AbstractString; kwargs...) if !mpi_isparallel(Val{:MPIExt}()) @info "You just called `trixi_include`. Julia may now compile the code, please be patient." end - Base.include(ex -> replace_assignments(insert_maxiters(ex); kwargs...), mod, elixir) + Base.include(ex -> replace_assignments(insert_maxiters(mapexpr(ex)); kwargs...), + mod, elixir) +end + +function trixi_include(mod::Module, elixir::AbstractString; kwargs...) + trixi_include(identity, mod, elixir; kwargs...) end function trixi_include(elixir::AbstractString; kwargs...) trixi_include(Main, elixir; kwargs...) end +""" + trixi_include_changeprecision(T, [mod::Module=Main,] elixir::AbstractString; kwargs...) + +`include` the elixir `elixir` and evaluate its content in the global scope of module `mod`. +You can override specific assignments in `elixir` by supplying keyword arguments, +similar to [`trixi_include`](@ref). + +The only difference to [`trixi_include`](@ref) is that the precision of floating-point +numbers in the included elixir is changed to `T`. +More precisely, the package [ChangePrecision.jl](https://github.com/JuliaMath/ChangePrecision.jl) +is used to convert all `Float64` literals, operations like `/` that produce `Float64` results, +and functions like `ones` that return `Float64` arrays by default, to the desired type `T`. +See the documentation of ChangePrecision.jl for more details. + +The purpose of this function is to conveniently run a full simulation with `Float32`, +which is orders of magnitude faster on most GPUs than `Float64`, by just including +the elixir with `trixi_include_changeprecision(Float32, elixir)`. +Most code in the Trixi framework is written in a way that changing all floating-point +numbers in the elixir to `Float32` manually will run the full simulation with single precision. +""" +function trixi_include_changeprecision(T, mod::Module, filename::AbstractString; kwargs...) + trixi_include(expr -> ChangePrecision.changeprecision(T, replace_trixi_include(T, expr)), + mod, filename; kwargs...) +end + +function trixi_include_changeprecision(T, filename::AbstractString; kwargs...) + trixi_include_changeprecision(T, Main, filename; kwargs...) +end + +function replace_trixi_include(T, expr) + expr = TrixiBase.walkexpr(expr) do x + if x isa Expr + if x.head === :call && x.args[1] === :trixi_include + x.args[1] = :trixi_include_changeprecision + insert!(x.args, 2, :($T)) + end + end + return x + end + + return expr +end + # Insert the keyword argument `maxiters` into calls to `solve` and `Trixi.solve` # with default value `10^5` if it is not already present. function insert_maxiters(expr)