From c2a16ced0cd63d509fad79678e4f39f65385bf61 Mon Sep 17 00:00:00 2001
From: John Omotani <john.omotani@ukaea.uk>
Date: Sat, 26 Oct 2024 20:19:33 +0100
Subject: [PATCH] more fixes for adi preconditioner

---
 moment_kinetics/src/nonlinear_solvers.jl | 34 ++++++++++++++----------
 1 file changed, 20 insertions(+), 14 deletions(-)

diff --git a/moment_kinetics/src/nonlinear_solvers.jl b/moment_kinetics/src/nonlinear_solvers.jl
index 14ff6262b0..fe566306f5 100644
--- a/moment_kinetics/src/nonlinear_solvers.jl
+++ b/moment_kinetics/src/nonlinear_solvers.jl
@@ -186,54 +186,60 @@ function setup_nonlinear_solve(active, input_dict, coords, outer_coords=(); defa
         v_size = nvperp * nvpa
 
         function get_adi_precon_buffers()
-            v_solve_z_range = loop_ranges_store[(:z)].z
+            v_solve_z_range = looping.loop_ranges_store[(:z,)].z
             v_solve_global_inds = [[((iz - 1)*v_size+1 : iz*v_size)..., total_size_coords+iz] for iz ∈ v_solve_z_range]
             v_solve_nsolve = length(v_solve_z_range)
             # Plus one for the one point of ppar that is included in the 'v solve'.
-            v_solve_n = vperp.n * vpa.n + 1
+            v_solve_n = nvperp * nvpa + 1
             v_solve_implicit_lus = Vector{SparseArrays.UMFPACK.UmfpackLU{mk_float, mk_int}}(undef, v_solve_nsolve)
             v_solve_explicit_matrices = Vector{SparseMatrixCSC{mk_float, mk_int}}(undef, v_solve_nsolve)
             # This buffer is not shared-memory, because it will be used for a serial LU solve.
             v_solve_buffer = allocate_float(v_solve_n)
-            v_solve_buffer2 = allocate_float(v_solve_n)
+            v_solve_matrix_buffer = allocate_float(v_solve_n, v_solve_n)
 
-            z_solve_vperp_range = loop_ranges_store[(:vperp,:vpa)].vperp
-            z_solve_vpa_range = loop_ranges_store[(:vperp,:vpa)].vpa
-            z_solve_global_inds = [[(ivperp-1)*nvpa+ivpa:v_size:nz*v_size+(ivperp-1)*nvpa+ivpa] for ivperp ∈ z_solve_vperp_range, ivpa ∈ z_solve_vpa_range]
+            z_solve_vperp_range = looping.loop_ranges_store[(:vperp,:vpa)].vperp
+            z_solve_vpa_range = looping.loop_ranges_store[(:vperp,:vpa)].vpa
+            z_solve_global_inds = vec([(ivperp-1)*nvpa+ivpa:v_size:(nz-1)*v_size+(ivperp-1)*nvpa+ivpa for ivperp ∈ z_solve_vperp_range, ivpa ∈ z_solve_vpa_range])
             z_solve_nsolve = length(z_solve_vperp_range) * length(z_solve_vpa_range)
             @serial_region begin
                 # Do the solve for ppar on the rank-0 process, which has the fewest grid
                 # points to handle if there are not an exactly equal number of points for each
                 # process.
-                push!(z_solve_global_inds, total_size_coords+1 : total_size_coords)
+                push!(z_solve_global_inds, total_size_coords+1 : total_size_coords+nz)
                 z_solve_nsolve += 1
             end
-            z_solve_n = z.n
+            z_solve_n = nz
             z_solve_implicit_lus = Vector{SparseArrays.UMFPACK.UmfpackLU{mk_float, mk_int}}(undef, z_solve_nsolve)
             z_solve_explicit_matrices = Vector{SparseMatrixCSC{mk_float, mk_int}}(undef, z_solve_nsolve)
             # This buffer is not shared-memory, because it will be used for a serial LU solve.
             z_solve_buffer = allocate_float(z_solve_n)
-            z_solve_buffer2 = allocate_float(z_solve_n)
+            z_solve_matrix_buffer = allocate_float(z_solve_n, z_solve_n)
 
             J_buffer = allocate_shared_float(pdf_plus_ppar_size, pdf_plus_ppar_size)
             input_buffer = allocate_shared_float(pdf_plus_ppar_size)
+            intermediate_buffer = allocate_shared_float(pdf_plus_ppar_size)
             output_buffer = allocate_shared_float(pdf_plus_ppar_size)
 
-            return (v_solve_global_inds=v_solve_global_inds,
+            return (v_solve_z_range=v_solve_z_range,
+                    v_solve_global_inds=v_solve_global_inds,
                     v_solve_nsolve=v_solve_nsolve,
                     v_solve_implicit_lus=v_solve_implicit_lus,
                     v_solve_explicit_matrices=v_solve_explicit_matrices,
-                    v_solve_buffer=v_solve_buffer, v_solve_buffer2=v_solve_buffer2,
+                    v_solve_buffer=v_solve_buffer,
+                    v_solve_matrix_buffer=v_solve_matrix_buffer,
+                    z_solve_vperp_range=z_solve_vperp_range,
+                    z_solve_vpa_range=z_solve_vpa_range,
                     z_solve_global_inds=z_solve_global_inds,
                     z_solve_nsolve=z_solve_nsolve,
                     z_solve_implicit_lus=z_solve_implicit_lus,
                     z_solve_explicit_matrices=z_solve_explicit_matrices,
-                    z_solve_buffer=z_solve_buffer, z_solve_buffer2=z_solve_buffer2,
-                    J_buffer=J_buffer, input_buffer=input_buffer,
+                    z_solve_buffer=z_solve_buffer,
+                    z_solve_matrix_buffer=z_solve_matrix_buffer, J_buffer=J_buffer,
+                    input_buffer=input_buffer, intermediate_buffer=intermediate_buffer,
                     output_buffer=output_buffer)
         end
 
-        preconditioners = fill(get_adi_precon_buffers(), reverse(outer_coord_size))
+        preconditioners = fill(get_adi_precon_buffers(), reverse(outer_coord_sizes))
     elseif preconditioner_type === Val(:none)
         preconditioners = nothing
     else