From 518211b976b2c3b5e44fd4bfce8f0adad126a7e5 Mon Sep 17 00:00:00 2001
From: Michael Hardman <michael.hardman@tokamakenergy.co.uk>
Date: Wed, 18 Oct 2023 10:06:34 +0000
Subject: [PATCH] Working version of the elliptic solvers for the Rosenbluth
 potentials, using boundary data provided by the Green's functions and
 numerical integration. Parallelisation is used to speed up the calculation of
 the boundary data.

---
 2D_FEM_assembly_test.jl |  82 +++++++++++++++++++++-----
 src/fokker_planck.jl    | 127 +++++++++++++++++++++++++++++++---------
 2 files changed, 165 insertions(+), 44 deletions(-)

diff --git a/2D_FEM_assembly_test.jl b/2D_FEM_assembly_test.jl
index 29a7211b9..791306371 100644
--- a/2D_FEM_assembly_test.jl
+++ b/2D_FEM_assembly_test.jl
@@ -12,9 +12,11 @@ using moment_kinetics.gauss_legendre: setup_gausslegendre_pseudospectral, get_QQ
 using moment_kinetics.type_definitions: mk_float, mk_int
 using moment_kinetics.fokker_planck: F_Maxwellian, H_Maxwellian, G_Maxwellian
 using moment_kinetics.fokker_planck: d2Gdvpa2, d2Gdvperp2, dGdvperp, d2Gdvperpdvpa, dHdvpa, dHdvperp
-using moment_kinetics.fokker_planck: init_fokker_planck_collisions, fokkerplanck_arrays_struct
+using moment_kinetics.fokker_planck: init_fokker_planck_collisions, fokkerplanck_arrays_struct, fokkerplanck_boundary_data_arrays_struct
+using moment_kinetics.fokker_planck: init_fokker_planck_collisions_new, boundary_integration_weights_struct
 using moment_kinetics.calculus: derivative!
 using moment_kinetics.communication
+using moment_kinetics.communication: MPISharedArray
 using moment_kinetics.looping
 using SparseArrays: sparse
 using LinearAlgebra: mul!, lu, cholesky
@@ -98,9 +100,9 @@ if abspath(PROGRAM_FILE) == @__FILE__
     end
     
     struct vpa_vperp_boundary_data
-        lower_boundary_vpa::Array{mk_float,1}
-        upper_boundary_vpa::Array{mk_float,1}
-        upper_boundary_vperp::Array{mk_float,1}
+        lower_boundary_vpa::MPISharedArray{mk_float,1}
+        upper_boundary_vpa::MPISharedArray{mk_float,1}
+        upper_boundary_vperp::MPISharedArray{mk_float,1}
     end
     
     struct rosenbluth_potential_boundary_data
@@ -167,10 +169,12 @@ if abspath(PROGRAM_FILE) == @__FILE__
     
     
     function calculate_boundary_data!(func_data::vpa_vperp_boundary_data,
-                                            weight,func_input,vpa,vperp)
+                                            weight::MPISharedArray{mk_float,4},func_input,vpa,vperp)
         nvpa = vpa.n
         nvperp = vperp.n
-        for ivperp in 1:nvperp
+        #for ivperp in 1:nvperp
+        begin_vperp_region()
+        @loop_vperp ivperp begin
             func_data.lower_boundary_vpa[ivperp] = 0.0
             func_data.upper_boundary_vpa[ivperp] = 0.0
             for ivperpp in 1:nvperp
@@ -180,7 +184,9 @@ if abspath(PROGRAM_FILE) == @__FILE__
                 end
             end
         end
-        for ivpa in 1:nvpa
+        #for ivpa in 1:nvpa
+        begin_vpa_region()
+        @loop_vpa ivpa begin
             func_data.upper_boundary_vperp[ivpa] = 0.0
             for ivperpp in 1:nvperp
                 for ivpap in 1:nvpa
@@ -188,26 +194,65 @@ if abspath(PROGRAM_FILE) == @__FILE__
                 end
             end
         end
