Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
joostvanzwieten committed Dec 23, 2019
1 parent 8ef641e commit 574a460
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 14 deletions.
12 changes: 12 additions & 0 deletions nutils/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,18 @@ def edges(self):
def children(self):
return list(zip(self.child_transforms, self.child_refs))

def get_from_trans(self, trans):
if not trans:
return self
if trans[0].todims != self.ndims:
raise ValueError('Expected a transform chain that maps to {} dims but got {}.'.format(self.ndims, trans[0].todims))
if trans[0].fromdims == self.ndims:
return self.child_refs[self.child_transforms.index(trans[0])].get_from_trans(trans[1:])
elif trans[0].fromdims == self.ndims-1:
return self.edge_refs[self.edge_transforms.index(trans[0])].get_from_trans(trans[1:])
else:
raise ValueError

@property
def connectivity(self):
# Nested tuple with connectivity information about edges of children:
Expand Down
30 changes: 16 additions & 14 deletions nutils/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -2043,8 +2043,8 @@ def getitem(self, item):

class PartitionedTopology(Topology):

__slots__ = 'basetopo', 'refs', 'names', 'nparts', '_partreferences', '_partboundaries', '_partopposites'
__cache__ = 'boundary', 'interfaces'
__slots__ = 'basetopo', 'refs', 'names', 'nparts', '_partreferences', '_parttransforms', '_partopposites'
__cache__ = 'boundary'

@types.apply_annotations
def __init__(self, basetopo:stricttopology, refs:types.tuple[types.tuple[element.strictreference]], names:types.tuple[types.strictstr]):
Expand All @@ -2064,10 +2064,10 @@ def __init__(self, basetopo:stricttopology, refs:types.tuple[types.tuple[element
self.names = names

indices = tuple(types.frozenarray(numpy.where(list(map(bool, prefs)))[0]) for prefs in zip(*refs))
self._partreferences = tuple(elemenseq.asreferences((ref for ref in prefs if ref), self.basetopo.ndims) for prefs in zip(*refs))
self._parttransforms = tuple(WithIdentifierTransforms(self.basetopo.transforms[i], name) for name, i in zip(names, indices))
self._partopposites = tuple(WithIdentifierTransforms(self.basetopo.opposites[i], name) for name, i in zip(names, indices))
super().__init__(elemenseq.chain(self._partreferences, self.basetopo.ndims),
self._partreferences = tuple(elementseq.asreferences((ref for ref in prefs if ref), self.basetopo.ndims) for prefs in zip(*refs))
self._parttransforms = tuple(transformseq.WithIdentifierTransforms(self.basetopo.transforms[i], name) for name, i in zip(names, indices))
self._partopposites = tuple(transformseq.WithIdentifierTransforms(self.basetopo.opposites[i], name) for name, i in zip(names, indices))
super().__init__(elementseq.chain(self._partreferences, self.basetopo.ndims),
transformseq.chain(self._parttransforms, self.basetopo.ndims),
transformseq.chain(self._partopposites, self.basetopo.ndims))

Expand All @@ -2084,8 +2084,7 @@ def boundary(self):
brefs = []
for bref, btrans in zip(baseboundary.references, baseboundary.transforms):
ielem, etrans = self.basetopo.transforms.index_with_tail(btrans)
iedge = self.basetopo.references[ielem].edge_index(etrans)
brefs.append(pref.edge_refs[iedge] for pref in self.refs[ielem])
brefs.append(pref.get_from_trans(etrans) for pref in self.refs[ielem])
return PartitionedTopology(baseboundary, brefs, self.names)

# @property
Expand Down Expand Up @@ -2200,14 +2199,17 @@ class _SubsetOfPartitionedTopology(Topology):

__slots__ = '_partition', '_names'

def __init__(self, partition: stricttopology, names: types.frozenset(types.strictstr)):
@types.apply_annotations
def __init__(self, partition: stricttopology, names: frozenset):
self._partition = partition
if not names < partition.names:
raise ValueError('Not a (strict) subset of the partition.')
if not names <= frozenset(partition.names):
raise ValueError('Not a subset of the partition.')
if not all(isinstance(name, str) for name in names):
raise ValueError('All names should be str objects.')
self._names = tuple(sorted(names, key=partition.names.index))
super().__init__(elemenseq.chain((self._partition._partreferences[name] for name in self._names), partition.ndims),
transformseq.chain((self._partition._parttransforms[name] for name in self._names), partition.ndims),
transformseq.chain((self._partition._partopposites[name] for name in self._names), partition.ndims))
super().__init__(elementseq.chain((self._partition._partreferences[partition.names.index(name)] for name in self._names), partition.ndims),
transformseq.chain((self._partition._parttransforms[partition.names.index(name)] for name in self._names), partition.ndims),
transformseq.chain((self._partition._partopposites[partition.names.index(name)] for name in self._names), partition.ndims))

def boundary(self):
topos = []
Expand Down

0 comments on commit 574a460

Please sign in to comment.