diff --git a/ext/JutulPartitionedArraysExt/linalg.jl b/ext/JutulPartitionedArraysExt/linalg.jl index 39dccffc..827c4b6a 100644 --- a/ext/JutulPartitionedArraysExt/linalg.jl +++ b/ext/JutulPartitionedArraysExt/linalg.jl @@ -67,7 +67,8 @@ function Jutul.parray_update_preconditioners!(sim, preconditioner_base, precondi model = Jutul.get_simulator_model(sim) storage = Jutul.get_simulator_storage(sim) sys = storage.LinearizedSystem - Jutul.update_preconditioner!(prec, sys, model, storage, recorder, sim.executor) + ctx = model.context + Jutul.update_preconditioner!(prec, sys, ctx, model, storage, recorder, sim.executor) prec end return (preconditioner_base, preconditioners) diff --git a/src/core_types/core_types.jl b/src/core_types/core_types.jl index 9e5d5996..6173147b 100644 --- a/src/core_types/core_types.jl +++ b/src/core_types/core_types.jl @@ -965,7 +965,17 @@ struct MultiModel{label, T, CT, G, C, GL} <: AbstractMultiModel{label} group_lookup::GL end -function MultiModel(models, label::Union{Nothing, Symbol} = nothing; cross_terms = Vector{CrossTermPair}(), groups = nothing, context = nothing, reduction = missing, specialize = false, specialize_ad = false) +function MultiModel(models, label::Union{Nothing, Symbol} = nothing; + cross_terms = Vector{CrossTermPair}(), + groups = nothing, + context = nothing, + reduction = missing, + specialize = false, + specialize_ad = false + ) + if isnothing(context) + context = models[first(keys(models))].context + end group_lookup = Dict{Symbol, Int}() if isnothing(groups) num_groups = 1 diff --git a/src/linsolve/krylov.jl b/src/linsolve/krylov.jl index 6bbce698..4fca8ead 100644 --- a/src/linsolve/krylov.jl +++ b/src/linsolve/krylov.jl @@ -56,43 +56,41 @@ function Base.show(io::IO, krylov::GenericKrylov) print(io, "Generic Krylov using $(krylov.solver) (ϵₐ=$atol, ϵ=$rtol) with prec = $(typeof(krylov.preconditioner))") end -function preconditioner(krylov::AbstractKrylov, sys, model, storage, recorder, float_t = Float64) +function preconditioner(krylov::AbstractKrylov, sys, context, model, storage, recorder) M = krylov.preconditioner + Ft = float_type(context) if isnothing(M) op = I else - op = PrecondWrapper(linear_operator(M, float_t)) + linop = linear_operator(M, Ft, sys, context, model, storage, recorder) + op = PrecondWrapper(linop) end return op end -# export update! -# function update_preconditioner!(prec, sys, model, storage, recorder) -# update!(prec, sys, model, storage, recorder) -# end - function linear_solve!(sys::LSystem, - krylov::GenericKrylov, - model, - storage = nothing, - dt = nothing, - recorder = ProgressRecorder(), - executor = default_executor(); - dx = sys.dx_buffer, - r = vector_residual(sys), - atol = linear_solver_tolerance(krylov, :absolute), - rtol = linear_solver_tolerance(krylov, :relative), - rtol_nl = linear_solver_tolerance(krylov, :nonlinear_relative), - rtol_relaxed = linear_solver_tolerance(krylov, :relaxed_relative) - ) + krylov::GenericKrylov, + context::JutulContext, + model::JutulModel, + storage = nothing, + dt = nothing, + recorder = ProgressRecorder(), + executor = default_executor(); + dx = sys.dx_buffer, + r = vector_residual(sys), + atol = linear_solver_tolerance(krylov, :absolute), + rtol = linear_solver_tolerance(krylov, :relative), + rtol_nl = linear_solver_tolerance(krylov, :nonlinear_relative), + rtol_relaxed = linear_solver_tolerance(krylov, :relaxed_relative) + ) cfg = krylov.config prec = krylov.preconditioner - Ft = float_type(model.context) + Ft = float_type(context) sys = krylov_scale_system!(sys, krylov, dt) t_prep = @elapsed @tic "prepare" prepare_linear_solve!(sys) op = linear_operator(sys) - t_prec = @elapsed @tic "precond" update_preconditioner!(prec, sys, model, storage, recorder, executor) - prec_op = preconditioner(krylov, sys, model, storage, recorder, Ft) + t_prec = @elapsed @tic "precond" update_preconditioner!(prec, sys, context, model, storage, recorder, executor) + prec_op = preconditioner(krylov, sys, context, model, storage, recorder) v = Int64(cfg.verbose) max_it = cfg.max_iterations min_it = cfg.min_iterations @@ -183,6 +181,16 @@ function linear_solve!(sys::LSystem, return linear_solve_return(solved, n, stats, precond = t_prec_op, precond_count = t_prec_count, prepare = t_prec + t_prep) end +function linear_solve!( + sys::LSystem, + krylov::GenericKrylov, + model::JutulModel, + arg...; + kwarg... + ) + return linear_solve!(sys, krylov, model.context, model, arg...; kwarg...) +end + function krylov_scale_system!(sys, krylov::GenericKrylov, dt) sys, krylov.storage_scaling = apply_scaling_to_linearized_system!(sys, krylov.storage_scaling, krylov.scaling, dt) return sys diff --git a/src/linsolve/precond/ilu.jl b/src/linsolve/precond/ilu.jl index 6eb2ff0a..9757f823 100644 --- a/src/linsolve/precond/ilu.jl +++ b/src/linsolve/precond/ilu.jl @@ -1,4 +1,3 @@ - """ ILU(0) preconditioner on CPU """ @@ -35,17 +34,21 @@ function update_preconditioner!(ilu::ILUZeroPreconditioner, A, b, context, execu end end -function update_preconditioner!(ilu::ILUZeroPreconditioner, A::StaticSparsityMatrixCSR, b, context::ParallelCSRContext, executor) +function update_preconditioner!(ilu::ILUZeroPreconditioner, A::StaticSparsityMatrixCSR, b, context, executor) if isnothing(ilu.factor) mb = A.minbatch max_t = max(size(A, 1) ÷ mb, 1) - nt = min(nthreads(context), max_t) + nt = min(A.nthreads, max_t) if nt == 1 @debug "Setting up serial ILU(0)-CSR" F = ilu0_csr(A) else @debug "Setting up parallel ILU(0)-CSR with $(nthreads(td)) threads" - part = context.partitioner + if context isa ParallelCSRContext + part = context.partitioner + else + part = MetisPartitioner() + end lookup = generate_lookup(part, A, nt) F = ilu0_csr(A, lookup) end diff --git a/src/linsolve/precond/utils.jl b/src/linsolve/precond/utils.jl index 1cf60e22..c2c999d4 100644 --- a/src/linsolve/precond/utils.jl +++ b/src/linsolve/precond/utils.jl @@ -1,12 +1,10 @@ - function update_preconditioner!(preconditioner::Nothing, arg...) # Do nothing. end -function update_preconditioner!(preconditioner, lsys, model, storage, recorder, executor) +function update_preconditioner!(preconditioner::JutulPreconditioner, lsys::JutulLinearSystem, context, model, storage, recorder, executor) J = jacobian(lsys) r = residual(lsys) - ctx = linear_system_context(model, lsys) - update_preconditioner!(preconditioner, J, r, ctx, executor) + update_preconditioner!(preconditioner, J, r, context, executor) end function partial_update_preconditioner!(p, A, b, context, executor) @@ -20,7 +18,11 @@ end is_left_preconditioner(::JutulPreconditioner) = true is_right_preconditioner(::JutulPreconditioner) = false -function linear_operator(precond::JutulPreconditioner, float_t = Float64) +function linear_operator(precond::JutulPreconditioner) + return linear_operator(precond, Float64, nothing, nothing, nothing, nothing, nothing) +end + +function linear_operator(precond::JutulPreconditioner, float_t, sys, context, model, storage, recorder) n = operator_nrows(precond) function precond_apply!(res, x, α, β::T) where T if β == zero(T) @@ -36,13 +38,8 @@ function linear_operator(precond::JutulPreconditioner, float_t = Float64) return op end -function apply!(x, p::JutulPreconditioner, y, arg...) +#nead to be spesilized on type not all JutulPreconditioners has get_factor +function apply!(x, p::JutulPreconditioner, y) factor = get_factorization(p) - if is_left_preconditioner(p) - ldiv!(x, factor, y) - elseif is_right_preconditioner(p) - error("Not supported.") - else - error("Neither left or right preconditioner?") - end + ldiv!(x, factor, y) end diff --git a/src/linsolve/precond/various.jl b/src/linsolve/precond/various.jl index 775dfc8d..bb965c59 100644 --- a/src/linsolve/precond/various.jl +++ b/src/linsolve/precond/various.jl @@ -29,11 +29,19 @@ function operator_nrows(lup::LUPreconditioner) return size(f.L, 1) end +#function operator_nrows(prec::TrivialPreconditioner) +# return prec.dim[1] +#end # LU factor as precond for wells? """ Trivial / identity preconditioner with size for use in subsystems. """ +function apply!(x,tp::TrivialPreconditioner,r, args...) + x = copy(r) +end + + # Trivial precond function update_preconditioner!(tp::TrivialPreconditioner, lsys, model, storage, recorder, executor) A = jacobian(lsys) @@ -41,7 +49,9 @@ function update_preconditioner!(tp::TrivialPreconditioner, lsys, model, storage, tp.dim = size(A).*length(b[1]) end -function linear_operator(id::TrivialPreconditioner, ::Symbol) +export linear_operator + +function linear_operator(id::TrivialPreconditioner, args...) return opEye(id.dim...) end diff --git a/src/linsolve/scalar_cpu.jl b/src/linsolve/scalar_cpu.jl index b3379275..2bbe864f 100644 --- a/src/linsolve/scalar_cpu.jl +++ b/src/linsolve/scalar_cpu.jl @@ -15,12 +15,12 @@ mutable struct LUSolver end end -function linear_solve!(sys, solver::LUSolver, arg...; kwarg...) +function linear_solve!(sys, solver::LUSolver, arg...;dx = sys.dx, r = sys.r, kwargs...) if length(sys.dx) > solver.max_size error("System too big for LU solver. You can increase max_size at your own peril.") end J = sys.jac - r = sys.r + #r = sys.r if !solver.reuse_memory F = lu(J) else @@ -32,7 +32,7 @@ function linear_solve!(sys, solver::LUSolver, arg...; kwarg...) F = solver.F end - sys.dx .= -(F\r) + dx .= -(F\r) @assert all(isfinite, sys.dx) "Linear solve resulted in non-finite values." return linear_solve_return() end diff --git a/src/models.jl b/src/models.jl index 31e5a164..1931574e 100644 --- a/src/models.jl +++ b/src/models.jl @@ -884,8 +884,9 @@ end function solve_and_update!(storage, model::JutulModel, dt = nothing; linear_solver = nothing, recorder = nothing, executor = default_executor(), kwarg...) lsys = storage.LinearizedSystem + context = model.context t_solve = @elapsed begin - @tic "linear solve" (ok, n, history) = linear_solve!(lsys, linear_solver, model, storage, dt, recorder, executor) + @tic "linear solve" (ok, n, history) = linear_solve!(lsys, linear_solver, context, model, storage, dt, recorder, executor) end t_update = @elapsed @tic "primary variables" update = update_primary_variables!(storage, model; kwarg...) return (t_solve, t_update, n, history, update) diff --git a/src/multimodel/model.jl b/src/multimodel/model.jl index 5261aa56..cb58ae01 100644 --- a/src/multimodel/model.jl +++ b/src/multimodel/model.jl @@ -130,6 +130,10 @@ function setup_storage!(storage, model::MultiModel; return storage end +mutable struct MutableWrapper + maps +end + function setup_multimodel_maps!(storage, model) groups = model.groups if isnothing(groups) @@ -138,6 +142,7 @@ function setup_multimodel_maps!(storage, model) offset_map = map(g -> get_submodel_offsets(model, g), unique(groups)) end storage[:multi_model_maps] = (offset_map = offset_map, ); + storage[:eq_maps] = MutableWrapper(nothing) end function setup_equations_and_primary_variable_views!(storage, model::MultiModel, lsys)