Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python: use valid cells for fields getitem #4370

33 changes: 26 additions & 7 deletions Python/pywarpx/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,9 @@ def _get_field(self, mfi):
]
return device_arr

def _get_intersect_slice(self, mfi, starts, stops, icstart, icstop):
def _get_intersect_slice(
self, mfi, starts, stops, icstart, icstop, with_internal_ghosts
):
"""Return the slices where the block intersects with the global slice.
If the block does not intersect, return None.
This also shifts the block slices by the number of ghost cells in the
Expand All @@ -344,6 +346,9 @@ def _get_intersect_slice(self, mfi, starts, stops, icstart, icstop):
The maximum component index of the global slice.
These can be negative.

with_internal_ghosts: bool
Whether the internal ghosts are included in the slices

Returns
-------
block_slices:
Expand All @@ -353,11 +358,25 @@ def _get_intersect_slice(self, mfi, starts, stops, icstart, icstop):
The slice of the intersection relative to the global array where the data from individual block will go
"""
box = mfi.tilebox()
box_small_end = box.small_end
box_big_end = box.big_end
if self.include_ghosts:
box.grow(self.mf.n_grow_vect)
nghosts = self.mf.n_grow_vect
box.grow(nghosts)
if with_internal_ghosts:
box_small_end = box.small_end
box_big_end = box.big_end
else:
min_box = self.mf.box_array().minimal_box()
for i in range(self.dim):
if box_small_end[i] == min_box.small_end[i]:
box_small_end[i] -= nghosts[i]
if box_big_end[i] == min_box.big_end[i]:
box_big_end[i] += nghosts[i]

ilo = self._get_indices(box.small_end, 0)
ihi = self._get_indices(box.big_end, 0)
boxlo = self._get_indices(box.small_end, 0)
ilo = self._get_indices(box_small_end, 0)
ihi = self._get_indices(box_big_end, 0)

# Add 1 to the upper end to be consistent with the slicing notation
ihi_p1 = [i + 1 for i in ihi]
Expand All @@ -368,7 +387,7 @@ def _get_intersect_slice(self, mfi, starts, stops, icstart, icstop):
block_slices = []
global_slices = []
for i in range(3):
block_slices.append(slice(i1[i] - ilo[i], i2[i] - ilo[i]))
block_slices.append(slice(i1[i] - boxlo[i], i2[i] - boxlo[i]))
global_slices.append(slice(i1[i] - starts[i], i2[i] - starts[i]))

block_slices.append(slice(icstart, icstop))
Expand Down Expand Up @@ -429,7 +448,7 @@ def __getitem__(self, index):
datalist = []
for mfi in self.mf:
block_slices, global_slices = self._get_intersect_slice(
mfi, starts, stops, icstart, icstop
mfi, starts, stops, icstart, icstop, False
)
if global_slices is not None:
# Note that the array will always have 4 dimensions.
Expand Down Expand Up @@ -546,7 +565,7 @@ def __setitem__(self, index, value):
stops = [ixstop, iystop, izstop]
for mfi in self.mf:
block_slices, global_slices = self._get_intersect_slice(
mfi, starts, stops, icstart, icstop
mfi, starts, stops, icstart, icstop, True
)
if global_slices is not None:
mf_arr = self._get_field(mfi)
Expand Down
Loading