From 17fe9c40f24a33cee8de78f2f1e5e32e2f36993a Mon Sep 17 00:00:00 2001 From: MediaPipe Team <mediapipe-team@google.com> Date: Mon, 15 Apr 2024 16:10:47 -0700 Subject: [PATCH] Drop remainder from datasets in text_classifier. This helps deal with issues on TPU training that results in NaN loss. PiperOrigin-RevId: 625114628 --- mediapipe/model_maker/python/core/tasks/classifier.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mediapipe/model_maker/python/core/tasks/classifier.py b/mediapipe/model_maker/python/core/tasks/classifier.py index 786b9427b4..31c42fab00 100644 --- a/mediapipe/model_maker/python/core/tasks/classifier.py +++ b/mediapipe/model_maker/python/core/tasks/classifier.py @@ -81,6 +81,7 @@ def _train_model( is_training=True, shuffle=self._shuffle, preprocess=preprocessor, + drop_remainder=True, ) if self._hparams.repeat and self._hparams.steps_per_epoch is None: raise ValueError( @@ -96,7 +97,9 @@ def _train_model( validation_dataset = validation_data.gen_tf_dataset( batch_size=self._hparams.batch_size, is_training=False, - preprocess=preprocessor) + preprocess=preprocessor, + drop_remainder=True, + ) self._model.compile( optimizer=self._optimizer, loss=self._loss_function,