Skip to content

Commit

Permalink
Refactor extent* functions
Browse files Browse the repository at this point in the history
  • Loading branch information
andykee committed Mar 8, 2024
1 parent ca8819e commit ea74fc4
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 71 deletions.
15 changes: 13 additions & 2 deletions docs/ref/internals.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,23 @@ Field

lentil.field.Field
lentil.field.boundary
lentil.field.extent
lentil.field.insert
lentil.field.merge
lentil.field.overlap
lentil.field.reduce

Extent
------
.. autosummary::
:toctree: generated/

lentil.extent.array_extent
lentil.extent.intersect
lentil.extent.intersection_extent
lentil.extent.intersection_shape
lentil.extent.intersection_slices
lentil.extent.intersection_shift

Tilt interface
--------------
.. autosummary::
Expand All @@ -45,6 +56,6 @@ Helper functions
:toctree: generated/

lentil.helper.mesh
lentil.helper.gaussian2d
lentil.helper.boundary_slice
lentil.helper.slice_offset
lentil.helper.gaussian2d
101 changes: 101 additions & 0 deletions lentil/extent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Functions for working with extent (rmon, rmax, cmin, cmax) data

import numpy as np

def array_extent(shape, shift, parent_shape=None):
"""Compute the extent of a shifted array.
Parameters
----------
shape : (2,) array_like
shift : (2,) array_like
parent_shape : (2,) array like or None, optional
Notes
-----
To use the values returned by ``extent()`` in a slice,
``rmax`` and ``cmax`` should be increased by 1.
"""
if len(shape) < 2:
shape = (1, 1)