+        # return to serial parallelisation
+        begin_serial_region()
+        return nothing
+    end
+    
+    function calculate_boundary_data!(func_data::vpa_vperp_boundary_data,
+                                      weight::boundary_integration_weights_struct,
+                                      func_input,vpa,vperp)
+        nvpa = vpa.n
+        nvperp = vperp.n
+        #for ivperp in 1:nvperp
+        begin_vperp_region()
+        @loop_vperp ivperp begin
+            func_data.lower_boundary_vpa[ivperp] = 0.0
+            func_data.upper_boundary_vpa[ivperp] = 0.0
+            for ivperpp in 1:nvperp
+                for ivpap in 1:nvpa
+                    func_data.lower_boundary_vpa[ivperp] += weight.lower_vpa_boundary[ivpap,ivperpp,ivperp]*func_input[ivpap,ivperpp]
+                    func_data.upper_boundary_vpa[ivperp] += weight.upper_vpa_boundary[ivpap,ivperpp,ivperp]*func_input[ivpap,ivperpp]
+                end
+            end
+        end
+        #for ivpa in 1:nvpa
+        begin_vpa_region()
+        @loop_vpa ivpa begin
+            func_data.upper_boundary_vperp[ivpa] = 0.0
+            for ivperpp in 1:nvperp
+                for ivpap in 1:nvpa
+                    func_data.upper_boundary_vperp[ivpa] += weight.upper_vperp_boundary[ivpap,ivperpp,ivpa]*func_input[ivpap,ivperpp]
+                end
+            end
+        end
+        # return to serial parallelisation
+        begin_serial_region()
         return nothing
     end
     
     function calculate_rosenbluth_potential_boundary_data!(rpbd::rosenbluth_potential_boundary_data,
-        fkpl::fokkerplanck_arrays_struct,pdf)
+        fkpl::Union{fokkerplanck_arrays_struct,fokkerplanck_boundary_data_arrays_struct},pdf)
         # get derivatives of pdf
         dfdvperp = fkpl.dfdvperp
         dfdvpa = fkpl.dfdvpa
         d2fdvperpdvpa = fkpl.d2fdvperpdvpa
-        for ivpa in 1:vpa.n
+        #for ivpa in 1:vpa.n
+        begin_vpa_region()
+        @loop_vpa ivpa begin
             @views derivative!(vperp.scratch, pdf[ivpa,:], vperp, vperp_spectral)
             @. dfdvperp[ivpa,:] = vperp.scratch
         end
-        for ivperp in 1:vperp.n
+        begin_vperp_region()
+        @loop_vperp ivperp begin
+        #for ivperp in 1:vperp.n
             @views derivative!(vpa.scratch, pdf[:,ivperp], vpa, vpa_spectral)
             @. dfdvpa[:,ivperp] = vpa.scratch
             @views derivative!(vpa.scratch, dfdvperp[:,ivperp], vpa, vpa_spectral)
             @. d2fdvperpdvpa[:,ivperp] = vpa.scratch
         end
-        
+        # ensure data is synchronized
+        begin_serial_region()
         # carry out the numerical integration 
         calculate_boundary_data!(rpbd.H_data,fkpl.H0_weights,pdf,vpa,vperp)
         calculate_boundary_data!(rpbd.dHdvpa_data,fkpl.H0_weights,dfdvpa,vpa,vperp)
@@ -381,13 +426,13 @@ if abspath(PROGRAM_FILE) == @__FILE__
     
     # define inputs needed for the test
 	plot_test_output = false#true
-    ngrid = 3 #number of points per element 
+    ngrid = 9 #number of points per element 
 	nelement_local_vpa = 16 # number of elements per rank
 	nelement_global_vpa = nelement_local_vpa # total number of elements 
 	nelement_local_vperp = 8 # number of elements per rank
 	nelement_global_vperp = nelement_local_vperp # total number of elements 
-	Lvpa = 6.0 #physical box size in reference units 
-	Lvperp = 3.0 #physical box size in reference units 
+	Lvpa = 12.0 #physical box size in reference units 
+	Lvperp = 6.0 #physical box size in reference units 
 	bc = "" #not required to take a particular value, not used 
 	# fd_option and adv_input not actually used so given values unimportant
 	#discretization = "chebyshev_pseudospectral"
@@ -769,10 +814,17 @@ if abspath(PROGRAM_FILE) == @__FILE__
       d2Gdvperpdvpa_M_exact,d2Gdvpa2_M_exact,vpa,vperp)
     # use numerical integration to find the boundary data
     # initialise the weights
-    fkpl_arrays = init_fokker_planck_collisions(vperp,vpa; precompute_weights=true)
+    #fkpl_arrays = init_fokker_planck_collisions(vperp,vpa; precompute_weights=true)
+    fkpl_arrays = init_fokker_planck_collisions_new(vpa,vperp; precompute_weights=true)
     begin_serial_region()
     # do the numerical integration at the boundaries (N.B. G not supported)
+    @serial_region begin 
+        println("begin boundary data calculation   ", Dates.format(now(), dateformat"H:MM:SS"))
+    end
     calculate_rosenbluth_potential_boundary_data!(rpbd,fkpl_arrays,F_M)
