diff --git a/test/smoke_test/smoke_test.py b/test/smoke_test/smoke_test.py index d3ad74ba2..7643676de 100644 --- a/test/smoke_test/smoke_test.py +++ b/test/smoke_test/smoke_test.py @@ -16,6 +16,10 @@ def s3_test(): from torchdata._torchdata import S3Handler +def stateful_dataloader_test(): + from torchdata.stateful_dataloader import StatefulDataLoader + + if __name__ == "__main__": r""" TorchData Smoke Test @@ -26,3 +30,7 @@ def s3_test(): options = parser.parse_args() if options.s3: s3_test() + + if not torchdata.__version__.startswith("0.8.0"): + raise Exception(f"TorchData version is not 0.8.0, found {torchdata.__version__}") + stateful_dataloader_test()