Skip to content

Commit

Permalink
Drop remainder from datasets in text_classifier. This helps deal with…
Browse files Browse the repository at this point in the history
… issues on TPU training that results in NaN loss.

PiperOrigin-RevId: 625114628
  • Loading branch information
MediaPipe Team authored and copybara-github committed Apr 15, 2024
1 parent 2767d84 commit 17fe9c4
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion mediapipe/model_maker/python/core/tasks/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def _train_model(
is_training=True,
shuffle=self._shuffle,
preprocess=preprocessor,
drop_remainder=True,
)
if self._hparams.repeat and self._hparams.steps_per_epoch is None:
raise ValueError(
Expand All @@ -96,7 +97,9 @@ def _train_model(
validation_dataset = validation_data.gen_tf_dataset(
batch_size=self._hparams.batch_size,
is_training=False,
preprocess=preprocessor)
preprocess=preprocessor,
drop_remainder=True,
)
self._model.compile(
optimizer=self._optimizer,
loss=self._loss_function,
Expand Down

0 comments on commit 17fe9c4

Please sign in to comment.