+    @serial_region begin 
+        println("finished boundary data calculation   ", Dates.format(now(), dateformat"H:MM:SS"))
+    end
     # test the boundary data calculation
     test_rosenbluth_potential_boundary_data(rpbd,rpbd_exact,vpa,vperp)
     #rpbd = rpbd_exact
diff --git a/src/fokker_planck.jl b/src/fokker_planck.jl
index 2cdee57e9..fe0727732 100644
--- a/src/fokker_planck.jl
+++ b/src/fokker_planck.jl
@@ -5,6 +5,7 @@ module fokker_planck
 
 
 export init_fokker_planck_collisions, fokkerplanck_arrays_struct
+export init_fokker_planck_collisions_new
 export explicit_fokker_planck_collisions!
 export calculate_Rosenbluth_potentials!
 export calculate_collisional_fluxes, calculate_Maxwellian_Rosenbluth_coefficients
@@ -16,7 +17,7 @@ export dHdvpa, dHdvperp, Cssp_Maxwellian_inputs
 export F_Maxwellian, dFdvpa_Maxwellian, dFdvperp_Maxwellian
 export d2Fdvpa2_Maxwellian, d2Fdvperpdvpa_Maxwellian, d2Fdvperp2_Maxwellian
 export H_Maxwellian, G_Maxwellian
-
+export boundary_integration_weights_struct, fokkerplanck_boundary_data_arrays_struct
 export Cssp_fully_expanded_form, get_local_Cssp_coefficients!, init_fokker_planck_collisions
 # testing
 export symmetric_matrix_inverse
@@ -126,6 +127,9 @@ struct fokkerplanck_boundary_data_arrays_struct
     H1_weights::boundary_integration_weights_struct
     H2_weights::boundary_integration_weights_struct
     H3_weights::boundary_integration_weights_struct
+    dfdvpa::MPISharedArray{mk_float,2}
+    d2fdvperpdvpa::MPISharedArray{mk_float,2}
+    dfdvperp::MPISharedArray{mk_float,2}    
 end
 
 
@@ -135,7 +139,8 @@ function allocate_boundary_integration_weight(vpa,vperp)
     lower_vpa_boundary = allocate_shared_float(nvpa,nvperp,nvperp)
     upper_vpa_boundary = allocate_shared_float(nvpa,nvperp,nvperp)
     upper_vperp_boundary = allocate_shared_float(nvpa,nvperp,nvpa)
-    return boundary_integration_weights_struct()
+    return boundary_integration_weights_struct(lower_vpa_boundary,
+            upper_vpa_boundary, upper_vperp_boundary)
 end
 
 function allocate_boundary_integration_weights(vpa,vperp)
@@ -145,8 +150,14 @@ function allocate_boundary_integration_weights(vpa,vperp)
     H1_weights = allocate_boundary_integration_weight(vpa,vperp)
     H2_weights = allocate_boundary_integration_weight(vpa,vperp)
     H3_weights = allocate_boundary_integration_weight(vpa,vperp)
+    nvpa = vpa.n
+    nvperp = vperp.n
+    dfdvpa = allocate_shared_float(nvpa,nvperp)
+    d2fdvperpdvpa = allocate_shared_float(nvpa,nvperp)
+    dfdvperp = allocate_shared_float(nvpa,nvperp)
     return fokkerplanck_boundary_data_arrays_struct(G0_weights,
-            G1_weights,H0_weights,H1_weights,H2_weights,H3_weights)
+            G1_weights,H0_weights,H1_weights,H2_weights,H3_weights,
+            dfdvpa,d2fdvperpdvpa,dfdvperp)
 end
 
 # initialise the elliptic integral factor arrays 
@@ -206,7 +217,7 @@ function init_fokker_planck_collisions_new(vpa,vperp; precompute_weights=false)
         @views init_Rosenbluth_potential_boundary_integration_weights!(bwgt.G0_weights, bwgt.G1_weights, bwgt.H0_weights, bwgt.H1_weights,
                                         bwgt.H2_weights, bwgt.H3_weights, vpa, vperp)
     end
-    return fka
+    return bwgt
 end
 
 """
@@ -322,33 +333,17 @@ only along the velocity space boundaries
 """
 function init_Rosenbluth_potential_boundary_integration_weights!(G0_weights,
       G1_weights,H0_weights,H1_weights,H2_weights,H3_weights,vpa,vperp)
