Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dimensional Collapse Tracking #77

Open
zankerx opened this issue Aug 3, 2024 · 2 comments
Open

Dimensional Collapse Tracking #77

zankerx opened this issue Aug 3, 2024 · 2 comments

Comments

@zankerx
Copy link

zankerx commented Aug 3, 2024

Hello, first of all, thank you for sharing this kind of model.

I have a question regarding the development of this model, particularly about the collapse of the dimension. Do you use any other indicators besides loss during training to select hyper-parameters and maximize the amount of information contained in the embedding? (Knn, eigenvalues of the embedding, or others?).

Thank you in advance for your response :)

@zankerx
Copy link
Author

zankerx commented Aug 29, 2024

In the code, I noticed that you also have a regularization term that is not being used, with the coefficient value set to 0 by default.

            'def reg_fn(z):
                return sum([torch.sqrt(zi.var(dim=1) + 0.0001) for zi in z]) / len(z)
            # Step 1. Forward
            loss_jepa, loss_reg = 0., 0.
            with torch.cuda.amp.autocast(dtype=dtype, enabled=mixed_precision):
                h = forward_target(clips)
                z = forward_context(clips, h)
                loss_jepa = loss_fn(z, h)  # jepa prediction loss
                pstd_z = reg_fn(z)  # predictor variance across patches
                loss_reg += torch.mean(F.relu(1.-pstd_z))
            loss = loss_jepa + reg_coeff * loss_reg`

Have you studied the impact of this regularization on the model ?

@vimalthilak
Copy link

@zankerx there is a line of research on embedding quality that maybe useful for what you are thinking about. Metrics such as \alpha-Req, RankMe or LiDAR (self-promotion alert as I am one of the authors), CLID have shown to be useful for estimating embedding quality in image-based SSL (and transfer learning) methods

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants