diff --git a/splipy/SplineModel.py b/splipy/SplineModel.py index 6a723885..943e1ca1 100644 --- a/splipy/SplineModel.py +++ b/splipy/SplineModel.py @@ -4,7 +4,7 @@ from splipy.utils import * import splipy.state as state import numpy as np -from collections import Counter +from collections import Counter, defaultdict from itertools import chain, product, permutations try: @@ -114,7 +114,7 @@ def __init__(self, perm, flip): self.perm_inv = tuple(perm.index(d) for d in range(len(perm))) @classmethod - def compute(cls, cpa, cpb=None): + def compute(cls, cpa, cpb=None, interior=True): """Compute and return a new orientation object representing the mapping between `cpa` (the reference system) and `cpb` (the mapped system). @@ -136,10 +136,23 @@ def compute(cls, cpa, cpb=None): shape_b = cpb.controlpoints.shape - # Deal with the easy cases: dimension mismatch, and - # comparing the shapes as multisets + # Easy error checking if len(shape_a) != len(shape_b): raise OrientationError("Mismatching parametric dimensions") + + cpsa = cpa.controlpoints + cpsb = cpb.controlpoints + if not interior: + # Pick out just the corners + indexes = np.ix_(*[[0,-1] for _ in range(pardim)]) + indexes = list(indexes) + [slice(None)] + cpsa = cpsa[indexes] + cpsb = cpsb[indexes] + shape_a = cpsa.shape + shape_b = cpsb.shape + + # Deal with the rest of the easy cases: dimension mismatch, and + # comparing the shapes as multisets if shape_a[-1] != shape_b[-1]: raise OrientationError("Mismatching physical dimensions") if Counter(shape_a) != Counter(shape_b): @@ -147,17 +160,22 @@ def compute(cls, cpa, cpb=None): # Enumerate all permutations of directions for perm in permutations(range(pardim)): - transposed = cpb.controlpoints.transpose(perm + (pardim,)) + transposed = cpsb.transpose(perm + (pardim,)) if transposed.shape != shape_a: continue # Enumerate all possible direction reversals for flip in product([False, True], repeat=pardim): slices = tuple(slice(None, None, -1) if f else slice(None) for f in flip) test_b = transposed[slices + (slice(None),)] - if np.allclose(cpa.controlpoints, test_b, + if np.allclose(cpsa, test_b, rtol=state.controlpoint_relative_tolerance, atol=state.controlpoint_absolute_tolerance): - if all([cpa.bases[i].matches(cpb.bases[perm[i]], reverse=flip[i]) for i in range(pardim)]): + ok = ( + not interior or + all([cpa.bases[i].matches(cpb.bases[perm[i]], reverse=flip[i]) + for i in range(pardim)]) + ) + if ok: return cls(perm, flip) raise OrientationError("Non-matching objects") @@ -254,16 +272,19 @@ class TopologicalNode(object): of any kind. """ - def __init__(self, obj, lower_nodes): + def __init__(self, catalogue, obj, lower_nodes): """Initialize a `TopologicalNode` object associated with the given `SplineObject` and lower order nodes. + :param ObjectCatalogue catalogue: The catalogue to which this node + belongs :param SplineObject obj: The underlying spline object :param lower_nodes: A nested list of lower order nodes """ + self.catalogue = catalogue self.obj = obj self.lower_nodes = lower_nodes - self.higher_nodes = {} + self.higher_nodes = defaultdict(list) for dim_nodes in self.lower_nodes: for node in dim_nodes: @@ -275,7 +296,7 @@ def pardim(self): def assign_higher(self, node): """Add a link to a node of higher dimension.""" - self.higher_nodes.setdefault(node.pardim, set()).add(node) + self.higher_nodes[node.pardim].append(node) def view(self, other_obj=None): """Return a `NodeView` object of this node. @@ -288,9 +309,9 @@ def view(self, other_obj=None): underlying object """ if other_obj: - orientation = Orientation.compute(self.obj, other_obj) + orientation = self.catalogue.make_orientation(self.obj, other_obj) else: - orientation = Orientation.compute(self.obj) + orientation = self.catalogue.make_orientation(self.obj) return NodeView(self, orientation) @@ -334,7 +355,9 @@ def section(self, *args, **kwargs): # The underlying lower-order node may not have an orientation that # matches the higher-order node, so we need to compose two orientations - ref_ori = Orientation.compute(node.obj, self.node.obj.section(*section, unwrap_points=False)) + ref_ori = self.node.catalogue.make_orientation( + node.obj, self.node.obj.section(*section, unwrap_points=False) + ) my_ori = self.orientation.view_section(section) return NodeView(node, ref_ori * my_ori) @@ -372,19 +395,24 @@ class ObjectCatalogue(object): at most `pardim` parametric directions. """ - def __init__(self, pardim): + def __init__(self, pardim, interior=True): """Initialize a catalogue for objects of parametric dimension `pardim`. """ self.pardim = pardim # Internal mapping from tuples of lower-order nodes to lists of nodes - self.internal = {} + self.internal = defaultdict(list) + + # Function for computing orientations + self.make_orientation = ( + lambda *args, **kwargs: Orientation.compute(*args, interior=interior, **kwargs) + ) # Each catalogue has a catalogue of lower dimension # For points, we use a VertexDict if pardim > 0: - self.lower = ObjectCatalogue(pardim - 1) + self.lower = ObjectCatalogue(pardim - 1, interior=interior) else: self.lower = VertexDict() @@ -408,7 +436,7 @@ def lookup(self, obj, add=False): # Special case for points: self.lower is a mapping from array to node if self.pardim == 0: if add: - node = TopologicalNode(obj, []) + node = TopologicalNode(self, obj, []) return self.lower.setdefault(obj.controlpoints, node).view() return self.lower[obj.controlpoints].view() @@ -430,19 +458,24 @@ def lookup(self, obj, add=False): # identity view on it. try: for candidate_node in self.internal[lower_nodes[-1]]: - return candidate_node.view(obj) - # FIXME: It might be useful to optionally not silence OrientationError, - # since that more often than not indicates a real error - except (KeyError, OrientationError): - if not add: - raise KeyError("No such object found") - node = TopologicalNode(obj, lower_nodes) - # Assign the new node to each possible permutation of lower-order - # nodes. This is slight overkill since some of these permutations - # are invalid, but c'est la vie. - for p in permutations(lower_nodes[-1]): - self.internal.setdefault(p, []).append(node) - return node.view() + try: + return candidate_node.view(obj) + # FIXME: It might be useful to optionally not silence OrientationError, + # since that more often than not indicates a real error + except OrientationError: + pass + except KeyError: + pass + + if not add: + raise KeyError("No such object found") + node = TopologicalNode(self, obj, lower_nodes) + # Assign the new node to each possible permutation of lower-order + # nodes. This is slight overkill since some of these permutations + # are invalid, but c'est la vie. + for p in permutations(lower_nodes[-1]): + self.internal[p].append(node) + return node.view() def add(self, obj): """Add new nodes to the graph to accommodate the given object, then return the @@ -480,11 +513,11 @@ def nodes(self, pardim): class SplineModel(object): - def __init__(self, pardim=3, dimension=3, objs=[]): + def __init__(self, pardim=3, dimension=3, objs=[], interior=True): self.pardim = pardim self.dimension = dimension - self.catalogue = ObjectCatalogue(pardim) + self.catalogue = ObjectCatalogue(pardim, interior=interior) self.add_patches(objs) def add_patch(self, obj):