diff --git a/timesformer/datasets/multigrid_helper.py b/timesformer/datasets/multigrid_helper.py index de9ba0c..291bcd4 100755 --- a/timesformer/datasets/multigrid_helper.py +++ b/timesformer/datasets/multigrid_helper.py @@ -3,9 +3,16 @@ """Helper functions for multigrid training.""" import numpy as np -from torch._six import int_classes as _int_classes +import torch from torch.utils.data.sampler import Sampler +TORCH_MAJOR = int(torch.__version__.split('.')[0]) +TORCH_MINOR = int(torch.__version__.split('.')[1]) + +if TORCH_MAJOR == 1 and TORCH_MINOR < 8: + from torch._six import int_classes as _int_classes +else: + _int_classes = int class ShortCycleBatchSampler(Sampler): """ diff --git a/timesformer/models/resnet_helper.py b/timesformer/models/resnet_helper.py index 082d318..773602b 100755 --- a/timesformer/models/resnet_helper.py +++ b/timesformer/models/resnet_helper.py @@ -12,7 +12,6 @@ from einops import rearrange, reduce, repeat import torch.nn.functional as F from torch.nn.modules.module import Module -from torch.nn.modules.linear import _LinearWithBias from torch.nn.modules.activation import MultiheadAttention import numpy as np diff --git a/timesformer/models/vit_utils.py b/timesformer/models/vit_utils.py index 9ce6a93..83251e5 100755 --- a/timesformer/models/vit_utils.py +++ b/timesformer/models/vit_utils.py @@ -11,7 +11,14 @@ from timesformer.models.helpers import load_pretrained from .build import MODEL_REGISTRY from itertools import repeat -from torch._six import container_abcs + +TORCH_MAJOR = int(torch.__version__.split('.')[0]) +TORCH_MINOR = int(torch.__version__.split('.')[1]) + +if TORCH_MAJOR == 1 and TORCH_MINOR < 8: + from torch._six import container_abcs +else: + import collections.abc as container_abcs DEFAULT_CROP_PCT = 0.875 IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)