-    @serial_region begin
-        println("setting up GL quadrature   ", Dates.format(now(), dateformat"H:MM:SS"))
-    end
-    
-    nelement_vpa, ngrid_vpa = vpa.nelement_local, vpa.ngrid
-    nelement_vperp, ngrid_vperp = vperp.nelement_local, vperp.ngrid
-    ngrid = max(ngrid_vpa,ngrid_vperp)
-    
-    # get Gauss-Legendre points and weights on (-1,1)
-    nquad = 2*ngrid
-    x_legendre, w_legendre = gausslegendre(nquad)
-    #nlaguerre = min(9,nquad) # to prevent points to close to the boundaries
-    nlaguerre = nquad
-    x_laguerre, w_laguerre = gausslaguerre(nlaguerre)
     
-    #x_hlaguerre, w_hlaguerre = gausslaguerre(halfnquad)
-    x_vpa, w_vpa = Array{mk_float,1}(undef,4*nquad), Array{mk_float,1}(undef,4*nquad)
-    x_vperp, w_vperp = Array{mk_float,1}(undef,4*nquad), Array{mk_float,1}(undef,4*nquad)
+    x_vpa, w_vpa, x_vperp, w_vperp, x_legendre, w_legendre, x_laguerre, w_laguerre = setup_basic_quadratures(vpa,vperp)
     
     @serial_region begin
-        println("beginning weights calculation   ", Dates.format(now(), dateformat"H:MM:SS"))
+        println("beginning (boundary) weights calculation   ", Dates.format(now(), dateformat"H:MM:SS"))
     end
 
     # precalculate weights, integrating over Lagrange polynomials
-    # first compute weights along vpa boundaries
+    # first compute weights along lower vpa boundary
     begin_vperp_region()
-    ivpa = 1 #
+    ivpa = 1 # lower_vpa_boundary
     @loop_vperp ivperp begin
         #limits where checks required to determine which divergence-safe grid is needed
         igrid_vpa, ielement_vpa, ielement_vpa_low, ielement_vpa_hi, igrid_vperp, ielement_vperp, ielement_vperp_low, ielement_vperp_hi = get_element_limit_indices(ivpa,ivperp,vpa,vperp)
