Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about the mask sampling #50

Open
FriedRonaldo opened this issue Mar 31, 2024 · 3 comments
Open

Question about the mask sampling #50

FriedRonaldo opened this issue Mar 31, 2024 · 3 comments

Comments

@FriedRonaldo
Copy link

FriedRonaldo commented Mar 31, 2024

Hi, I read the paper JEPA and it is an effective way to learn temporal information better than other works like VideoMAE and UMT.

I have a question about the mask sampling.

To be clear, I do not mean to review or criticize the paper, but I want to reproduce the work exactly.

Question 01) When I instantiate a mask generator and then sample a mask, it sometimes masks only the first N frames.

For example, the source code below describes the situation.

mg = BlockMaskGenerator(aspect_ratio=(0.75, 1.5), npred=8, spatial_pred_mask_scale=(0.15, 0.15), temporal_pred_mask_scale=(1., 1.), max_context_frames_ratio=1.0, image_size=(64, 64), num_frames=2, patch_size=(16, 16), temporal_stride=1)
mask_enc, mask_pred = mg(16)
print(mask_enc)

it outputs

tensor([[ 6,  7,  8, 11, 12, 13, 22, 23],
        [ 2,  3,  4,  5, 10, 11, 15, 18],
        [ 0,  3,  6,  7,  8, 11, 15, 16],
        [ 4,  5, 12, 13, 20, 21, 28, 29],
        [ 2,  3, 10, 11, 12, 13, 14, 15],
        [ 3,  4,  5, 12, 13, 14, 15, 19],
        [ 3,  8, 12, 13, 14, 15, 19, 24],
        [ 0,  3,  8,  9, 15, 16, 19, 24],
        [ 7,  8, 11, 12, 13, 23, 24, 27],
        [ 4,  7,  8, 12, 13, 20, 23, 24],
        [ 2,  3,  4,  8, 18, 19, 20, 24],
        [ 3,  4,  8,  9, 10, 11, 19, 20],
        [ 0,  7,  8,  9, 10, 11, 12, 15],
        [ 0,  1,  4,  7, 11, 12, 13, 16],
        [ 6,  7, 11, 12, 15, 22, 23, 27],
        [ 4,  5,  8,  9, 10, 11, 14, 15]])

In some cases like mask_enc[-1] and mask_enc[-4], the mask is applied only to the first frame.
(There are 2 frames and 16 patches for each frame, then, the indices of [[ 4, 5, 8, 9, 10, 11, 14, 15]] can mask the first frame only -- because the index under 16 is included in the first frame.)

In this case, for some batches, the model seems to use the part of the frames (ex. 4 masked frames out of 8 frames) and is required to reconstruct the entire patches only with first some patches in some frames. (ex. reconstruct 8 frames using 4 masked frames)

Is my analysis correct? If so, it might not be the same as the description of the paper that says the mask is the same for all frames.

3D Multi-Block Masking. We use a simple 3D extension of the block masking strategy employed
for images (Bao et al., 2021). Given a video, we sample several (possibly overlapping) spatially
continuous blocks with various aspect ratios and take their union to construct a single mask. This
spatial mask is then repeated across the entire temporal dimension. Masking a large continuous
block that covers the full temporal dimension limits information leakage due to the spatial and
temporal redundancy of videos, and results in a harder prediction task (Tong et al., 2022).

In this case, the masking strategy does not work as the intention to limit information leakage.

Question 02) The sum of the visible and invisible masks seems not to be the same as the total number of patches.

When I print the shape of each mask, I get the output like below:

print(mask_enc.shape)
print(mask_pred.shape)

torch.Size([16, 8])
torch.Size([16, 16])

There are 32 patches (2 frames * 16 patches for each frame = 32) but the sum of the lengths is less than the total patch counts.

Discussion

The second question might not be that problematic. It uses the part of the visible patches for each sample to reconstruct the part of the input video. Because partial reconstruction in MAE is shown to be effective in the paper [1]

[1] CrossMAE: Rethinking Patch Dependence for Masked Autoencoders

Approach (if the analysis is correct and the behavior is not intended)

However, the first question can affect the performance because the masking method aims to block the information leakage between the frames, specifically, preventing the model from copying the near patches at the different frames.

To resolve the problem, I think the masking block should be sampled for a single frame and repeated along the time axis with an offset (the number of patches in each frame).

I hope the discussion improves the clarity of the source code and the paper.

Thanks.

Update

The source code below can be a way to fix the mask sampling method.

        collated_masks_pred, collated_masks_enc = [], []
        min_keep_enc = min_keep_pred = self.duration * self.height * self.width
        for _ in range(batch_size):

            empty_context = True
            while empty_context:

                mask_e = torch.ones((1, self.height, self.width), dtype=torch.int32)
                for _ in range(self.npred):
                    mask_e *= self._sample_block_mask(p_size)
                mask_e = mask_e.flatten()

                mask_p = torch.argwhere(mask_e == 0).squeeze()
                mask_e = torch.nonzero(mask_e).squeeze()

                empty_context = (len(mask_e) == 0)
                if not empty_context:
                    min_keep_pred = min(min_keep_pred, len(mask_p))
                    min_keep_enc = min(min_keep_enc, len(mask_e))
                    collated_masks_pred.append(mask_p)
                    collated_masks_enc.append(mask_e)

        if self.max_keep is not None:
            min_keep_enc = min(min_keep_enc, self.max_keep)

        # --
        return self._truncate_mask(collated_masks_enc, min_keep_enc), self._truncate_mask(collated_masks_pred, min_keep_pred)
    
    def _truncate_mask(self, masks, min_keep):
        result = []
        for cm in masks:
            # choice min_keep items randomly
            idx = torch.randperm(len(cm))[:min_keep]
            cm = cm[idx]
            tmp = torch.zeros((1, self.height, self.width), dtype=torch.int32)
            tmp.flatten()[cm] = 1
            tmp = tmp.repeat(self.duration, 1, 1)
            tmp = torch.nonzero(tmp.flatten()).squeeze()
            result.append(tmp)
        return torch.utils.data.default_collate(result)

For the sanity check, I run the code without "tmp = torch.nonzero(tmp.flatten()).squeeze()".

The outputs are like:

image
@FriedRonaldo
Copy link
Author

Hi, @MidoAssran . It would be great if one of the authors could reply the issue to make future readers understand the work better. Thanks!

@rozgo
Copy link

rozgo commented Apr 4, 2024

If there's substantial information leakage due to this unintended mask sampling behavior, it could compromise the model's temporal learning capabilities by simplifying the learning task. Definitely interested in learning more.

@gozy0
Copy link

gozy0 commented Nov 20, 2024

Was anyone able to achieve training and evaluating for the vjepa code?

Are there any updates on this issue? I am trying to understand the vjepa masking strategy: i) the number of non-masked patches + number of masked-patches do not match the total number of patches ii) From the two types of masks (short and long-range), seems to choose only the first of them in the predictor's forward.

Any help / pointers are appreciated.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants