diff --git a/t5x/te_helper.py b/t5x/te_helper.py index f3750ca88..f5857525d 100644 --- a/t5x/te_helper.py +++ b/t5x/te_helper.py @@ -179,6 +179,8 @@ def update_fp8_metas(grad_accum, flax_mutables): @staticmethod def check_dataset_cfg(config): + if not hasattr(config, 'pack'): + return assert not config.pack, \ "Transformer Engine does not support dataset.packing, please turn it off."