diff --git a/Python/pywarpx/fields.py b/Python/pywarpx/fields.py index 0100f64f261..cc70ec94532 100644 --- a/Python/pywarpx/fields.py +++ b/Python/pywarpx/fields.py @@ -317,7 +317,9 @@ def _get_field(self, mfi): ] return device_arr - def _get_intersect_slice(self, mfi, starts, stops, icstart, icstop): + def _get_intersect_slice( + self, mfi, starts, stops, icstart, icstop, with_internal_ghosts + ): """Return the slices where the block intersects with the global slice. If the block does not intersect, return None. This also shifts the block slices by the number of ghost cells in the @@ -344,6 +346,9 @@ def _get_intersect_slice(self, mfi, starts, stops, icstart, icstop): The maximum component index of the global slice. These can be negative. + with_internal_ghosts: bool + Whether the internal ghosts are included in the slices + Returns ------- block_slices: @@ -353,11 +358,25 @@ def _get_intersect_slice(self, mfi, starts, stops, icstart, icstop): The slice of the intersection relative to the global array where the data from individual block will go """ box = mfi.tilebox() + box_small_end = box.small_end + box_big_end = box.big_end if self.include_ghosts: - box.grow(self.mf.n_grow_vect) + nghosts = self.mf.n_grow_vect + box.grow(nghosts) + if with_internal_ghosts: + box_small_end = box.small_end + box_big_end = box.big_end + else: + min_box = self.mf.box_array().minimal_box() + for i in range(self.dim): + if box_small_end[i] == min_box.small_end[i]: + box_small_end[i] -= nghosts[i] + if box_big_end[i] == min_box.big_end[i]: + box_big_end[i] += nghosts[i] - ilo = self._get_indices(box.small_end, 0) - ihi = self._get_indices(box.big_end, 0) + boxlo = self._get_indices(box.small_end, 0) + ilo = self._get_indices(box_small_end, 0) + ihi = self._get_indices(box_big_end, 0) # Add 1 to the upper end to be consistent with the slicing notation ihi_p1 = [i + 1 for i in ihi] @@ -368,7 +387,7 @@ def _get_intersect_slice(self, mfi, starts, stops, icstart, icstop): block_slices = [] global_slices = [] for i in range(3): - block_slices.append(slice(i1[i] - ilo[i], i2[i] - ilo[i])) + block_slices.append(slice(i1[i] - boxlo[i], i2[i] - boxlo[i])) global_slices.append(slice(i1[i] - starts[i], i2[i] - starts[i])) block_slices.append(slice(icstart, icstop)) @@ -429,7 +448,7 @@ def __getitem__(self, index): datalist = [] for mfi in self.mf: block_slices, global_slices = self._get_intersect_slice( - mfi, starts, stops, icstart, icstop + mfi, starts, stops, icstart, icstop, False ) if global_slices is not None: # Note that the array will always have 4 dimensions. @@ -546,7 +565,7 @@ def __setitem__(self, index, value): stops = [ixstop, iystop, izstop] for mfi in self.mf: block_slices, global_slices = self._get_intersect_slice( - mfi, starts, stops, icstart, icstop + mfi, starts, stops, icstart, icstop, True ) if global_slices is not None: mf_arr = self._get_field(mfi)