diff --git a/weaver/utils/data/config.py b/weaver/utils/data/config.py index bfb6a24a..e8966ab4 100644 --- a/weaver/utils/data/config.py +++ b/weaver/utils/data/config.py @@ -56,6 +56,11 @@ def __init__(self, print_info=True, **kwargs): if print_info: _logger.debug(opts) + self.train_load_branches = set() + self.train_aux_branches = set() + self.test_load_branches = set() + self.test_aux_branches = set() + self.selection = opts['selection'] self.test_time_selection = opts['test_time_selection'] if opts['test_time_selection'] else self.selection self.var_funcs = copy.deepcopy(opts['new_variables']) @@ -101,26 +106,27 @@ def _get(idx, default): assert (isinstance(self.label_value, list)) self.label_names = ('_label_',) label_exprs = ['ak.to_numpy(%s)' % k for k in self.label_value] - self.var_funcs['_label_'] = 'np.argmax(np.stack([%s], axis=1), axis=1)' % (','.join(label_exprs)) - self.var_funcs['_labelcheck_'] = 'np.sum(np.stack([%s], axis=1), axis=1)' % (','.join(label_exprs)) + self.register('_label_', 'np.argmax(np.stack([%s], axis=1), axis=1)' % (','.join(label_exprs))) + self.register('_labelcheck_', 'np.sum(np.stack([%s], axis=1), axis=1)' % (','.join(label_exprs)), 'train') else: self.label_names = tuple(self.label_value.keys()) - self.var_funcs.update(self.label_value) + self.register(self.label_value) self.basewgt_name = '_basewgt_' self.weight_name = None if opts['weights'] is not None: - self.weight_name = 'weight_' + self.weight_name = '_weight_' self.use_precomputed_weights = opts['weights']['use_precomputed_weights'] if self.use_precomputed_weights: - self.var_funcs[self.weight_name] = '*'.join(opts['weights']['weight_branches']) + self.register(self.weight_name, '*'.join(opts['weights']['weight_branches']), 'train') else: self.reweight_method = opts['weights']['reweight_method'] self.reweight_basewgt = opts['weights'].get('reweight_basewgt', None) if self.reweight_basewgt: - self.var_funcs[self.basewgt_name] = self.reweight_basewgt + self.register(self.basewgt_name, self.reweight_basewgt, 'train') self.reweight_branches = tuple(opts['weights']['reweight_vars'].keys()) self.reweight_bins = tuple(opts['weights']['reweight_vars'].values()) self.reweight_classes = tuple(opts['weights']['reweight_classes']) + self.register(self.reweight_branches + self.reweight_classes, to='train') self.class_weights = opts['weights'].get('class_weights', None) if self.class_weights is None: self.class_weights = np.ones(len(self.reweight_classes)) @@ -167,44 +173,56 @@ def _log(msg, *args, **kwargs): 'reweight_discard_under_overflow']: _log('%s: %s' % (k, getattr(self, k))) - # parse config - self.keep_branches = set() - aux_branches = set() # selection if self.selection: - aux_branches.update(_get_variable_names(self.selection)) + self.register(_get_variable_names(self.selection), to='train') # test time selection if self.test_time_selection: - aux_branches.update(_get_variable_names(self.test_time_selection)) - # var_funcs - self.keep_branches.update(self.var_funcs.keys()) - for expr in self.var_funcs.values(): - aux_branches.update(_get_variable_names(expr)) + self.register(_get_variable_names(self.test_time_selection), to='test') # inputs for names in self.input_dicts.values(): - self.keep_branches.update(names) - # labels - self.keep_branches.update(self.label_names) - # weight - if self.weight_name: - self.keep_branches.add(self.weight_name) - if not self.use_precomputed_weights: - aux_branches.update(self.reweight_branches) - aux_branches.update(self.reweight_classes) + self.register(names) # observers - self.keep_branches.update(self.observer_names) + self.register(self.observer_names, to='test') # monitor variables - self.keep_branches.update(self.monitor_variables) - # keep and drop - self.drop_branches = (aux_branches - self.keep_branches) - self.load_branches = (aux_branches | self.keep_branches) - set(self.var_funcs.keys()) - {self.weight_name, } + self.register(self.monitor_variables) + # resolve dependencies + func_vars = set(self.var_funcs.keys()) + for (load_branches, aux_branches) in (self.train_load_branches, self.train_aux_branches), (self.test_load_branches, self.test_aux_branches): + while (load_branches & func_vars): + for k in (load_branches & func_vars): + aux_branches.add(k) + load_branches.remove(k) + load_branches.update(_get_variable_names(self.var_funcs[k])) if print_info: - _logger.debug('drop_branches:\n %s', ','.join(self.drop_branches)) - _logger.debug('load_branches:\n %s', ','.join(self.load_branches)) + _logger.debug('train_load_branches:\n %s', ', '.join(sorted(self.train_load_branches))) + _logger.debug('train_aux_branches:\n %s', ', '.join(sorted(self.train_aux_branches))) + _logger.debug('test_load_branches:\n %s', ', '.join(sorted(self.test_load_branches))) + _logger.debug('test_aux_branches:\n %s', ', '.join(sorted(self.test_aux_branches))) def __getattr__(self, name): return self.options[name] + def register(self, name, expr=None, to='both'): + assert to in ('train', 'test', 'both') + if isinstance(name, dict): + for k, v in name.items(): + self.register(k, v, to) + elif isinstance(name, (list, tuple)): + for k in name: + self.register(k, None, to) + else: + if to in ('train', 'both'): + self.train_load_branches.add(name) + if to in ('test', 'both'): + self.test_load_branches.add(name) + if expr: + self.var_funcs[name] = expr + if to in ('train', 'both'): + self.train_aux_branches.add(name) + if to in ('test', 'both'): + self.test_aux_branches.add(name) + def dump(self, fp): with open(fp, 'w') as f: yaml.safe_dump(self.options, f, sort_keys=False) diff --git a/weaver/utils/data/preprocess.py b/weaver/utils/data/preprocess.py index c2d8efed..a0a86a60 100644 --- a/weaver/utils/data/preprocess.py +++ b/weaver/utils/data/preprocess.py @@ -9,11 +9,12 @@ from .fileio import _read_files -def _apply_selection(table, selection, funcs={}): +def _apply_selection(table, selection, funcs=None): if selection is None: return table - new_vars = {k: funcs[k] for k in _get_variable_names(selection) if k not in table.fields and k in funcs} - _build_new_variables(table, new_vars) + if funcs: + new_vars = {k: funcs[k] for k in _get_variable_names(selection) if k not in table.fields and k in funcs} + _build_new_variables(table, new_vars) selected = ak.values_astype(_eval_expr(selection, table), 'bool') return table[selected] @@ -28,11 +29,6 @@ def _build_new_variables(table, funcs): return table -def _clean_up(table, drop_branches): - columns = [k for k in table.fields if k not in drop_branches] - return table[columns] - - def _build_weights(table, data_config, reweight_hists=None): if data_config.weight_name is None: raise RuntimeError('Error when building weights: `weight_name` is None!') @@ -92,27 +88,33 @@ def __init__(self, filelist, data_config): self.load_range = (0, data_config.preprocess.get('data_fraction', 0.1)) def read_file(self, filelist): - self.keep_branches = set() - self.load_branches = set() + keep_branches = set() + aux_branches = set() + load_branches = set() for k, params in self._data_config.preprocess_params.items(): if params['center'] == 'auto': - self.keep_branches.add(k) - if k in self._data_config.var_funcs: - expr = self._data_config.var_funcs[k] - self.load_branches.update(_get_variable_names(expr)) - else: - self.load_branches.add(k) + keep_branches.add(k) + load_branches.add(k) if self._data_config.selection: - self.load_branches.update(_get_variable_names(self._data_config.selection)) - _logger.debug('[AutoStandardizer] keep_branches:\n %s', ','.join(self.keep_branches)) - _logger.debug('[AutoStandardizer] load_branches:\n %s', ','.join(self.load_branches)) - table = _read_files(filelist, self.load_branches, self.load_range, show_progressbar=True, + load_branches.update(_get_variable_names(self._data_config.selection)) + + func_vars = set(self._data_config.var_funcs.keys()) + while (load_branches & func_vars): + for k in (load_branches & func_vars): + aux_branches.add(k) + load_branches.remove(k) + load_branches.update(_get_variable_names(self._data_config.var_funcs[k])) + + _logger.debug('[AutoStandardizer] keep_branches:\n %s', ','.join(keep_branches)) + _logger.debug('[AutoStandardizer] aux_branches:\n %s', ','.join(aux_branches)) + _logger.debug('[AutoStandardizer] load_branches:\n %s', ','.join(load_branches)) + + table = _read_files(filelist, load_branches, self.load_range, show_progressbar=True, treename=self._data_config.treename, branch_magic=self._data_config.branch_magic, file_magic=self._data_config.file_magic) table = _apply_selection(table, self._data_config.selection, funcs=self._data_config.var_funcs) - table = _build_new_variables( - table, {k: v for k, v in self._data_config.var_funcs.items() if k in self.keep_branches}) - table = _clean_up(table, self.load_branches - self.keep_branches) + table = _build_new_variables(table, {k: v for k, v in self._data_config.var_funcs.items() if k in aux_branches}) + table = table[keep_branches] return table def make_preprocess_params(self, table): @@ -142,7 +144,7 @@ def produce(self, output=None): table = self.read_file(self._filelist) preprocess_params = self.make_preprocess_params(table) self._data_config.preprocess_params = preprocess_params - # must also propogate the changes to `data_config.options` so it can be persisted + # must also propagate the changes to `data_config.options` so it can be persisted self._data_config.options['preprocess']['params'] = preprocess_params if output: _logger.info( @@ -168,26 +170,31 @@ def __init__(self, filelist, data_config): self._data_config = data_config.copy() def read_file(self, filelist): - self.keep_branches = set(self._data_config.reweight_branches + self._data_config.reweight_classes + - (self._data_config.basewgt_name,)) - self.load_branches = set() - for k in self.keep_branches: - if k in self._data_config.var_funcs: - expr = self._data_config.var_funcs[k] - self.load_branches.update(_get_variable_names(expr)) - else: - self.load_branches.add(k) + keep_branches = set(self._data_config.reweight_branches + self._data_config.reweight_classes) + if self._data_config.reweight_basewgt: + keep_branches.add(self._data_config.basewgt_name) + aux_branches = set() + load_branches = keep_branches.copy() if self._data_config.selection: - self.load_branches.update(_get_variable_names(self._data_config.selection)) - _logger.debug('[WeightMaker] keep_branches:\n %s', ','.join(self.keep_branches)) - _logger.debug('[WeightMaker] load_branches:\n %s', ','.join(self.load_branches)) - table = _read_files(filelist, self.load_branches, show_progressbar=True, + load_branches.update(_get_variable_names(self._data_config.selection)) + + func_vars = set(self._data_config.var_funcs.keys()) + while (load_branches & func_vars): + for k in (load_branches & func_vars): + aux_branches.add(k) + load_branches.remove(k) + load_branches.update(_get_variable_names(self._data_config.var_funcs[k])) + + _logger.debug('[WeightMaker] keep_branches:\n %s', ','.join(keep_branches)) + _logger.debug('[WeightMaker] aux_branches:\n %s', ','.join(aux_branches)) + _logger.debug('[WeightMaker] load_branches:\n %s', ','.join(load_branches)) + + table = _read_files(filelist, load_branches, show_progressbar=True, treename=self._data_config.treename, branch_magic=self._data_config.branch_magic, file_magic=self._data_config.file_magic) table = _apply_selection(table, self._data_config.selection, funcs=self._data_config.var_funcs) - table = _build_new_variables( - table, {k: v for k, v in self._data_config.var_funcs.items() if k in self.keep_branches}) - table = _clean_up(table, self.load_branches - self.keep_branches) + table = _build_new_variables(table, {k: v for k, v in self._data_config.var_funcs.items() if k in aux_branches}) + table = table[keep_branches] return table def make_weights(self, table): @@ -284,7 +291,7 @@ def produce(self, output=None): table = self.read_file(self._filelist) wgts = self.make_weights(table) self._data_config.reweight_hists = wgts - # must also propogate the changes to `data_config.options` so it can be persisted + # must also propagate the changes to `data_config.options` so it can be persisted self._data_config.options['weights']['reweight_hists'] = {k: v.tolist() for k, v in wgts.items()} if output: _logger.info('Writing YAML file w/ reweighting info to %s' % output) diff --git a/weaver/utils/dataset.py b/weaver/utils/dataset.py index 5d48fb46..7fbdf5f1 100644 --- a/weaver/utils/dataset.py +++ b/weaver/utils/dataset.py @@ -88,7 +88,8 @@ def _preprocess(table, data_config, options): if len(table) == 0: return [] # define new variables - table = _build_new_variables(table, data_config.var_funcs) + aux_branches = data_config.train_aux_branches if options['training'] else data_config.test_aux_branches + table = _build_new_variables(table, {k: v for k, v in data_config.var_funcs.items() if k in aux_branches}) # check labels if data_config.label_type == 'simple' and options['training']: _check_labels(table) @@ -108,7 +109,8 @@ def _preprocess(table, data_config, options): def _load_next(data_config, filelist, load_range, options): - table = _read_files(filelist, data_config.load_branches, load_range, treename=data_config.treename, + load_branches = data_config.train_load_branches if options['training'] else data_config.test_load_branches + table = _read_files(filelist, load_branches, load_range, treename=data_config.treename, branch_magic=data_config.branch_magic, file_magic=data_config.file_magic) table, indices = _preprocess(table, data_config, options) return table, indices