Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Battmo #105

Merged
merged 17 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ext/JutulPartitionedArraysExt/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ function Jutul.parray_update_preconditioners!(sim, preconditioner_base, precondi
model = Jutul.get_simulator_model(sim)
storage = Jutul.get_simulator_storage(sim)
sys = storage.LinearizedSystem
Jutul.update_preconditioner!(prec, sys, model, storage, recorder, sim.executor)
ctx = model.context
Jutul.update_preconditioner!(prec, sys, ctx, model, storage, recorder, sim.executor)
prec
end
return (preconditioner_base, preconditioners)
Expand Down
12 changes: 11 additions & 1 deletion src/core_types/core_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -965,7 +965,17 @@ struct MultiModel{label, T, CT, G, C, GL} <: AbstractMultiModel{label}
group_lookup::GL
end

function MultiModel(models, label::Union{Nothing, Symbol} = nothing; cross_terms = Vector{CrossTermPair}(), groups = nothing, context = nothing, reduction = missing, specialize = false, specialize_ad = false)
function MultiModel(models, label::Union{Nothing, Symbol} = nothing;
cross_terms = Vector{CrossTermPair}(),
groups = nothing,
context = nothing,
reduction = missing,
specialize = false,
specialize_ad = false
)
if isnothing(context)
context = models[first(keys(models))].context
end
group_lookup = Dict{Symbol, Int}()
if isnothing(groups)
num_groups = 1
Expand Down
54 changes: 31 additions & 23 deletions src/linsolve/krylov.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,43 +56,41 @@ function Base.show(io::IO, krylov::GenericKrylov)
print(io, "Generic Krylov using $(krylov.solver) (ϵₐ=$atol, ϵ=$rtol) with prec = $(typeof(krylov.preconditioner))")
end

function preconditioner(krylov::AbstractKrylov, sys, model, storage, recorder, float_t = Float64)
function preconditioner(krylov::AbstractKrylov, sys, context, model, storage, recorder)
M = krylov.preconditioner
Ft = float_type(context)
if isnothing(M)
op = I
else
op = PrecondWrapper(linear_operator(M, float_t))
linop = linear_operator(M, Ft, sys, context, model, storage, recorder)
op = PrecondWrapper(linop)
end
return op
end

# export update!
# function update_preconditioner!(prec, sys, model, storage, recorder)
# update!(prec, sys, model, storage, recorder)
# end

function linear_solve!(sys::LSystem,
krylov::GenericKrylov,
model,
storage = nothing,
dt = nothing,
recorder = ProgressRecorder(),
executor = default_executor();
dx = sys.dx_buffer,
r = vector_residual(sys),
atol = linear_solver_tolerance(krylov, :absolute),
rtol = linear_solver_tolerance(krylov, :relative),
rtol_nl = linear_solver_tolerance(krylov, :nonlinear_relative),
rtol_relaxed = linear_solver_tolerance(krylov, :relaxed_relative)
)
krylov::GenericKrylov,
context::JutulContext,
model::JutulModel,
storage = nothing,
dt = nothing,
recorder = ProgressRecorder(),
executor = default_executor();
dx = sys.dx_buffer,
r = vector_residual(sys),
atol = linear_solver_tolerance(krylov, :absolute),
rtol = linear_solver_tolerance(krylov, :relative),
rtol_nl = linear_solver_tolerance(krylov, :nonlinear_relative),
rtol_relaxed = linear_solver_tolerance(krylov, :relaxed_relative)
)
cfg = krylov.config
prec = krylov.preconditioner
Ft = float_type(model.context)
Ft = float_type(context)
sys = krylov_scale_system!(sys, krylov, dt)
t_prep = @elapsed @tic "prepare" prepare_linear_solve!(sys)
op = linear_operator(sys)
t_prec = @elapsed @tic "precond" update_preconditioner!(prec, sys, model, storage, recorder, executor)
prec_op = preconditioner(krylov, sys, model, storage, recorder, Ft)
t_prec = @elapsed @tic "precond" update_preconditioner!(prec, sys, context, model, storage, recorder, executor)
prec_op = preconditioner(krylov, sys, context, model, storage, recorder)
v = Int64(cfg.verbose)
max_it = cfg.max_iterations
min_it = cfg.min_iterations
Expand Down Expand Up @@ -183,6 +181,16 @@ function linear_solve!(sys::LSystem,
return linear_solve_return(solved, n, stats, precond = t_prec_op, precond_count = t_prec_count, prepare = t_prec + t_prep)
end

function linear_solve!(
sys::LSystem,
krylov::GenericKrylov,
model::JutulModel,
arg...;
kwarg...
)
return linear_solve!(sys, krylov, model.context, model, arg...; kwarg...)
end

function krylov_scale_system!(sys, krylov::GenericKrylov, dt)
sys, krylov.storage_scaling = apply_scaling_to_linearized_system!(sys, krylov.storage_scaling, krylov.scaling, dt)
return sys
Expand Down
11 changes: 7 additions & 4 deletions src/linsolve/precond/ilu.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

"""
ILU(0) preconditioner on CPU
"""
Expand Down Expand Up @@ -35,17 +34,21 @@ function update_preconditioner!(ilu::ILUZeroPreconditioner, A, b, context, execu
end
end

function update_preconditioner!(ilu::ILUZeroPreconditioner, A::StaticSparsityMatrixCSR, b, context::ParallelCSRContext, executor)
function update_preconditioner!(ilu::ILUZeroPreconditioner, A::StaticSparsityMatrixCSR, b, context, executor)
if isnothing(ilu.factor)
mb = A.minbatch
max_t = max(size(A, 1) ÷ mb, 1)
nt = min(nthreads(context), max_t)
nt = min(A.nthreads, max_t)
if nt == 1
@debug "Setting up serial ILU(0)-CSR"
F = ilu0_csr(A)
else
@debug "Setting up parallel ILU(0)-CSR with $(nthreads(td)) threads"
part = context.partitioner
if context isa ParallelCSRContext
part = context.partitioner
else
part = MetisPartitioner()
end
lookup = generate_lookup(part, A, nt)
F = ilu0_csr(A, lookup)
end
Expand Down
23 changes: 10 additions & 13 deletions src/linsolve/precond/utils.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@

function update_preconditioner!(preconditioner::Nothing, arg...)
# Do nothing.
end
function update_preconditioner!(preconditioner, lsys, model, storage, recorder, executor)
function update_preconditioner!(preconditioner::JutulPreconditioner, lsys::JutulLinearSystem, context, model, storage, recorder, executor)
J = jacobian(lsys)
r = residual(lsys)
ctx = linear_system_context(model, lsys)
update_preconditioner!(preconditioner, J, r, ctx, executor)
update_preconditioner!(preconditioner, J, r, context, executor)
end

function partial_update_preconditioner!(p, A, b, context, executor)
Expand All @@ -20,7 +18,11 @@ end
is_left_preconditioner(::JutulPreconditioner) = true
is_right_preconditioner(::JutulPreconditioner) = false

function linear_operator(precond::JutulPreconditioner, float_t = Float64)
function linear_operator(precond::JutulPreconditioner)
return linear_operator(precond, Float64, nothing, nothing, nothing, nothing, nothing)
end

function linear_operator(precond::JutulPreconditioner, float_t, sys, context, model, storage, recorder)
n = operator_nrows(precond)
function precond_apply!(res, x, α, β::T) where T
if β == zero(T)
Expand All @@ -36,13 +38,8 @@ function linear_operator(precond::JutulPreconditioner, float_t = Float64)
return op
end

function apply!(x, p::JutulPreconditioner, y, arg...)
#nead to be spesilized on type not all JutulPreconditioners has get_factor
function apply!(x, p::JutulPreconditioner, y)
factor = get_factorization(p)
if is_left_preconditioner(p)
ldiv!(x, factor, y)
elseif is_right_preconditioner(p)
error("Not supported.")
else
error("Neither left or right preconditioner?")
end
ldiv!(x, factor, y)
end
12 changes: 11 additions & 1 deletion src/linsolve/precond/various.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,29 @@ function operator_nrows(lup::LUPreconditioner)
return size(f.L, 1)
end

#function operator_nrows(prec::TrivialPreconditioner)
# return prec.dim[1]
#end
# LU factor as precond for wells?

"""
Trivial / identity preconditioner with size for use in subsystems.
"""
function apply!(x,tp::TrivialPreconditioner,r, args...)
x = copy(r)
end


# Trivial precond
function update_preconditioner!(tp::TrivialPreconditioner, lsys, model, storage, recorder, executor)
A = jacobian(lsys)
b = residual(lsys)
tp.dim = size(A).*length(b[1])
end

function linear_operator(id::TrivialPreconditioner, ::Symbol)
export linear_operator

function linear_operator(id::TrivialPreconditioner, args...)
return opEye(id.dim...)
end

Expand Down
6 changes: 3 additions & 3 deletions src/linsolve/scalar_cpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ mutable struct LUSolver
end
end

function linear_solve!(sys, solver::LUSolver, arg...; kwarg...)
function linear_solve!(sys, solver::LUSolver, arg...;dx = sys.dx, r = sys.r, kwargs...)
if length(sys.dx) > solver.max_size
error("System too big for LU solver. You can increase max_size at your own peril.")
end
J = sys.jac
r = sys.r
#r = sys.r
if !solver.reuse_memory
F = lu(J)
else
Expand All @@ -32,7 +32,7 @@ function linear_solve!(sys, solver::LUSolver, arg...; kwarg...)
F = solver.F
end

sys.dx .= -(F\r)
dx .= -(F\r)
@assert all(isfinite, sys.dx) "Linear solve resulted in non-finite values."
return linear_solve_return()
end
3 changes: 2 additions & 1 deletion src/models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -884,8 +884,9 @@ end

function solve_and_update!(storage, model::JutulModel, dt = nothing; linear_solver = nothing, recorder = nothing, executor = default_executor(), kwarg...)
lsys = storage.LinearizedSystem
context = model.context
t_solve = @elapsed begin
@tic "linear solve" (ok, n, history) = linear_solve!(lsys, linear_solver, model, storage, dt, recorder, executor)
@tic "linear solve" (ok, n, history) = linear_solve!(lsys, linear_solver, context, model, storage, dt, recorder, executor)
end
t_update = @elapsed @tic "primary variables" update = update_primary_variables!(storage, model; kwarg...)
return (t_solve, t_update, n, history, update)
Expand Down
5 changes: 5 additions & 0 deletions src/multimodel/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,10 @@ function setup_storage!(storage, model::MultiModel;
return storage
end

mutable struct MutableWrapper
maps
end

function setup_multimodel_maps!(storage, model)
groups = model.groups
if isnothing(groups)
Expand All @@ -138,6 +142,7 @@ function setup_multimodel_maps!(storage, model)
offset_map = map(g -> get_submodel_offsets(model, g), unique(groups))
end
storage[:multi_model_maps] = (offset_map = offset_map, );
storage[:eq_maps] = MutableWrapper(nothing)
end

function setup_equations_and_primary_variable_views!(storage, model::MultiModel, lsys)
Expand Down
Loading