Skip to content

Commit

Permalink
Merge pull request #105 from sintefmath/battmo
Browse files Browse the repository at this point in the history
Refactor linear solvers to support battery preconditioners
  • Loading branch information
moyner authored Nov 20, 2024
2 parents 1bfe36b + 2094596 commit 7740a3d
Show file tree
Hide file tree
Showing 9 changed files with 82 additions and 47 deletions.
3 changes: 2 additions & 1 deletion ext/JutulPartitionedArraysExt/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ function Jutul.parray_update_preconditioners!(sim, preconditioner_base, precondi
model = Jutul.get_simulator_model(sim)
storage = Jutul.get_simulator_storage(sim)
sys = storage.LinearizedSystem
Jutul.update_preconditioner!(prec, sys, model, storage, recorder, sim.executor)
ctx = model.context
Jutul.update_preconditioner!(prec, sys, ctx, model, storage, recorder, sim.executor)
prec
end
return (preconditioner_base, preconditioners)
Expand Down
12 changes: 11 additions & 1 deletion src/core_types/core_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -965,7 +965,17 @@ struct MultiModel{label, T, CT, G, C, GL} <: AbstractMultiModel{label}
group_lookup::GL
end

function MultiModel(models, label::Union{Nothing, Symbol} = nothing; cross_terms = Vector{CrossTermPair}(), groups = nothing, context = nothing, reduction = missing, specialize = false, specialize_ad = false)
function MultiModel(models, label::Union{Nothing, Symbol} = nothing;
cross_terms = Vector{CrossTermPair}(),
groups = nothing,
context = nothing,
reduction = missing,
specialize = false,
specialize_ad = false
)
if isnothing(context)
context = models[first(keys(models))].context
end
group_lookup = Dict{Symbol, Int}()
if isnothing(groups)
num_groups = 1
Expand Down
54 changes: 31 additions & 23 deletions src/linsolve/krylov.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,43 +56,41 @@ function Base.show(io::IO, krylov::GenericKrylov)
print(io, "Generic Krylov using $(krylov.solver) (ϵₐ=$atol, ϵ=$rtol) with prec = $(typeof(krylov.preconditioner))")
end

function preconditioner(krylov::AbstractKrylov, sys, model, storage, recorder, float_t = Float64)
function preconditioner(krylov::AbstractKrylov, sys, context, model, storage, recorder)
M = krylov.preconditioner
Ft = float_type(context)
if isnothing(M)
op = I
else
op = PrecondWrapper(linear_operator(M, float_t))
linop = linear_operator(M, Ft, sys, context, model, storage, recorder)
op = PrecondWrapper(linop)
end
return op
end

# export update!
# function update_preconditioner!(prec, sys, model, storage, recorder)
# update!(prec, sys, model, storage, recorder)
# end

function linear_solve!(sys::LSystem,
krylov::GenericKrylov,
model,
storage = nothing,
dt = nothing,
recorder = ProgressRecorder(),
executor = default_executor();
dx = sys.dx_buffer,
r = vector_residual(sys),
atol = linear_solver_tolerance(krylov, :absolute),
rtol = linear_solver_tolerance(krylov, :relative),
rtol_nl = linear_solver_tolerance(krylov, :nonlinear_relative),
rtol_relaxed = linear_solver_tolerance(krylov, :relaxed_relative)
)
krylov::GenericKrylov,
context::JutulContext,
model::JutulModel,
storage = nothing,
dt = nothing,
recorder = ProgressRecorder(),
executor = default_executor();
dx = sys.dx_buffer,
r = vector_residual(sys),
atol = linear_solver_tolerance(krylov, :absolute),
rtol = linear_solver_tolerance(krylov, :relative),
rtol_nl = linear_solver_tolerance(krylov, :nonlinear_relative),
rtol_relaxed = linear_solver_tolerance(krylov, :relaxed_relative)
)
cfg = krylov.config
prec = krylov.preconditioner
Ft = float_type(model.context)
Ft = float_type(context)
sys = krylov_scale_system!(sys, krylov, dt)
t_prep = @elapsed @tic "prepare" prepare_linear_solve!(sys)
op = linear_operator(sys)
t_prec = @elapsed @tic "precond" update_preconditioner!(prec, sys, model, storage, recorder, executor)
prec_op = preconditioner(krylov, sys, model, storage, recorder, Ft)
t_prec = @elapsed @tic "precond" update_preconditioner!(prec, sys, context, model, storage, recorder, executor)
prec_op = preconditioner(krylov, sys, context, model, storage, recorder)
v = Int64(cfg.verbose)
max_it = cfg.max_iterations
min_it = cfg.min_iterations
Expand Down Expand Up @@ -183,6 +181,16 @@ function linear_solve!(sys::LSystem,
return linear_solve_return(solved, n, stats, precond = t_prec_op, precond_count = t_prec_count, prepare = t_prec + t_prep)
end

function linear_solve!(
sys::LSystem,
krylov::GenericKrylov,
model::JutulModel,
arg...;
kwarg...
)
return linear_solve!(sys, krylov, model.context, model, arg...; kwarg...)
end

function krylov_scale_system!(sys, krylov::GenericKrylov, dt)
sys, krylov.storage_scaling = apply_scaling_to_linearized_system!(sys, krylov.storage_scaling, krylov.scaling, dt)
return sys
Expand Down
11 changes: 7 additions & 4 deletions src/linsolve/precond/ilu.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

"""
ILU(0) preconditioner on CPU
"""
Expand Down Expand Up @@ -35,17 +34,21 @@ function update_preconditioner!(ilu::ILUZeroPreconditioner, A, b, context, execu
end
end

function update_preconditioner!(ilu::ILUZeroPreconditioner, A::StaticSparsityMatrixCSR, b, context::ParallelCSRContext, executor)
function update_preconditioner!(ilu::ILUZeroPreconditioner, A::StaticSparsityMatrixCSR, b, context, executor)
if isnothing(ilu.factor)
mb = A.minbatch
max_t = max(size(A, 1) ÷ mb, 1)
nt = min(nthreads(context), max_t)
nt = min(A.nthreads, max_t)
if nt == 1
@debug "Setting up serial ILU(0)-CSR"
F = ilu0_csr(A)
else
@debug "Setting up parallel ILU(0)-CSR with $(nthreads(td)) threads"
part = context.partitioner
if context isa ParallelCSRContext
part = context.partitioner
else
part = MetisPartitioner()
end
lookup = generate_lookup(part, A, nt)
F = ilu0_csr(A, lookup)
end
Expand Down
23 changes: 10 additions & 13 deletions src/linsolve/precond/utils.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@

function update_preconditioner!(preconditioner::Nothing, arg...)
# Do nothing.
end
function update_preconditioner!(preconditioner, lsys, model, storage, recorder, executor)
function update_preconditioner!(preconditioner::JutulPreconditioner, lsys::JutulLinearSystem, context, model, storage, recorder, executor)
J = jacobian(lsys)
r = residual(lsys)
ctx = linear_system_context(model, lsys)
update_preconditioner!(preconditioner, J, r, ctx, executor)
update_preconditioner!(preconditioner, J, r, context, executor)
end

function partial_update_preconditioner!(p, A, b, context, executor)
Expand All @@ -20,7 +18,11 @@ end
is_left_preconditioner(::JutulPreconditioner) = true
is_right_preconditioner(::JutulPreconditioner) = false

function linear_operator(precond::JutulPreconditioner, float_t = Float64)
function linear_operator(precond::JutulPreconditioner)
return linear_operator(precond, Float64, nothing, nothing, nothing, nothing, nothing)
end

function linear_operator(precond::JutulPreconditioner, float_t, sys, context, model, storage, recorder)
n = operator_nrows(precond)
function precond_apply!(res, x, α, β::T) where T
if β == zero(T)
Expand All @@ -36,13 +38,8 @@ function linear_operator(precond::JutulPreconditioner, float_t = Float64)
return op
end

function apply!(x, p::JutulPreconditioner, y, arg...)
#nead to be spesilized on type not all JutulPreconditioners has get_factor
function apply!(x, p::JutulPreconditioner, y)
factor = get_factorization(p)
if is_left_preconditioner(p)
ldiv!(x, factor, y)
elseif is_right_preconditioner(p)
error("Not supported.")
else
error("Neither left or right preconditioner?")
end
ldiv!(x, factor, y)
end
12 changes: 11 additions & 1 deletion src/linsolve/precond/various.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,29 @@ function operator_nrows(lup::LUPreconditioner)
return size(f.L, 1)
end

#function operator_nrows(prec::TrivialPreconditioner)
# return prec.dim[1]
#end
# LU factor as precond for wells?

"""
Trivial / identity preconditioner with size for use in subsystems.
"""
function apply!(x,tp::TrivialPreconditioner,r, args...)
x = copy(r)
end


# Trivial precond
function update_preconditioner!(tp::TrivialPreconditioner, lsys, model, storage, recorder, executor)
A = jacobian(lsys)
b = residual(lsys)
tp.dim = size(A).*length(b[1])
end

function linear_operator(id::TrivialPreconditioner, ::Symbol)
export linear_operator

function linear_operator(id::TrivialPreconditioner, args...)
return opEye(id.dim...)
end

Expand Down
6 changes: 3 additions & 3 deletions src/linsolve/scalar_cpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ mutable struct LUSolver
end
end

function linear_solve!(sys, solver::LUSolver, arg...; kwarg...)
function linear_solve!(sys, solver::LUSolver, arg...;dx = sys.dx, r = sys.r, kwargs...)
if length(sys.dx) > solver.max_size
error("System too big for LU solver. You can increase max_size at your own peril.")
end
J = sys.jac
r = sys.r
#r = sys.r
if !solver.reuse_memory
F = lu(J)
else
Expand All @@ -32,7 +32,7 @@ function linear_solve!(sys, solver::LUSolver, arg...; kwarg...)
F = solver.F
end

sys.dx .= -(F\r)
dx .= -(F\r)
@assert all(isfinite, sys.dx) "Linear solve resulted in non-finite values."
return linear_solve_return()
end
3 changes: 2 additions & 1 deletion src/models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -884,8 +884,9 @@ end

function solve_and_update!(storage, model::JutulModel, dt = nothing; linear_solver = nothing, recorder = nothing, executor = default_executor(), kwarg...)
lsys = storage.LinearizedSystem
context = model.context
t_solve = @elapsed begin
@tic "linear solve" (ok, n, history) = linear_solve!(lsys, linear_solver, model, storage, dt, recorder, executor)
@tic "linear solve" (ok, n, history) = linear_solve!(lsys, linear_solver, context, model, storage, dt, recorder, executor)
end
t_update = @elapsed @tic "primary variables" update = update_primary_variables!(storage, model; kwarg...)
return (t_solve, t_update, n, history, update)
Expand Down
5 changes: 5 additions & 0 deletions src/multimodel/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,10 @@ function setup_storage!(storage, model::MultiModel;
return storage
end

mutable struct MutableWrapper
maps
end

function setup_multimodel_maps!(storage, model)
groups = model.groups
if isnothing(groups)
Expand All @@ -138,6 +142,7 @@ function setup_multimodel_maps!(storage, model)
offset_map = map(g -> get_submodel_offsets(model, g), unique(groups))
end
storage[:multi_model_maps] = (offset_map = offset_map, );
storage[:eq_maps] = MutableWrapper(nothing)
end

function setup_equations_and_primary_variable_views!(storage, model::MultiModel, lsys)
Expand Down

0 comments on commit 7740a3d

Please sign in to comment.