diff --git a/torchdata/datapipes/iter/transform/bucketbatcher.py b/torchdata/datapipes/iter/transform/bucketbatcher.py index 867068627..fb2a7f617 100644 --- a/torchdata/datapipes/iter/transform/bucketbatcher.py +++ b/torchdata/datapipes/iter/transform/bucketbatcher.py @@ -9,7 +9,7 @@ from dataclasses import dataclass, field from functools import partial -from typing import Callable, Generic, Iterator, List, Optional, TypeVar +from typing import Callable, final, Generic, Iterator, List, Optional, TypeVar import torch @@ -63,6 +63,7 @@ def __iter__(self) -> Iterator[DataChunk[T_co]]: new_batch = self._rng.sample(batch, len(batch)) yield DataChunk(new_batch) + @final def reset(self) -> None: if self._enabled: if self._seed is None: diff --git a/torchdata/datapipes/iter/util/combining.py b/torchdata/datapipes/iter/util/combining.py index bd9d70769..73a959b37 100644 --- a/torchdata/datapipes/iter/util/combining.py +++ b/torchdata/datapipes/iter/util/combining.py @@ -7,7 +7,7 @@ import warnings from collections import OrderedDict -from typing import Callable, Iterator, List, Optional, Sequence, TypeVar +from typing import Callable, final, Iterator, List, Optional, Sequence, TypeVar from torch.utils.data import functional_datapipe, IterDataPipe, MapDataPipe from torch.utils.data.datapipes.iter.combining import _ChildDataPipe, _DemultiplexerIterDataPipe, _ForkerIterDataPipe @@ -125,6 +125,7 @@ def __iter__(self) -> Iterator: def __len__(self) -> int: return len(self.source_datapipe) + @final def reset(self) -> None: self.buffer = OrderedDict() diff --git a/torchdata/datapipes/iter/util/distributed.py b/torchdata/datapipes/iter/util/distributed.py index 2e6fe813c..ebccf59be 100644 --- a/torchdata/datapipes/iter/util/distributed.py +++ b/torchdata/datapipes/iter/util/distributed.py @@ -10,7 +10,7 @@ from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError from dataclasses import dataclass from functools import partial -from typing import Callable, Deque, Iterator, Optional, TypeVar +from typing import Callable, Deque, final, Iterator, Optional, TypeVar import torch import torch.distributed as dist @@ -197,6 +197,7 @@ def __iter__(self) -> Iterator[T_co]: break yield data + @final def reset(self): if self._executor is not None: self._executor.shutdown() @@ -236,6 +237,7 @@ def pause(self): # self._executor.shutdown() # self._executor = None + @final def resume(self): raise RuntimeError("`resume` is not supported for FullSync at the moment.") # self._executor = _PrefetchExecutor(iter(self.datapipe), 1, self._callback_fn, self.timeout) diff --git a/torchdata/datapipes/iter/util/paragraphaggregator.py b/torchdata/datapipes/iter/util/paragraphaggregator.py index cb4e9839d..f2ec7bacd 100644 --- a/torchdata/datapipes/iter/util/paragraphaggregator.py +++ b/torchdata/datapipes/iter/util/paragraphaggregator.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Callable, Iterator, List, Tuple, TypeVar +from typing import Callable, final, Iterator, List, Tuple, TypeVar from torch.utils.data.datapipes.utils.common import _check_unpickable_fn @@ -66,6 +66,7 @@ def __iter__(self) -> Iterator[Tuple[str, str]]: if self.buffer: yield prev_filename, self.joiner(self.buffer) # type: ignore[misc] + @final def reset(self) -> None: self.buffer = [] diff --git a/torchdata/datapipes/iter/util/prefetcher.py b/torchdata/datapipes/iter/util/prefetcher.py index 0fb1ff8fe..92da91956 100644 --- a/torchdata/datapipes/iter/util/prefetcher.py +++ b/torchdata/datapipes/iter/util/prefetcher.py @@ -8,7 +8,7 @@ import time from collections import deque -from typing import Deque, Optional +from typing import Deque, final, Optional import torch @@ -125,6 +125,7 @@ def __setstate__(self, state): self.buffer_size = state["buffer_size"] self.thread = None + @final def reset(self): if self.thread is not None: self.prefetch_data.run_prefetcher = False @@ -142,6 +143,7 @@ def pause(self): while not self.prefetch_data.paused: time.sleep(PRODUCER_SLEEP_INTERVAL * 10) + @final def resume(self): if self.thread is not None and ( not self.prefetch_data.stop_iteration or len(self.prefetch_data.prefetch_buffer) > 0 diff --git a/torchdata/datapipes/iter/util/randomsplitter.py b/torchdata/datapipes/iter/util/randomsplitter.py index 27732314f..2608f93c8 100644 --- a/torchdata/datapipes/iter/util/randomsplitter.py +++ b/torchdata/datapipes/iter/util/randomsplitter.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import random -from typing import Dict, List, Optional, TypeVar, Union +from typing import Dict, final, List, Optional, TypeVar, Union from torchdata.datapipes import functional_datapipe from torchdata.datapipes.iter import IterDataPipe @@ -119,6 +119,7 @@ def normalize_weights(weights: List[float], total_length: int) -> List[float]: total_weight = sum(weights) return [float(w) * total_length / total_weight for w in weights] + @final def reset(self) -> None: self._rng = random.Random(self._seed) self.weights = self.norm_weights.copy()