Skip to content

Commit

Permalink
more fixes for adi preconditioner
Browse files Browse the repository at this point in the history
  • Loading branch information
johnomotani committed Oct 26, 2024
1 parent 0048959 commit c2a16ce
Showing 1 changed file with 20 additions and 14 deletions.
34 changes: 20 additions & 14 deletions moment_kinetics/src/nonlinear_solvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -186,54 +186,60 @@ function setup_nonlinear_solve(active, input_dict, coords, outer_coords=(); defa
v_size = nvperp * nvpa

function get_adi_precon_buffers()
v_solve_z_range = loop_ranges_store[(:z)].z
v_solve_z_range = looping.loop_ranges_store[(:z,)].z
v_solve_global_inds = [[((iz - 1)*v_size+1 : iz*v_size)..., total_size_coords+iz] for iz v_solve_z_range]
v_solve_nsolve = length(v_solve_z_range)
# Plus one for the one point of ppar that is included in the 'v solve'.
v_solve_n = vperp.n * vpa.n + 1
v_solve_n = nvperp * nvpa + 1
v_solve_implicit_lus = Vector{SparseArrays.UMFPACK.UmfpackLU{mk_float, mk_int}}(undef, v_solve_nsolve)
v_solve_explicit_matrices = Vector{SparseMatrixCSC{mk_float, mk_int}}(undef, v_solve_nsolve)
# This buffer is not shared-memory, because it will be used for a serial LU solve.
v_solve_buffer = allocate_float(v_solve_n)
v_solve_buffer2 = allocate_float(v_solve_n)
v_solve_matrix_buffer = allocate_float(v_solve_n, v_solve_n)

z_solve_vperp_range = loop_ranges_store[(:vperp,:vpa)].vperp
z_solve_vpa_range = loop_ranges_store[(:vperp,:vpa)].vpa
z_solve_global_inds = [[(ivperp-1)*nvpa+ivpa:v_size:nz*v_size+(ivperp-1)*nvpa+ivpa] for ivperp z_solve_vperp_range, ivpa z_solve_vpa_range]
z_solve_vperp_range = looping.loop_ranges_store[(:vperp,:vpa)].vperp
z_solve_vpa_range = looping.loop_ranges_store[(:vperp,:vpa)].vpa
z_solve_global_inds = vec([(ivperp-1)*nvpa+ivpa:v_size:(nz-1)*v_size+(ivperp-1)*nvpa+ivpa for ivperp z_solve_vperp_range, ivpa z_solve_vpa_range])
z_solve_nsolve = length(z_solve_vperp_range) * length(z_solve_vpa_range)
@serial_region begin
# Do the solve for ppar on the rank-0 process, which has the fewest grid
# points to handle if there are not an exactly equal number of points for each
# process.
push!(z_solve_global_inds, total_size_coords+1 : total_size_coords)
push!(z_solve_global_inds, total_size_coords+1 : total_size_coords+nz)
z_solve_nsolve += 1
end
z_solve_n = z.n
z_solve_n = nz
z_solve_implicit_lus = Vector{SparseArrays.UMFPACK.UmfpackLU{mk_float, mk_int}}(undef, z_solve_nsolve)
z_solve_explicit_matrices = Vector{SparseMatrixCSC{mk_float, mk_int}}(undef, z_solve_nsolve)
# This buffer is not shared-memory, because it will be used for a serial LU solve.
z_solve_buffer = allocate_float(z_solve_n)
z_solve_buffer2 = allocate_float(z_solve_n)
z_solve_matrix_buffer = allocate_float(z_solve_n, z_solve_n)

J_buffer = allocate_shared_float(pdf_plus_ppar_size, pdf_plus_ppar_size)
input_buffer = allocate_shared_float(pdf_plus_ppar_size)
intermediate_buffer = allocate_shared_float(pdf_plus_ppar_size)
output_buffer = allocate_shared_float(pdf_plus_ppar_size)

return (v_solve_global_inds=v_solve_global_inds,
return (v_solve_z_range=v_solve_z_range,
v_solve_global_inds=v_solve_global_inds,
v_solve_nsolve=v_solve_nsolve,
v_solve_implicit_lus=v_solve_implicit_lus,
v_solve_explicit_matrices=v_solve_explicit_matrices,
v_solve_buffer=v_solve_buffer, v_solve_buffer2=v_solve_buffer2,
v_solve_buffer=v_solve_buffer,
v_solve_matrix_buffer=v_solve_matrix_buffer,
z_solve_vperp_range=z_solve_vperp_range,
z_solve_vpa_range=z_solve_vpa_range,
z_solve_global_inds=z_solve_global_inds,
z_solve_nsolve=z_solve_nsolve,
z_solve_implicit_lus=z_solve_implicit_lus,
z_solve_explicit_matrices=z_solve_explicit_matrices,
z_solve_buffer=z_solve_buffer, z_solve_buffer2=z_solve_buffer2,
J_buffer=J_buffer, input_buffer=input_buffer,
z_solve_buffer=z_solve_buffer,
z_solve_matrix_buffer=z_solve_matrix_buffer, J_buffer=J_buffer,
input_buffer=input_buffer, intermediate_buffer=intermediate_buffer,
output_buffer=output_buffer)
end

preconditioners = fill(get_adi_precon_buffers(), reverse(outer_coord_size))
preconditioners = fill(get_adi_precon_buffers(), reverse(outer_coord_sizes))
elseif preconditioner_type === Val(:none)
preconditioners = nothing
else
Expand Down

0 comments on commit c2a16ce

Please sign in to comment.