diff --git a/docs/ref/internals.rst b/docs/ref/internals.rst index 9d5fa2c..0fcf633 100644 --- a/docs/ref/internals.rst +++ b/docs/ref/internals.rst @@ -19,12 +19,23 @@ Field lentil.field.Field lentil.field.boundary - lentil.field.extent lentil.field.insert lentil.field.merge lentil.field.overlap lentil.field.reduce +Extent +------ +.. autosummary:: + :toctree: generated/ + + lentil.extent.array_extent + lentil.extent.intersect + lentil.extent.intersection_extent + lentil.extent.intersection_shape + lentil.extent.intersection_slices + lentil.extent.intersection_shift + Tilt interface -------------- .. autosummary:: @@ -45,6 +56,6 @@ Helper functions :toctree: generated/ lentil.helper.mesh - lentil.helper.gaussian2d lentil.helper.boundary_slice lentil.helper.slice_offset + lentil.helper.gaussian2d diff --git a/lentil/extent.py b/lentil/extent.py new file mode 100644 index 0000000..67f37ff --- /dev/null +++ b/lentil/extent.py @@ -0,0 +1,101 @@ +# Functions for working with extent (rmon, rmax, cmin, cmax) data + +import numpy as np + +def array_extent(shape, shift, parent_shape=None): + """Compute the extent of a shifted array. + + Parameters + ---------- + shape : (2,) array_like + + shift : (2,) array_like + + parent_shape : (2,) array like or None, optional + + Notes + ----- + To use the values returned by ``extent()`` in a slice, + ``rmax`` and ``cmax`` should be increased by 1. + """ + if len(shape) < 2: + shape = (1, 1) + + rmin = int(-(shape[0]//2) + shift[0]) + cmin = int(-(shape[1]//2) + shift[1]) + rmax = int(rmin + shape[0] - 1) + cmax = int(cmin + shape[1] - 1) + + if parent_shape is not None: + parent_center = np.asarray(parent_shape)//2 + rmin += parent_center[0] + rmax += parent_center[0] + cmin += parent_center[1] + cmax += parent_center[1] + + return rmin, rmax, cmin, cmax + + +def intersect(a, b): + """Return True if two extents intersect, otherwise False + + Parameters + ---------- + a, b : (4,) array like + Two array extents (rmin, rmax, cmin, cmax) + + Returns + ------- + bool + """ + + armin, armax, acmin, acmax = a + brmin, brmax, bcmin, bcmax = b + return armin <= brmax and armax >= brmin and acmin <= bcmax and acmax >= bcmin + + +def intersection_extent(a, b): + # bounding array indices to be multiplied + armin, armax, acmin, acmax = a + brmin, brmax, bcmin, bcmax = b + + rmin, rmax = max(armin, brmin), min(armax, brmax) + cmin, cmax = max(acmin, bcmin), min(acmax, bcmax) + + return rmin, rmax, cmin, cmax + + +def intersection_shape(a, b): + """Compute the shape + """ + + rmin, rmax, cmin, cmax = intersection_extent(a, b) + nr, nc = rmax - rmin + 1, cmax - cmin + 1 + + if nr < 0 or nc < 0: + shape = () + else: + shape = (nr, nc) + + return shape + + +def intersection_slices(a, b): + rmin, rmax, cmin, cmax = intersection_extent(a, b) + + armin, armax, acmin, acmax = a + brmin, brmax, bcmin, bcmax = b + + arow = slice(rmin-armin, rmax-armin+1) + acol = slice(cmin-acmin, cmax-acmin+1) + brow = slice(rmin-brmin, rmax-brmin+1) + bcol = slice(cmin-bcmin, cmax-bcmin+1) + + return (arow, acol), (brow, bcol) + + +def intersection_shift(a, b): + rmin, rmax, cmin, cmax = intersection_extent(a, b) + nrow = rmax - rmin + 1 + ncol = cmax - cmin + 1 + return rmin + nrow//2, cmin + ncol//2 \ No newline at end of file diff --git a/lentil/field.py b/lentil/field.py index 980423c..96709b2 100644 --- a/lentil/field.py +++ b/lentil/field.py @@ -2,6 +2,8 @@ from itertools import combinations import numpy as np +import lentil.extent + class Field: """ Two-dimensional discretely sampled complex field. @@ -51,7 +53,7 @@ def __init__(self, data, pixelscale=None, offset=None, tilt=None): self.tilt = tilt if tilt else [] #: tuple of ints : Extent of ``data`` - self.extent = extent(self.shape, self.offset) + self.extent = lentil.extent.array_extent(self.shape, self.offset) @property def shape(self): @@ -132,11 +134,11 @@ def _mul_array(self, other): self_data, self_offset, other_data, other_offset = _mul_broadcast( self.data, self.offset, other.data, other.offset ) - self_extent = extent(self_data.shape, self_offset) - other_extent = extent(other_data.shape, other_offset) + self_extent = lentil.extent.array_extent(self_data.shape, self_offset) + other_extent = lentil.extent.array_extent(other_data.shape, other_offset) - self_slice, other_slice = _mul_slices(self_extent, other_extent) - offset = _mul_offset(self_extent, other_extent) + self_slice, other_slice = lentil.extent.intersection_slices(self_extent, other_extent) + offset = lentil.extent.intersection_shift(self_extent, other_extent) data = self_data[self_slice] * other_data[other_slice] if data.size == 0: @@ -226,25 +228,6 @@ def boundary(fields): return rmin, rmax, cmin, cmax - -def extent(shape, offset): - """ - Compute the extent of a shifted array. - - Note: To use the values returned by ``extent()`` in a slice, - ``rmax`` and ``cmax`` should be increased by 1. - """ - if len(shape) < 2: - shape = (1, 1) - - rmin = int(-(shape[0]//2) + offset[0]) - cmin = int(-(shape[1]//2) + offset[1]) - rmax = int(rmin + shape[0] - 1) - cmax = int(cmin + shape[1] - 1) - - return rmin, rmax, cmin, cmax - - def insert(field, out, intensity=False, weight=1): """Insert a field into an array. @@ -417,10 +400,8 @@ def overlap(fields): overlap : bool """ - #return _overlap(a.extent, b.extent) - if len(fields) == 2: - return _overlap(fields[0].extent, fields[1].extent) + return lentil.extent.intersect(fields[0].extent, fields[1].extent) else: fields = _reduce(fields) if len(fields) > 1: @@ -472,7 +453,7 @@ def _disjoint(fields): Return fields as a disjoint set. """ for m, n in combinations(range(len(fields)), 2): - if _overlap(fields[m]['extent'], fields[n]['extent']): + if lentil.extent.intersect(fields[m]['extent'], fields[n]['extent']): fields[m]['field'].extend(fields[n]['field']) fields[m]['extent'] = boundary(fields[m]['field']) fields.pop(n) @@ -480,15 +461,6 @@ def _disjoint(fields): return fields -def _overlap(a_extent, b_extent): - """ - Return True if two extents overlap, otherwise False - """ - armin, armax, acmin, acmax = a_extent - brmin, brmax, bcmin, bcmax = b_extent - return armin <= brmax and armax >= brmin and acmin <= bcmax and acmax >= bcmin - - def _mul_broadcast(a_data, a_offset, b_data, b_offset): """ Broadcast for multiplication. @@ -510,35 +482,3 @@ def _mul_broadcast(a_data, a_offset, b_data, b_offset): b_data = np.broadcast_to(b_data, a_data.shape) b_offset = a_offset return a_data, a_offset, b_data, b_offset - - -def _mul_boundary(a_extent, b_extent): - # bounding array indices to be multiplied - armin, armax, acmin, acmax = a_extent - brmin, brmax, bcmin, bcmax = b_extent - - rmin, rmax = max(armin, brmin), min(armax, brmax) - cmin, cmax = max(acmin, bcmin), min(acmax, bcmax) - - return rmin, rmax, cmin, cmax - - -def _mul_slices(a_extent, b_extent): - rmin, rmax, cmin, cmax = _mul_boundary(a_extent, b_extent) - - armin, armax, acmin, acmax = a_extent - brmin, brmax, bcmin, bcmax = b_extent - - arow = slice(rmin-armin, rmax-armin+1) - acol = slice(cmin-acmin, cmax-acmin+1) - brow = slice(rmin-brmin, rmax-brmin+1) - bcol = slice(cmin-bcmin, cmax-bcmin+1) - - return (arow, acol), (brow, bcol) - - -def _mul_offset(a_extent, b_extent): - rmin, rmax, cmin, cmax = _mul_boundary(a_extent, b_extent) - nrow = rmax - rmin + 1 - ncol = cmax - cmin + 1 - return rmin + nrow//2, cmin + ncol//2