diff --git a/tests/test_nested_tensor.py b/tests/test_nested_tensor.py index c55171397..22e0ae267 100644 --- a/tests/test_nested_tensor.py +++ b/tests/test_nested_tensor.py @@ -7,7 +7,7 @@ import unittest import torch -from vissl.data.collators.collator_helper import MultiDimensionalTensor +from vissl.utils.multi_dimensional_tensor import MultiDimensionalTensor logger = logging.getLogger("__name__") diff --git a/vissl/data/collators/multicrop_collator.py b/vissl/data/collators/multicrop_collator.py index 40efb3d94..f5d8571af 100644 --- a/vissl/data/collators/multicrop_collator.py +++ b/vissl/data/collators/multicrop_collator.py @@ -5,7 +5,7 @@ import torch from vissl.data.collators import register_collator -from vissl.data.collators.collator_helper import MultiDimensionalTensor +from vissl.utils.multi_dimensional_tensor import MultiDimensionalTensor @register_collator("multicrop_collator") diff --git a/vissl/models/base_ssl_model.py b/vissl/models/base_ssl_model.py index 02e73ad74..6432420c7 100644 --- a/vissl/models/base_ssl_model.py +++ b/vissl/models/base_ssl_model.py @@ -12,7 +12,6 @@ from classy_vision.models import ClassyModel, register_model from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP from vissl.config import AttrDict -from vissl.data.collators.collator_helper import MultiDimensionalTensor from vissl.models.heads import SwAVPrototypesHead, get_model_head from vissl.models.model_helpers import ( get_trunk_output_feature_names, @@ -27,6 +26,7 @@ from vissl.utils.env import get_machine_local_and_dist_rank from vissl.utils.fsdp_utils import fsdp_recursive_reset_lazy_init from vissl.utils.misc import set_torch_seed +from vissl.utils.multi_dimensional_tensor import MultiDimensionalTensor @register_model("multi_input_output_model") diff --git a/vissl/models/model_helpers.py b/vissl/models/model_helpers.py index e595bea8a..5f94d83a4 100644 --- a/vissl/models/model_helpers.py +++ b/vissl/models/model_helpers.py @@ -14,9 +14,9 @@ import torch.nn.functional as F from torch.nn.modules.utils import _ntuple from torch.utils.checkpoint import checkpoint -from vissl.data.collators.collator_helper import MultiDimensionalTensor from vissl.utils.activation_checkpointing import checkpoint_trunk from vissl.utils.misc import is_apex_available +from vissl.utils.multi_dimensional_tensor import MultiDimensionalTensor # Tuple of classes of BN layers. diff --git a/vissl/models/trunks/regnet.py b/vissl/models/trunks/regnet.py index a58da0817..1d009cf6c 100644 --- a/vissl/models/trunks/regnet.py +++ b/vissl/models/trunks/regnet.py @@ -10,7 +10,6 @@ import torch.nn as nn from classy_vision.models import RegNet as ClassyRegNet, build_model from vissl.config import AttrDict -from vissl.data.collators.collator_helper import MultiDimensionalTensor from vissl.models.model_helpers import ( Flatten, get_trunk_forward_outputs, @@ -18,6 +17,7 @@ transform_model_input_data_type, ) from vissl.models.trunks import register_model_trunk +from vissl.utils.multi_dimensional_tensor import MultiDimensionalTensor @register_model_trunk("regnet") diff --git a/vissl/models/trunks/regnet_fsdp.py b/vissl/models/trunks/regnet_fsdp.py index 222dc3922..a70e8d424 100644 --- a/vissl/models/trunks/regnet_fsdp.py +++ b/vissl/models/trunks/regnet_fsdp.py @@ -39,7 +39,6 @@ from classy_vision.models.regnet import RegNetParams from fairscale.nn import checkpoint_wrapper from vissl.config import AttrDict -from vissl.data.collators.collator_helper import MultiDimensionalTensor from vissl.models.model_helpers import ( Flatten, get_trunk_forward_outputs, @@ -49,6 +48,7 @@ from vissl.models.trunks import register_model_trunk from vissl.utils.fsdp_utils import auto_wrap_big_layers, fsdp_auto_wrap_bn, fsdp_wrapper from vissl.utils.misc import set_torch_seed +from vissl.utils.multi_dimensional_tensor import MultiDimensionalTensor def init_weights(module): diff --git a/vissl/models/trunks/resnext.py b/vissl/models/trunks/resnext.py index d9eb1487c..76a0246fb 100644 --- a/vissl/models/trunks/resnext.py +++ b/vissl/models/trunks/resnext.py @@ -12,7 +12,6 @@ import torchvision.models as models from torchvision.models.resnet import Bottleneck from vissl.config import AttrDict -from vissl.data.collators.collator_helper import MultiDimensionalTensor from vissl.models.model_helpers import ( Flatten, _get_norm, @@ -21,6 +20,7 @@ transform_model_input_data_type, ) from vissl.models.trunks import register_model_trunk +from vissl.utils.multi_dimensional_tensor import MultiDimensionalTensor # For more depths, add the block config here diff --git a/vissl/data/collators/collator_helper.py b/vissl/utils/multi_dimensional_tensor.py similarity index 100% rename from vissl/data/collators/collator_helper.py rename to vissl/utils/multi_dimensional_tensor.py