Skip to content

Commit

Permalink
Prevent overwrite (#1076)
Browse files Browse the repository at this point in the history
Summary:
Fixes #886

### Changes
Add `final` decorators to reset and resume methods in datapipies. Mypy gives a soft hint when overriding these methods.

CLA is in progress.

In collaboration with thomasdick

Pull Request resolved: #1076

Reviewed By: ejguan

Differential Revision: D44109506

Pulled By: NivekT

fbshipit-source-id: f215a6e473a44c527dc53c9ee6a026b7126b9825
  • Loading branch information
Carlos Schmidt Muniz authored and facebook-github-bot committed Mar 27, 2023
1 parent e78ab6c commit aeda987
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 6 deletions.
3 changes: 2 additions & 1 deletion torchdata/datapipes/iter/transform/bucketbatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from dataclasses import dataclass, field
from functools import partial
from typing import Callable, Generic, Iterator, List, Optional, TypeVar
from typing import Callable, final, Generic, Iterator, List, Optional, TypeVar

import torch

Expand Down Expand Up @@ -63,6 +63,7 @@ def __iter__(self) -> Iterator[DataChunk[T_co]]:
new_batch = self._rng.sample(batch, len(batch))
yield DataChunk(new_batch)

@final
def reset(self) -> None:
if self._enabled:
if self._seed is None:
Expand Down
3 changes: 2 additions & 1 deletion torchdata/datapipes/iter/util/combining.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import warnings

from collections import OrderedDict
from typing import Callable, Iterator, List, Optional, Sequence, TypeVar
from typing import Callable, final, Iterator, List, Optional, Sequence, TypeVar

from torch.utils.data import functional_datapipe, IterDataPipe, MapDataPipe
from torch.utils.data.datapipes.iter.combining import _ChildDataPipe, _DemultiplexerIterDataPipe, _ForkerIterDataPipe
Expand Down Expand Up @@ -125,6 +125,7 @@ def __iter__(self) -> Iterator:
def __len__(self) -> int:
return len(self.source_datapipe)

@final
def reset(self) -> None:
self.buffer = OrderedDict()

Expand Down
4 changes: 3 additions & 1 deletion torchdata/datapipes/iter/util/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError
from dataclasses import dataclass
from functools import partial
from typing import Callable, Deque, Iterator, Optional, TypeVar
from typing import Callable, Deque, final, Iterator, Optional, TypeVar

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -197,6 +197,7 @@ def __iter__(self) -> Iterator[T_co]:
break
yield data

@final
def reset(self):
if self._executor is not None:
self._executor.shutdown()
Expand Down Expand Up @@ -236,6 +237,7 @@ def pause(self):
# self._executor.shutdown()
# self._executor = None

@final
def resume(self):
raise RuntimeError("`resume` is not supported for FullSync at the moment.")
# self._executor = _PrefetchExecutor(iter(self.datapipe), 1, self._callback_fn, self.timeout)
3 changes: 2 additions & 1 deletion torchdata/datapipes/iter/util/paragraphaggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Callable, Iterator, List, Tuple, TypeVar
from typing import Callable, final, Iterator, List, Tuple, TypeVar

from torch.utils.data.datapipes.utils.common import _check_unpickable_fn

Expand Down Expand Up @@ -66,6 +66,7 @@ def __iter__(self) -> Iterator[Tuple[str, str]]:
if self.buffer:
yield prev_filename, self.joiner(self.buffer) # type: ignore[misc]

@final
def reset(self) -> None:
self.buffer = []

Expand Down
4 changes: 3 additions & 1 deletion torchdata/datapipes/iter/util/prefetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import time

from collections import deque
from typing import Deque, Optional
from typing import Deque, final, Optional

import torch

Expand Down Expand Up @@ -125,6 +125,7 @@ def __setstate__(self, state):
self.buffer_size = state["buffer_size"]
self.thread = None

@final
def reset(self):
if self.thread is not None:
self.prefetch_data.run_prefetcher = False
Expand All @@ -142,6 +143,7 @@ def pause(self):
while not self.prefetch_data.paused:
time.sleep(PRODUCER_SLEEP_INTERVAL * 10)

@final
def resume(self):
if self.thread is not None and (
not self.prefetch_data.stop_iteration or len(self.prefetch_data.prefetch_buffer) > 0
Expand Down
3 changes: 2 additions & 1 deletion torchdata/datapipes/iter/util/randomsplitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

import random
from typing import Dict, List, Optional, TypeVar, Union
from typing import Dict, final, List, Optional, TypeVar, Union

from torchdata.datapipes import functional_datapipe
from torchdata.datapipes.iter import IterDataPipe
Expand Down Expand Up @@ -119,6 +119,7 @@ def normalize_weights(weights: List[float], total_length: int) -> List[float]:
total_weight = sum(weights)
return [float(w) * total_length / total_weight for w in weights]

@final
def reset(self) -> None:
self._rng = random.Random(self._seed)
self.weights = self.norm_weights.copy()
Expand Down

0 comments on commit aeda987

Please sign in to comment.