diff --git a/src/lumo/data/builder.py b/src/lumo/data/builder.py index 066e01a..63de7f7 100644 --- a/src/lumo/data/builder.py +++ b/src/lumo/data/builder.py @@ -27,16 +27,6 @@ def __init__(self): self._iter_cache = {} - def copy(self): - db = DatasetBuilder() - db._prop = copy.copy(self._prop) - db._idx_keys = copy.copy(self._idx_keys) - db._data = copy.copy(self._data) - db._outs = copy.deepcopy(self._outs) - db._transforms = copy.copy(self._transforms) - db._outkeys = copy.copy(self._outkeys) - return db - def __repr__(self): if self.sized: @@ -129,6 +119,23 @@ def _update_len(self): self._prop['__clen__'] = res return res + @property + def inputs(self): + return self._data + + @property + def outputs(self): + mapping = {} + for key, outkeys in self._outs.items(): + if key == '::idx::': + source = range(len(self)) + else: + source = self._data[key] + + for outkey in outkeys: + mapping[outkey] = source + return mapping + @property def mode(self): return self._prop.get('mode', 'zip') @@ -149,6 +156,16 @@ def pseudo_length(self) -> int: def pseudo_repeat(self) -> int: return self._prop.get('pseudo_repeat', None) + def copy(self): + db = DatasetBuilder() + db._prop = copy.copy(self._prop) + db._idx_keys = copy.copy(self._idx_keys) + db._data = copy.copy(self._data) + db._outs = copy.deepcopy(self._outs) + db._transforms = copy.copy(self._transforms) + db._outkeys = copy.copy(self._outkeys) + return db + def subset(self, indices: Sequence[int]): self._prop['subindices'] = np.array(indices) self._update_len()