From 17fe9c40f24a33cee8de78f2f1e5e32e2f36993a Mon Sep 17 00:00:00 2001
From: MediaPipe Team <mediapipe-team@google.com>
Date: Mon, 15 Apr 2024 16:10:47 -0700
Subject: [PATCH] Drop remainder from datasets in text_classifier. This helps
 deal with issues on TPU training that results in NaN loss.

PiperOrigin-RevId: 625114628
---
 mediapipe/model_maker/python/core/tasks/classifier.py | 5 ++++-
 1 file changed, 4 insertions(+), 1 deletion(-)

diff --git a/mediapipe/model_maker/python/core/tasks/classifier.py b/mediapipe/model_maker/python/core/tasks/classifier.py
index 786b9427b4..31c42fab00 100644
--- a/mediapipe/model_maker/python/core/tasks/classifier.py
+++ b/mediapipe/model_maker/python/core/tasks/classifier.py
@@ -81,6 +81,7 @@ def _train_model(
         is_training=True,
         shuffle=self._shuffle,
         preprocess=preprocessor,
+        drop_remainder=True,
     )
     if self._hparams.repeat and self._hparams.steps_per_epoch is None:
       raise ValueError(
@@ -96,7 +97,9 @@ def _train_model(
     validation_dataset = validation_data.gen_tf_dataset(
         batch_size=self._hparams.batch_size,
         is_training=False,
-        preprocess=preprocessor)
+        preprocess=preprocessor,
+        drop_remainder=True,
+    )
     self._model.compile(
         optimizer=self._optimizer,
         loss=self._loss_function,