Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python: use valid cells for fields getitem #4370

16 changes: 8 additions & 8 deletions Examples/Tests/python_wrappers/PICMI_inputs_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,29 +273,29 @@ def check_values(benchmark, data, rtol, atol):
check_values(3.0188584528062377, F[:,:], rtol, atol)
check_values(1013672631.8764204, G[:,:], rtol, atol)
# E in PML
check_values(364287936.1526477 , Expml[:,:,0], rtol, atol)
check_values(31311885.4695534 , Expml[:,:,0], rtol, atol)
check_values(183582352.20753333, Expml[:,:,1], rtol, atol)
check_values(190065766.41491824, Expml[:,:,2], rtol, atol)
check_values(440581907.0828975 , Eypml[:,:,0], rtol, atol)
check_values(219685045.6889916 , Eypml[:,:,0], rtol, atol)
check_values(178117294.05871135, Eypml[:,:,1], rtol, atol)
check_values(0.0 , Eypml[:,:,2], rtol, atol)
check_values(430277101.26568377, Ezpml[:,:,0], rtol, atol)
check_values(205162994.3170565 , Ezpml[:,:,0], rtol, atol)
check_values(0.0 , Ezpml[:,:,1], rtol, atol)
check_values(190919663.2167449 , Ezpml[:,:,2], rtol, atol)
# B in PML
check_values(1.0565189315366146 , Bxpml[:,:,0], rtol, atol)
check_values(0.08431013332510595, Bxpml[:,:,0], rtol, atol)
check_values(0.46181913800643065, Bxpml[:,:,1], rtol, atol)
check_values(0.6849858305343736 , Bxpml[:,:,2], rtol, atol)
check_values(1.7228584190213505 , Bypml[:,:,0], rtol, atol)
check_values(0.7303366134901944 , Bypml[:,:,0], rtol, atol)
check_values(0.47697332248020935, Bypml[:,:,1], rtol, atol)
check_values(0.0 , Bypml[:,:,2], rtol, atol)
check_values(1.518338068658267 , Bzpml[:,:,0], rtol, atol)
check_values(0.5461292699653959 , Bzpml[:,:,0], rtol, atol)
check_values(0.0 , Bzpml[:,:,1], rtol, atol)
check_values(0.6849858291863835 , Bzpml[:,:,2], rtol, atol)
# F and G in PML
check_values(1.7808748509425263, Fpml[:,:,0], rtol, atol)
check_values(0.8275886762933928, Fpml[:,:,0], rtol, atol)
check_values(0.0 , Fpml[:,:,1], rtol, atol)
check_values(0.4307845604625681, Fpml[:,:,2], rtol, atol)
check_values(536552745.42701197, Gpml[:,:,0], rtol, atol)
check_values(196016270.8729728 , Gpml[:,:,0], rtol, atol)
check_values(0.0 , Gpml[:,:,1], rtol, atol)
check_values(196016270.97767758, Gpml[:,:,2], rtol, atol)
34 changes: 26 additions & 8 deletions Python/pywarpx/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@
device_arr = device_arr[tuple([slice(ng, -ng) for ng in nghosts[:self.dim]])]
return device_arr

def _get_intersect_slice(self, mfi, starts, stops, icstart, icstop):
def _get_intersect_slice(self, mfi, starts, stops, icstart, icstop, with_internal_ghosts):
"""Return the slices where the block intersects with the global slice.
If the block does not intersect, return None.
This also shifts the block slices by the number of ghost cells in the
Expand All @@ -331,6 +331,9 @@
The maximum component index of the global slice.
These can be negative.

with_internal_ghosts: bool
Whether the internal ghosts are included in the slices

Returns
-------
block_slices:
Expand All @@ -340,11 +343,26 @@
The slice of the intersection relative to the global array where the data from individual block will go
"""
box = mfi.tilebox()
box_small_end = box.small_end
box_big_end = box.big_end
if self.include_ghosts:
box.grow(self.mf.n_grow_vect)

ilo = self._get_indices(box.small_end, 0)
ihi = self._get_indices(box.big_end, 0)
nghosts = self.mf.n_grow_vect
box.grow(nghosts)
if with_internal_ghosts:
box_small_end = box.small_end
box_big_end = box.big_end
else:
min_box = self.mf.box_array().minimal_box()
for i in range(self.dim):
if box_small_end[i] == min_box.small_end[i]:
box_small_end[i] -= nghosts[i]
if box_big_end[i] == min_box.big_end[i]:
box_big_end[i] += nghosts[i]

boxlo = self._get_indices(box.small_end, 0)
boxhi = self._get_indices(box.big_end, 0)
Fixed Show fixed Hide fixed
ilo = self._get_indices(box_small_end, 0)
ihi = self._get_indices(box_big_end, 0)

# Add 1 to the upper end to be consistent with the slicing notation
ihi_p1 = [i + 1 for i in ihi]
Expand All @@ -356,7 +374,7 @@
block_slices = []
global_slices = []
for i in range(3):
block_slices.append(slice(i1[i] - ilo[i], i2[i] - ilo[i]))
block_slices.append(slice(i1[i] - boxlo[i], i2[i] - boxlo[i]))
global_slices.append(slice(i1[i] - starts[i], i2[i] - starts[i]))

block_slices.append(slice(icstart, icstop))
Expand Down Expand Up @@ -416,7 +434,7 @@
stops = [ixstop, iystop, izstop]
datalist = []
for mfi in self.mf:
block_slices, global_slices = self._get_intersect_slice(mfi, starts, stops, icstart, icstop)
block_slices, global_slices = self._get_intersect_slice(mfi, starts, stops, icstart, icstop, False)
if global_slices is not None:
# Note that the array will always have 4 dimensions.
device_arr = self._get_field(mfi)
Expand Down Expand Up @@ -519,7 +537,7 @@
starts = [ixstart, iystart, izstart]
stops = [ixstop, iystop, izstop]
for mfi in self.mf:
block_slices, global_slices = self._get_intersect_slice(mfi, starts, stops, icstart, icstop)
block_slices, global_slices = self._get_intersect_slice(mfi, starts, stops, icstart, icstop, True)
if global_slices is not None:
mf_arr = self._get_field(mfi)
if isinstance(value, np.ndarray):
Expand Down
Loading