From c2a16ced0cd63d509fad79678e4f39f65385bf61 Mon Sep 17 00:00:00 2001 From: John Omotani <john.omotani@ukaea.uk> Date: Sat, 26 Oct 2024 20:19:33 +0100 Subject: [PATCH] more fixes for adi preconditioner --- moment_kinetics/src/nonlinear_solvers.jl | 34 ++++++++++++++---------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/moment_kinetics/src/nonlinear_solvers.jl b/moment_kinetics/src/nonlinear_solvers.jl index 14ff6262b0..fe566306f5 100644 --- a/moment_kinetics/src/nonlinear_solvers.jl +++ b/moment_kinetics/src/nonlinear_solvers.jl @@ -186,54 +186,60 @@ function setup_nonlinear_solve(active, input_dict, coords, outer_coords=(); defa v_size = nvperp * nvpa function get_adi_precon_buffers() - v_solve_z_range = loop_ranges_store[(:z)].z + v_solve_z_range = looping.loop_ranges_store[(:z,)].z v_solve_global_inds = [[((iz - 1)*v_size+1 : iz*v_size)..., total_size_coords+iz] for iz ∈ v_solve_z_range] v_solve_nsolve = length(v_solve_z_range) # Plus one for the one point of ppar that is included in the 'v solve'. - v_solve_n = vperp.n * vpa.n + 1 + v_solve_n = nvperp * nvpa + 1 v_solve_implicit_lus = Vector{SparseArrays.UMFPACK.UmfpackLU{mk_float, mk_int}}(undef, v_solve_nsolve) v_solve_explicit_matrices = Vector{SparseMatrixCSC{mk_float, mk_int}}(undef, v_solve_nsolve) # This buffer is not shared-memory, because it will be used for a serial LU solve. v_solve_buffer = allocate_float(v_solve_n) - v_solve_buffer2 = allocate_float(v_solve_n) + v_solve_matrix_buffer = allocate_float(v_solve_n, v_solve_n) - z_solve_vperp_range = loop_ranges_store[(:vperp,:vpa)].vperp - z_solve_vpa_range = loop_ranges_store[(:vperp,:vpa)].vpa - z_solve_global_inds = [[(ivperp-1)*nvpa+ivpa:v_size:nz*v_size+(ivperp-1)*nvpa+ivpa] for ivperp ∈ z_solve_vperp_range, ivpa ∈ z_solve_vpa_range] + z_solve_vperp_range = looping.loop_ranges_store[(:vperp,:vpa)].vperp + z_solve_vpa_range = looping.loop_ranges_store[(:vperp,:vpa)].vpa + z_solve_global_inds = vec([(ivperp-1)*nvpa+ivpa:v_size:(nz-1)*v_size+(ivperp-1)*nvpa+ivpa for ivperp ∈ z_solve_vperp_range, ivpa ∈ z_solve_vpa_range]) z_solve_nsolve = length(z_solve_vperp_range) * length(z_solve_vpa_range) @serial_region begin # Do the solve for ppar on the rank-0 process, which has the fewest grid # points to handle if there are not an exactly equal number of points for each # process. - push!(z_solve_global_inds, total_size_coords+1 : total_size_coords) + push!(z_solve_global_inds, total_size_coords+1 : total_size_coords+nz) z_solve_nsolve += 1 end - z_solve_n = z.n + z_solve_n = nz z_solve_implicit_lus = Vector{SparseArrays.UMFPACK.UmfpackLU{mk_float, mk_int}}(undef, z_solve_nsolve) z_solve_explicit_matrices = Vector{SparseMatrixCSC{mk_float, mk_int}}(undef, z_solve_nsolve) # This buffer is not shared-memory, because it will be used for a serial LU solve. z_solve_buffer = allocate_float(z_solve_n) - z_solve_buffer2 = allocate_float(z_solve_n) + z_solve_matrix_buffer = allocate_float(z_solve_n, z_solve_n) J_buffer = allocate_shared_float(pdf_plus_ppar_size, pdf_plus_ppar_size) input_buffer = allocate_shared_float(pdf_plus_ppar_size) + intermediate_buffer = allocate_shared_float(pdf_plus_ppar_size) output_buffer = allocate_shared_float(pdf_plus_ppar_size) - return (v_solve_global_inds=v_solve_global_inds, + return (v_solve_z_range=v_solve_z_range, + v_solve_global_inds=v_solve_global_inds, v_solve_nsolve=v_solve_nsolve, v_solve_implicit_lus=v_solve_implicit_lus, v_solve_explicit_matrices=v_solve_explicit_matrices, - v_solve_buffer=v_solve_buffer, v_solve_buffer2=v_solve_buffer2, + v_solve_buffer=v_solve_buffer, + v_solve_matrix_buffer=v_solve_matrix_buffer, + z_solve_vperp_range=z_solve_vperp_range, + z_solve_vpa_range=z_solve_vpa_range, z_solve_global_inds=z_solve_global_inds, z_solve_nsolve=z_solve_nsolve, z_solve_implicit_lus=z_solve_implicit_lus, z_solve_explicit_matrices=z_solve_explicit_matrices, - z_solve_buffer=z_solve_buffer, z_solve_buffer2=z_solve_buffer2, - J_buffer=J_buffer, input_buffer=input_buffer, + z_solve_buffer=z_solve_buffer, + z_solve_matrix_buffer=z_solve_matrix_buffer, J_buffer=J_buffer, + input_buffer=input_buffer, intermediate_buffer=intermediate_buffer, output_buffer=output_buffer) end - preconditioners = fill(get_adi_precon_buffers(), reverse(outer_coord_size)) + preconditioners = fill(get_adi_precon_buffers(), reverse(outer_coord_sizes)) elseif preconditioner_type === Val(:none) preconditioners = nothing else