diff --git a/init2winit/utils.py b/init2winit/utils.py index 568d579..bd9c3b5 100644 --- a/init2winit/utils.py +++ b/init2winit/utils.py @@ -73,9 +73,10 @@ def array_append(full_array, to_append): def reduce_to_scalar(value): """Helper function to reduce an numpy array to a scalar by extracting the first element.""" - if isinstance(value, np.ndarray): - if value.shape == (1,): - value = value[0] + while isinstance(value, np.ndarray) and value.ndim > 0: + value = value[0] + if isinstance(value, np.ndarray) and value.ndim == 0: + value = value.item() return value @@ -246,9 +247,14 @@ def append_scalar_metrics(self, metrics): if name not in self._measurements: self._measurements[name] = self._xm_work_unit.get_measurement_series( label=name) - self._measurements[name].create_measurement( - objective_value=reduce_to_scalar(value), step=metrics['global_step'] - ) + try: + self._measurements[name].create_measurement( + objective_value=reduce_to_scalar(value), + step=metrics['global_step'], + ) + except TypeError as e: + logging.info('Failed to create measurement for %s: %s', name, value) + raise e if self._tb_metric_writer: self._tb_metric_writer.write_scalars(