rmin = int(-(shape[0]//2) + shift[0])
cmin = int(-(shape[1]//2) + shift[1])
rmax = int(rmin + shape[0] - 1)
cmax = int(cmin + shape[1] - 1)

if parent_shape is not None:
parent_center = np.asarray(parent_shape)//2
rmin += parent_center[0]
rmax += parent_center[0]
cmin += parent_center[1]
cmax += parent_center[1]

return rmin, rmax, cmin, cmax


def intersect(a, b):
"""Return True if two extents intersect, otherwise False
Parameters
----------
a, b : (4,) array like
Two array extents (rmin, rmax, cmin, cmax)
Returns
-------
bool
"""

armin, armax, acmin, acmax = a
brmin, brmax, bcmin, bcmax = b
return armin <= brmax and armax >= brmin and acmin <= bcmax and acmax >= bcmin


def intersection_extent(a, b):
# bounding array indices to be multiplied
armin, armax, acmin, acmax = a
brmin, brmax, bcmin, bcmax = b

rmin, rmax = max(armin, brmin), min(armax, brmax)
cmin, cmax = max(acmin, bcmin), min(acmax, bcmax)

return rmin, rmax, cmin, cmax


def intersection_shape(a, b):
"""Compute the shape
"""

rmin, rmax, cmin, cmax = intersection_extent(a, b)
nr, nc = rmax - rmin + 1, cmax - cmin + 1

if nr < 0 or nc < 0:
shape = ()
else:
shape = (nr, nc)

return shape


def intersection_slices(a, b):
rmin, rmax, cmin, cmax = intersection_extent(a, b)

armin, armax, acmin, acmax = a
brmin, brmax, bcmin, bcmax = b

arow = slice(rmin-armin, rmax-armin+1)
acol = slice(cmin-acmin, cmax-acmin+1)
brow = slice(rmin-brmin, rmax-brmin+1)
bcol = slice(cmin-bcmin, cmax-bcmin+1)

return (arow, acol), (brow, bcol)


def intersection_shift(a, b):
rmin, rmax, cmin, cmax = intersection_extent(a, b)
nrow = rmax - rmin + 1
ncol = cmax - cmin + 1
return rmin + nrow//2, cmin + ncol//2
78 changes: 9 additions & 69 deletions lentil/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from itertools import combinations
import numpy as np

import lentil.extent

class Field:
"""
Two-dimensional discretely sampled complex field.
Expand Down Expand Up @@ -51,7 +53,7 @@ def __init__(self, data, pixelscale=None, offset=None, tilt=None):
self.tilt = tilt if tilt else []

#: tuple of ints : Extent of ``data``
self.extent = extent(self.shape, self.offset)
self.extent = lentil.extent.array_extent(self.shape, self.offset)

@property
def shape(self):
Expand Down Expand Up @@ -132,11 +134,11 @@ def _mul_array(self, other):
self_data, self_offset, other_data, other_offset = _mul_broadcast(
self.data, self.offset, other.data, other.offset
)
self_extent = extent(self_data.shape, self_offset)
other_extent = extent(other_data.shape, other_offset)
self_extent = lentil.extent.array_extent(self_data.shape, self_offset)
other_extent = lentil.extent.array_extent(other_data.shape, other_offset)

self_slice, other_slice = _mul_slices(self_extent, other_extent)
offset = _mul_offset(self_extent, other_extent)
self_slice, other_slice = lentil.extent.intersection_slices(self_extent, other_extent)
offset = lentil.extent.intersection_shift(self_extent, other_extent)
data = self_data[self_slice] * other_data[other_slice]

if data.size == 0:
Expand Down Expand Up @@ -226,25 +228,6 @@ def boundary(fields):

return rmin, rmax, cmin, cmax


def extent(shape, offset):
"""
Compute the extent of a shifted array.
Note: To use the values returned by ``extent()`` in a slice,
``rmax`` and ``cmax`` should be increased by 1.
"""
if len(shape) < 2:
shape = (1, 1)

rmin = int(-(shape[0]//2) + offset[0])
cmin = int(-(shape[1]//2) + offset[1])
rmax = int(rmin + shape[0] - 1)
cmax = int(cmin + shape[1] - 1)

return rmin, rmax, cmin, cmax


def insert(field, out, intensity=False, weight=1):
"""Insert a field into an array.
Expand Down Expand Up @@ -417,10 +400,8 @@ def overlap(fields):
overlap : bool
"""
#return _overlap(a.extent, b.extent)

if len(fields) == 2:
return _overlap(fields[0].extent, fields[1].extent)
return lentil.extent.intersect(fields[0].extent, fields[1].extent)
else:
fields = _reduce(fields)
if len(fields) > 1:
Expand Down Expand Up @@ -472,23 +453,14 @@ def _disjoint(fields):
Return fields as a disjoint set.
"""
for m, n in combinations(range(len(fields)), 2):
if _overlap(fields[m]['extent'], fields[n]['extent']):
if lentil.extent.intersect(fields[m]['extent'], fields[n]['extent']):
fields[m]['field'].extend(fields[n]['field'])
fields[m]['extent'] = boundary(fields[m]['field'])
fields.pop(n)
return _disjoint(fields)
return fields


def _overlap(a_extent, b_extent):
"""
Return True if two extents overlap, otherwise False
"""
armin, armax, acmin, acmax = a_extent
brmin, brmax, bcmin, bcmax = b_extent
return armin <= brmax and armax >= brmin and acmin <= bcmax and acmax >= bcmin


def _mul_broadcast(a_data, a_offset, b_data, b_offset):
"""
Broadcast for multiplication.
Expand All @@ -510,35 +482,3 @@ def _mul_broadcast(a_data, a_offset, b_data, b_offset):
b_data = np.broadcast_to(b_data, a_data.shape)
b_offset = a_offset
return a_data, a_offset, b_data, b_offset


def _mul_boundary(a_extent, b_extent):
# bounding array indices to be multiplied
armin, armax, acmin, acmax = a_extent
brmin, brmax, bcmin, bcmax = b_extent

rmin, rmax = max(armin, brmin), min(armax, brmax)
cmin, cmax = max(acmin, bcmin), min(acmax, bcmax)

return rmin, rmax, cmin, cmax


def _mul_slices(a_extent, b_extent):
rmin, rmax, cmin, cmax = _mul_boundary(a_extent, b_extent)

armin, armax, acmin, acmax = a_extent
brmin, brmax, bcmin, bcmax = b_extent

arow = slice(rmin-armin, rmax-armin+1)
acol = slice(cmin-acmin, cmax-acmin+1)
brow = slice(rmin-brmin, rmax-brmin+1)
bcol = slice(cmin-bcmin, cmax-bcmin+1)

return (arow, acol), (brow, bcol)


def _mul_offset(a_extent, b_extent):
rmin, rmax, cmin, cmax = _mul_boundary(a_extent, b_extent)
nrow = rmax - rmin + 1
ncol = cmax - cmin + 1
return rmin + nrow//2, cmin + ncol//2

0 comments on commit ea74fc4

Please sign in to comment.