diff --git a/core/corr.py b/core/corr.py index 3f60ce0b..14b98dd4 100644 --- a/core/corr.py +++ b/core/corr.py @@ -2,12 +2,6 @@ import torch.nn.functional as F from utils.utils import bilinear_sampler, coords_grid -try: - import alt_cuda_corr -except: - # alt_cuda_corr is not compiled - pass - class CorrBlock: def __init__(self, fmap1, fmap2, num_levels=4, radius=4): @@ -62,6 +56,9 @@ def corr(fmap1, fmap2): class AlternateCorrBlock: def __init__(self, fmap1, fmap2, num_levels=4, radius=4): + import alt_cuda_corr + self.alt_corr_fwd = alt_cuda_corr.forward + self.num_levels = num_levels self.radius = radius