Skip to content

Commit

Permalink
Fix to convert metrics from numpy array to ints or floats by calling …
Browse files Browse the repository at this point in the history
…item().

Without this, some runs may still break for the metric 'preemption_count'.

PiperOrigin-RevId: 664813014
  • Loading branch information
priyakasimbeg authored and copybara-github committed Aug 19, 2024
1 parent 689a815 commit 6ebdb02
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions init2winit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,10 @@ def array_append(full_array, to_append):

def reduce_to_scalar(value):
"""Helper function to reduce an numpy array to a scalar by extracting the first element."""
if isinstance(value, np.ndarray):
if value.shape == (1,):
value = value[0]
while isinstance(value, np.ndarray) and value.ndim > 0:
value = value[0]
if isinstance(value, np.ndarray) and value.ndim == 0:
value = value.item()
return value


Expand Down Expand Up @@ -246,9 +247,14 @@ def append_scalar_metrics(self, metrics):
if name not in self._measurements:
self._measurements[name] = self._xm_work_unit.get_measurement_series(
label=name)
self._measurements[name].create_measurement(
objective_value=reduce_to_scalar(value), step=metrics['global_step']
)
try:
self._measurements[name].create_measurement(
objective_value=reduce_to_scalar(value),
step=metrics['global_step'],
)
except TypeError as e:
logging.info('Failed to create measurement for %s: %s', name, value)
raise e

if self._tb_metric_writer:
self._tb_metric_writer.write_scalars(
Expand Down

0 comments on commit 6ebdb02

Please sign in to comment.