@@ -369,17 +364,91 @@ function init_Rosenbluth_potential_boundary_integration_weights!(G0_weights,
             end
         end
         # loop over elements and grid points within elements on primed coordinate
-        @views loop_over_vperp_vpa_elements!(G1_weights,H0_weights,H1_weights,H2_weights,H3_weights,
+        @views loop_over_vperp_vpa_elements!(G0_weights.lower_vpa_boundary[:,:,ivperp],
+                G1_weights.lower_vpa_boundary[:,:,ivperp],
+                H0_weights.lower_vpa_boundary[:,:,ivperp],
+                H1_weights.lower_vpa_boundary[:,:,ivperp],
+                H2_weights.lower_vpa_boundary[:,:,ivperp],
+                H3_weights.lower_vpa_boundary[:,:,ivperp],
                 vpa,ielement_vpa_low,ielement_vpa_hi, # info about primed vpa grids
                 vperp,ielement_vperp_low,ielement_vperp_hi, # info about primed vperp grids
                 x_vpa, w_vpa, x_vperp, w_vperp, # arrays to store points and weights for primed (source) grids
                 x_legendre,w_legendre,x_laguerre,w_laguerre,
-                igrid_vpa, igrid_vperp, vpa_val, vperp_val, ivpa, ivperp)
+                igrid_vpa, igrid_vperp, vpa_val, vperp_val)
     end
-    
-    # now compute the weights along the vperp boundary
+    # second compute weights along upper vpa boundary
+    ivpa = vpa.n # upper_vpa_boundary
+    @loop_vperp ivperp begin
+        #limits where checks required to determine which divergence-safe grid is needed
+        igrid_vpa, ielement_vpa, ielement_vpa_low, ielement_vpa_hi, igrid_vperp, ielement_vperp, ielement_vperp_low, ielement_vperp_hi = get_element_limit_indices(ivpa,ivperp,vpa,vperp)
+        
+        vperp_val = vperp.grid[ivperp]
+        vpa_val = vpa.grid[ivpa]
+        for ivperpp in 1:vperp.n
+            for ivpap in 1:vpa.n
+                G0_weights.upper_vpa_boundary[ivpap,ivperpp,ivperp] = 0.0  
+                G1_weights.upper_vpa_boundary[ivpap,ivperpp,ivperp] = 0.0  
+                # G2_weights[ivpap,ivperpp,ivpa,ivperp] = 0.0  
+                # G3_weights[ivpap,ivperpp,ivpa,ivperp] = 0.0  
+                H0_weights.upper_vpa_boundary[ivpap,ivperpp,ivperp] = 0.0  
+                H1_weights.upper_vpa_boundary[ivpap,ivperpp,ivperp] = 0.0  
+                H2_weights.upper_vpa_boundary[ivpap,ivperpp,ivperp] = 0.0  
+                H3_weights.upper_vpa_boundary[ivpap,ivperpp,ivperp] = 0.0  
+                #@. n_weights[ivpap,ivperpp,ivpa,ivperp] = 0.0  
+            end
+        end
+        # loop over elements and grid points within elements on primed coordinate
+        @views loop_over_vperp_vpa_elements!(G0_weights.upper_vpa_boundary[:,:,ivperp],
+                G1_weights.upper_vpa_boundary[:,:,ivperp],
+                H0_weights.upper_vpa_boundary[:,:,ivperp],
+                H1_weights.upper_vpa_boundary[:,:,ivperp],
+                H2_weights.upper_vpa_boundary[:,:,ivperp],
+                H3_weights.upper_vpa_boundary[:,:,ivperp],
+                vpa,ielement_vpa_low,ielement_vpa_hi, # info about primed vpa grids
+                vperp,ielement_vperp_low,ielement_vperp_hi, # info about primed vperp grids
+                x_vpa, w_vpa, x_vperp, w_vperp, # arrays to store points and weights for primed (source) grids
+                x_legendre,w_legendre,x_laguerre,w_laguerre,
+                igrid_vpa, igrid_vperp, vpa_val, vperp_val)
+    end
+    # finally compute weight along upper vperp boundary
+    begin_vpa_region()
+    ivperp = vperp.n # upper_vperp_boundary
+    @loop_vpa ivpa begin
+        #limits where checks required to determine which divergence-safe grid is needed
+        igrid_vpa, ielement_vpa, ielement_vpa_low, ielement_vpa_hi, igrid_vperp, ielement_vperp, ielement_vperp_low, ielement_vperp_hi = get_element_limit_indices(ivpa,ivperp,vpa,vperp)
+        
+        vperp_val = vperp.grid[ivperp]
+        vpa_val = vpa.grid[ivpa]
+        for ivperpp in 1:vperp.n
+            for ivpap in 1:vpa.n
+                G0_weights.upper_vperp_boundary[ivpap,ivperpp,ivpa] = 0.0  
+                G1_weights.upper_vperp_boundary[ivpap,ivperpp,ivpa] = 0.0  
+                # G2_weights[ivpap,ivperpp,ivpa,ivperp] = 0.0  
+                # G3_weights[ivpap,ivperpp,ivpa,ivperp] = 0.0  
+                H0_weights.upper_vperp_boundary[ivpap,ivperpp,ivpa] = 0.0  
+                H1_weights.upper_vperp_boundary[ivpap,ivperpp,ivpa] = 0.0  
+                H2_weights.upper_vperp_boundary[ivpap,ivperpp,ivpa] = 0.0  
+                H3_weights.upper_vperp_boundary[ivpap,ivperpp,ivpa] = 0.0  
+                #@. n_weights[ivpap,ivperpp,ivpa,ivperp] = 0.0  
+            end
+        end
+        # loop over elements and grid points within elements on primed coordinate
+        @views loop_over_vperp_vpa_elements!(G0_weights.upper_vperp_boundary[:,:,ivpa],
+                G1_weights.upper_vperp_boundary[:,:,ivpa],
+                H0_weights.upper_vperp_boundary[:,:,ivpa],
+                H1_weights.upper_vperp_boundary[:,:,ivpa],
+                H2_weights.upper_vperp_boundary[:,:,ivpa],
+                H3_weights.upper_vperp_boundary[:,:,ivpa],
+                vpa,ielement_vpa_low,ielement_vpa_hi, # info about primed vpa grids
+                vperp,ielement_vperp_low,ielement_vperp_hi, # info about primed vperp grids
+                x_vpa, w_vpa, x_vperp, w_vperp, # arrays to store points and weights for primed (source) grids
+                x_legendre,w_legendre,x_laguerre,w_laguerre,
+                igrid_vpa, igrid_vperp, vpa_val, vperp_val)
+    end
+    # return the parallelisation status to serial
+    begin_serial_region()
     @serial_region begin
-        println("finished weights calculation   ", Dates.format(now(), dateformat"H:MM:SS"))
+        println("finished (boundary) weights calculation   ", Dates.format(now(), dateformat"H:MM:SS"))
     end
     return nothing
 end