Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

3D input tensors and feature reduction #252

Open
wmloh opened this issue Nov 18, 2024 · 5 comments · May be fixed by #266
Open

3D input tensors and feature reduction #252

wmloh opened this issue Nov 18, 2024 · 5 comments · May be fixed by #266

Comments

@wmloh
Copy link

wmloh commented Nov 18, 2024

@wiseodd
Tldr; Issue with tensors of size (B, L, D) passing through a Linear last layer

Here's the minimal reproducible example:

import torch
import torch.nn as nn
from laplace import Laplace
from torch.utils.data import DataLoader
from tensordict import TensorDict

BATCH_SIZE = 4  # B
SEQ_LENGTH = 6  # L
EMBED_DIM = 8  # D
INPUT_KEY = "input"
OUTPUT_KEY = "output"


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.attn = nn.MultiheadAttention(EMBED_DIM, num_heads=1)
        self.final_layer = nn.Linear(EMBED_DIM, 1)

    def forward(self, x):
        x = x[INPUT_KEY].view(-1, SEQ_LENGTH, EMBED_DIM)  # (B, L, D) 
        out = self.attn(x, x, x, need_weights=False)[0]  # (B, L, D)
        return self.final_layer(out).squeeze(dim=-1)  # (B, L)


ds = TensorDict({INPUT_KEY: torch.randn((100, SEQ_LENGTH * EMBED_DIM)),
                 OUTPUT_KEY: torch.randn((100, SEQ_LENGTH * 1))},
                batch_size=[100])  # simulates a dataset
dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=False, collate_fn=lambda x: x)

model = Model()
la = Laplace(model, "regression", dict_key_x=INPUT_KEY, dict_key_y=OUTPUT_KEY,
             last_layer_name="final_layer", feature_reduction="average")
la.fit(dl)

data = next(iter(dl))  # data[INPUT_KEY].shape = (B, L * D)
pred_map = model(data)  # (B, D)
pred_lap = la(data)  # TODO: error! (shape '[4, 6, 54]' is invalid for input of size 216)

My goal is to obtain some measure of epistemic uncertainty on each "token" of the output. The difference from the tutorials I reviewed is the 3D input tensor, which I need for attention. Using the feature_reduction parameter seems to help a little when I was trying to debug but I'm not very familiar with this functionality.

I find it surprising that la.fit(dl) works but the forward call la(data) doesn't work. How do you recommend I use this library properly for this use-case? Thanks in advance.

@wiseodd
Copy link
Collaborator

wiseodd commented Nov 20, 2024

  1. For better flexibility, it's better to use la(..., subset_of_weights="all", ...) and then just switch of the gradients of the parameters you don't need. E.g. if you want to do last-layer Laplace, keep that subset_of_weights="all" but set gradients of all but the last-layer params to False.
  2. Currently, the GLM prediction doesn't work well with multi-dim batch outputs. But MC approx. works well: set pred_type="nn", link_approx="mc" when making prediction.
  3. Don't squeeze the output dimensions, esp. the class dimension, in your model. Otherwise laplace-torch won't know how many outputs/classes you have.

I will try to make (2) above work for GLM.

For now, this script works:

import torch
import torch.nn as nn
from tensordict import TensorDict
from torch.utils.data import DataLoader

from laplace import Laplace
from laplace.curvature.asdl import AsdlGGN
from laplace.utils.enums import LinkApprox, PredType

BATCH_SIZE = 4  # B
SEQ_LENGTH = 6  # L
EMBED_DIM = 8  # D
INPUT_KEY = "input"
OUTPUT_KEY = "output"


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.attn = nn.MultiheadAttention(EMBED_DIM, num_heads=1)
        self.final_layer = nn.Linear(EMBED_DIM, 1)

    def forward(self, x):
        x = x[INPUT_KEY].view(-1, SEQ_LENGTH, EMBED_DIM)  # (B, L, D)
        out = self.attn(x, x, x, need_weights=False)[0]  # (B, L, D)
        return self.final_layer(out)  # (B, L, 1)


ds = TensorDict(
    {
        INPUT_KEY: torch.randn((100, SEQ_LENGTH, EMBED_DIM)),
        OUTPUT_KEY: torch.randn((100, SEQ_LENGTH, 1)),
    },
    batch_size=[100],
)  # simulates a dataset
dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=False, collate_fn=lambda x: x)

model = Model()

for mod_name, mod in model.named_modules():
    if mod_name == "final_layer":
        for p in mod.parameters():
            p.requires_grad = True
    else:
        for p in mod.parameters():
            p.requires_grad = False

la = Laplace(
    model,
    "regression",
    hessian_structure="diag",
    subset_of_weights="all",
    backend=AsdlGGN,
    dict_key_x=INPUT_KEY,
    dict_key_y=OUTPUT_KEY,
)
la.fit(dl)

data = next(iter(dl))  # data[INPUT_KEY].shape = (B, L * D)
pred_map = model(data)  # (B, D)
pred_la_mean, pred_la_var = la(
    data, pred_type=PredType.NN, link_approx=LinkApprox.MC, n_samples=10
)

# torch.Size([4, 6, 1]) torch.Size([4, 6, 1])
print(pred_la_mean.shape, pred_la_var.shape)

@wiseodd
Copy link
Collaborator

wiseodd commented Nov 20, 2024

Addendum: If you branch glm_multidim, then you can use the GLM predictive (better than MC) with caveats:

  • hessian_structure in ["full", "diag"]
  • You must pass enable_backprop=True to Laplace(...), then after you are done, just detach.

See example: https://github.com/aleximmer/Laplace/blob/glm-multidim/examples/lm_example.py

@wmloh
Copy link
Author

wmloh commented Nov 20, 2024

@wiseodd

Thanks for correcting the code. It works on my end, and I've successfully transferred the correction to my complete use-case (at least with respect to the mentioned issue).

Regarding point (1), oddly enough, if I stubbornly insist on setting subset_of_weights="last_layer", it will not work. It needs to be "all". Regardless, it's nothing major.

@wiseodd
Copy link
Collaborator

wiseodd commented Nov 20, 2024

As I said before, subset_of_weights="last_layer" is much less flexible. Just set it to "all" and switch off gradients

@wmloh wmloh closed this as completed Nov 20, 2024
@wiseodd
Copy link
Collaborator

wiseodd commented Nov 20, 2024

I‘ll keep this open until the aforementioned branch merged. Thanks for opening the issue!

@wiseodd wiseodd reopened this Nov 20, 2024
@wiseodd wiseodd linked a pull request Dec 6, 2024 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants