From dbfaa5f9f882eb7bad276e1307f4af7cf36ec516 Mon Sep 17 00:00:00 2001 From: Agustinus Kristiadi Date: Tue, 23 Apr 2024 16:37:33 -0400 Subject: [PATCH 1/3] Make the dict keys general for model with dict-based inputs --- laplace/baselaplace.py | 45 ++++++++- laplace/curvature/asdl.py | 159 +++++++++++++++++++++++--------- laplace/curvature/backpack.py | 78 +++++++++++----- laplace/curvature/curvature.py | 46 ++++++++- laplace/curvature/curvlinops.py | 66 +++++++++---- laplace/lllaplace.py | 18 +++- tests/test_baselaplace.py | 109 ++++++++++++++++------ 7 files changed, 398 insertions(+), 123 deletions(-) diff --git a/laplace/baselaplace.py b/laplace/baselaplace.py index ff9aed28..aa1ef11e 100644 --- a/laplace/baselaplace.py +++ b/laplace/baselaplace.py @@ -55,6 +55,14 @@ class BaseLaplace: enable_backprop: bool, default=False whether to enable backprop to the input `x` through the Laplace predictive. Useful for e.g. Bayesian optimization. + dict_key_x: str, default='input_ids' + The dictionary key under which the input tensor `x` is stored. Only has effect + when the model takes a `MutableMapping` as the input. Useful for Huggingface + LLM models. + dict_key_y: str, default='labels' + The dictionary key under which the target tensor `y` is stored. Only has effect + when the model takes a `MutableMapping` as the input. Useful for Huggingface + LLM models. backend : subclasses of `laplace.curvature.CurvatureInterface` backend for access to curvature/Hessian approximations. Defaults to CurvlinopsGGN if None. backend_kwargs : dict, default=None @@ -73,6 +81,8 @@ def __init__( prior_mean=0.0, temperature=1.0, enable_backprop=False, + dict_key_x='input_ids', + dict_key_y='labels', backend=None, backend_kwargs=None, asdl_fisher_kwargs=None, @@ -109,6 +119,10 @@ def __init__( self.temperature = temperature self.enable_backprop = enable_backprop + # For models with dict-like inputs (e.g. Huggingface LLMs) + self.dict_key_x = dict_key_x + self.dict_key_y = dict_key_y + if backend is None: backend = CurvlinopsGGN else: @@ -137,7 +151,11 @@ def _device(self): def backend(self): if self._backend is None: self._backend = self._backend_cls( - self.model, self.likelihood, **self._backend_kwargs + self.model, + self.likelihood, + dict_key_x=self.dict_key_x, + dict_key_y=self.dict_key_y, + **self._backend_kwargs, ) return self._backend @@ -481,6 +499,8 @@ def __init__( prior_mean=0.0, temperature=1.0, enable_backprop=False, + dict_key_x='inputs_id', + dict_key_y='labels', backend=None, backend_kwargs=None, asdl_fisher_kwargs=None, @@ -493,6 +513,8 @@ def __init__( prior_mean, temperature, enable_backprop, + dict_key_x, + dict_key_y, backend, backend_kwargs, asdl_fisher_kwargs, @@ -538,12 +560,15 @@ def fit(self, train_loader, override=True, progress_bar=False): data = next(iter(train_loader)) with torch.no_grad(): if isinstance(data, MutableMapping): # To support Huggingface dataset - if isinstance(self, DiagLaplace) and self._backend_cls == CurvlinopsEF: + if 'backpack' in self._backend_cls.__name__.lower() or ( + isinstance(self, DiagLaplace) and self._backend_cls == CurvlinopsEF + ): raise ValueError( 'Currently DiagEF is not supported under CurvlinopsEF backend ' + 'for custom models with non-tensor inputs ' + '(https://github.com/pytorch/functorch/issues/159). Consider ' - + 'using AsdlEF backend instead.' + + 'using AsdlEF backend instead. The same limitation applies ' + + 'to all BackPACK backend' ) out = self.model(data) @@ -565,7 +590,7 @@ def fit(self, train_loader, override=True, progress_bar=False): for data in pbar: if isinstance(data, MutableMapping): # To support Huggingface dataset - X, y = data, data['labels'].to(self._device) + X, y = data, data[self.dict_key_y].to(self._device) else: X, y = data X, y = X.to(self._device), y.to(self._device) @@ -1106,6 +1131,8 @@ def __init__( prior_mean=0.0, temperature=1.0, enable_backprop=False, + dict_key_x='input_ids', + dict_key_y='labels', backend=None, backend_kwargs=None, ): @@ -1117,6 +1144,8 @@ def __init__( prior_mean, temperature, enable_backprop, + dict_key_x, + dict_key_y, backend, backend_kwargs, ) @@ -1223,6 +1252,8 @@ def __init__( prior_mean=0.0, temperature=1.0, enable_backprop=False, + dict_key_x='inputs_id', + dict_key_y='labels', backend=None, damping=False, backend_kwargs=None, @@ -1238,6 +1269,8 @@ def __init__( prior_mean, temperature, enable_backprop, + dict_key_x, + dict_key_y, backend, backend_kwargs, asdl_fisher_kwargs, @@ -1370,6 +1403,8 @@ def __init__( prior_mean=0, temperature=1, enable_backprop=False, + dict_key_x='inputs_id', + dict_key_y='labels', backend=AsdlHessian, backend_kwargs=None, ): @@ -1381,6 +1416,8 @@ def __init__( prior_mean=prior_mean, temperature=temperature, enable_backprop=enable_backprop, + dict_key_x=dict_key_x, + dict_key_y=dict_key_y, backend=backend, backend_kwargs=backend_kwargs, ) diff --git a/laplace/curvature/asdl.py b/laplace/curvature/asdl.py index cbd1ccc7..a2a297ae 100644 --- a/laplace/curvature/asdl.py +++ b/laplace/curvature/asdl.py @@ -1,10 +1,16 @@ +from collections.abc import MutableMapping import warnings import numpy as np import torch from asdl.matrices import ( - FISHER_EXACT, FISHER_MC, FISHER_EMP, SHAPE_KRON, SHAPE_DIAG, SHAPE_FULL + FISHER_EXACT, + FISHER_MC, + FISHER_EMP, + SHAPE_KRON, + SHAPE_DIAG, + SHAPE_FULL, ) from asdl.grad_maker import LOSS_MSE, LOSS_CROSS_ENTROPY from asdl.fisher import FisherConfig, get_fisher_maker @@ -14,16 +20,24 @@ from laplace.curvature import CurvatureInterface, GGNInterface, EFInterface from laplace.utils import Kron, _is_batchnorm -from collections import UserDict - EPS = 1e-6 class AsdlInterface(CurvatureInterface): - """Interface for asdfghjkl backend. - """ - def __init__(self, model, likelihood, last_layer=False, subnetwork_indices=None): - super().__init__(model, likelihood, last_layer, subnetwork_indices) + """Interface for asdfghjkl backend.""" + + def __init__( + self, + model, + likelihood, + last_layer=False, + subnetwork_indices=None, + dict_key_x='input_ids', + dict_key_y='labels', + ): + super().__init__( + model, likelihood, last_layer, subnetwork_indices, dict_key_x, dict_key_y + ) @property def loss_type(self): @@ -35,9 +49,9 @@ def jacobians(self, x, enable_backprop=False): Parameters ---------- - x : torch.Tensor or UserDict + x : torch.Tensor or MutableMapping (e.g. dict, UserDict) input data `(batch, input_shape)` on compatible device with model if torch.Tensor. - If UserDict, then at least contains key ['input_ids'] or ['input_ids_0', 'input_ids_1']. + If MutableMapping, then at least contains `self.dict_key_x`. The latter is specific for reward modeling. enable_backprop : bool, default = False whether to enable backprop through the Js and f w.r.t. x @@ -51,14 +65,22 @@ def jacobians(self, x, enable_backprop=False): """ Js = list() for i in range(self.model.output_size): + def closure(): self.model.zero_grad() f = self.model(x) loss = f[:, i].sum() - loss.backward(create_graph=enable_backprop, retain_graph=enable_backprop) + loss.backward( + create_graph=enable_backprop, retain_graph=enable_backprop + ) return f - Ji, f = batch_gradient(self.model, closure, return_outputs=True, batch_size=self._get_batch_size(x)) + Ji, f = batch_gradient( + self.model, + closure, + return_outputs=True, + batch_size=self._get_batch_size(x), + ) if self.subnetwork_indices is not None: Ji = Ji[:, self.subnetwork_indices] Js.append(Ji) @@ -81,13 +103,16 @@ def gradients(self, x, y): Gs : torch.Tensor gradients `(batch, parameters)` """ + def closure(): self.model.zero_grad() loss = self.lossfunc(self.model(x), y) loss.backward() return loss - Gs, loss = batch_gradient(self.model, closure, return_outputs=True, batch_size=self._get_batch_size(x)) + Gs, loss = batch_gradient( + self.model, closure, return_outputs=True, batch_size=self._get_batch_size(x) + ) if self.subnetwork_indices is not None: Gs = Gs[:, self.subnetwork_indices] return Gs, loss @@ -125,20 +150,27 @@ def _get_kron_factors(self, M): def _rescale_kron_factors(kron, N): for F in kron.kfacs: if len(F) == 2: - F[1] *= 1/N + F[1] *= 1 / N return kron def diag(self, X, y, N=None, **kwargs): del N if self.last_layer: _, X = self.model.forward_with_features(X) - cfg = FisherConfig(fisher_type=self._ggn_type, loss_type=self.loss_type, - fisher_shapes=[SHAPE_DIAG], data_size=1, **kwargs) + cfg = FisherConfig( + fisher_type=self._ggn_type, + loss_type=self.loss_type, + fisher_shapes=[SHAPE_DIAG], + data_size=1, + **kwargs, + ) fisher_maker = get_fisher_maker(self.model, cfg) y = y if self.loss_type == LOSS_MSE else y.view(-1) if 'emp' in self._ggn_type: dummy = fisher_maker.setup_model_call(self._model, X) - dummy = dummy if self.loss_type == LOSS_MSE else dummy.view(-1, dummy.size(-1)) + dummy = ( + dummy if self.loss_type == LOSS_MSE else dummy.view(-1, dummy.size(-1)) + ) fisher_maker.setup_loss_call(self.lossfunc, dummy, y) else: fisher_maker.setup_model_call(self._model, X) @@ -158,19 +190,26 @@ def diag(self, X, y, N=None, **kwargs): if type(self) is AsdlEF and self.likelihood == 'regression': curv_factor = 0.5 # correct scaling for diag ef else: - curv_factor = 1.0 # ASDL uses proper 1/2 * MSELoss + curv_factor = 1.0 # ASDL uses proper 1/2 * MSELoss return self.factor * loss, curv_factor * diag_ggn def kron(self, X, y, N, **kwargs): if self.last_layer: _, X = self.model.forward_with_features(X) - cfg = FisherConfig(fisher_type=self._ggn_type, loss_type=self.loss_type, - fisher_shapes=[SHAPE_KRON], data_size=1, **kwargs) + cfg = FisherConfig( + fisher_type=self._ggn_type, + loss_type=self.loss_type, + fisher_shapes=[SHAPE_KRON], + data_size=1, + **kwargs, + ) fisher_maker = get_fisher_maker(self.model, cfg) y = y if self.loss_type == LOSS_MSE else y.view(-1) if 'emp' in self._ggn_type: dummy = fisher_maker.setup_model_call(self._model, X) - dummy = dummy if self.loss_type == LOSS_MSE else dummy.view(-1, dummy.size(-1)) + dummy = ( + dummy if self.loss_type == LOSS_MSE else dummy.view(-1, dummy.size(-1)) + ) fisher_maker.setup_loss_call(self.lossfunc, dummy, y) else: fisher_maker.setup_model_call(self._model, X) @@ -184,31 +223,39 @@ def kron(self, X, y, N, **kwargs): if type(self) is AsdlEF and self.likelihood == 'regression': curv_factor = 0.5 # correct scaling for diag ef else: - curv_factor = 1.0 # ASDL uses proper 1/2 * MSELoss + curv_factor = 1.0 # ASDL uses proper 1/2 * MSELoss return self.factor * loss, curv_factor * kron - @staticmethod - def _get_batch_size(x): + def _get_batch_size(self, x): """ ASDL assumes that all leading dimensions are the batch size by default (batch_size = None). Here, we want to specify that only the first dimension is the actual batch size. This is the case for LLMs. """ - # If x is UserDict, then it has weight-sharing dimension (from Huggingface datasets) - if isinstance(x, UserDict) or isinstance(x, dict): - try: - return x['input_ids'].shape[0] - except KeyError: - # The case of reward modeling; the UserDict contains ['input_ids_0', 'input_ids_1'] - return x['input_ids_0'].shape[0] + if isinstance(x, MutableMapping): + return x[self.dict_key_x].shape[0] else: return None # Use ASDL default behavior class AsdlHessian(AsdlInterface): - - def __init__(self, model, likelihood, last_layer=False, low_rank=10): - super().__init__(model, likelihood, last_layer) + def __init__( + self, + model, + likelihood, + last_layer=False, + dict_key_x='input_ids', + dict_key_y='labels', + low_rank=10, + ): + super().__init__( + model, + likelihood, + last_layer, + subnetwork_indices=None, + dict_key_x=dict_key_x, + dict_key_y=dict_key_y, + ) self.low_rank = low_rank @property @@ -243,22 +290,45 @@ def eig_lowrank(self, data_loader): hess_maker.setup_loss_call(self.lossfunc, dummy, y) # iteratively go through data loader and average eigendecomposition # previously: - eigvals, eigvecs = hessian_eig(self.model, self.lossfunc, data_loader=data_loader, - top_n=self.low_rank, max_iters=self.low_rank*10) + eigvals, eigvecs = hessian_eig( + self.model, + self.lossfunc, + data_loader=data_loader, + top_n=self.low_rank, + max_iters=self.low_rank * 10, + ) eigvals = torch.from_numpy(np.array(eigvals)) - mask = (eigvals > EPS) - eigvecs = torch.stack([vec.get_flatten_vector() for vec in eigvecs], dim=1)[:, mask] + mask = eigvals > EPS + eigvecs = torch.stack([vec.get_flatten_vector() for vec in eigvecs], dim=1)[ + :, mask + ] device = eigvecs.device eigvals = eigvals[mask].to(eigvecs.dtype).to(device) - loss = sum([self.lossfunc(self.model(x.to(device)).detach(), y.to(device)) for x, y in data_loader]) + loss = sum( + [ + self.lossfunc(self.model(x.to(device)).detach(), y.to(device)) + for x, y in data_loader + ] + ) return eigvecs, self.factor * eigvals, self.factor * loss class AsdlGGN(AsdlInterface, GGNInterface): - """Implementation of the `GGNInterface` using asdfghjkl. - """ - def __init__(self, model, likelihood, last_layer=False, subnetwork_indices=None, stochastic=False): - super().__init__(model, likelihood, last_layer, subnetwork_indices) + """Implementation of the `GGNInterface` using asdfghjkl.""" + + def __init__( + self, + model, + likelihood, + last_layer=False, + subnetwork_indices=None, + dict_key_x='input_ids', + dict_key_y='labels', + stochastic=False, + ): + super().__init__( + model, likelihood, last_layer, subnetwork_indices, dict_key_x, dict_key_y + ) self.stochastic = stochastic @property @@ -267,10 +337,7 @@ def _ggn_type(self): class AsdlEF(AsdlInterface, EFInterface): - """Implementation of the `EFInterface` using asdfghjkl. - """ - def __init__(self, model, likelihood, last_layer=False): - super().__init__(model, likelihood, last_layer) + """Implementation of the `EFInterface` using asdfghjkl.""" @property def _ggn_type(self): diff --git a/laplace/curvature/backpack.py b/laplace/curvature/backpack.py index 71a7b327..fafb42d8 100644 --- a/laplace/curvature/backpack.py +++ b/laplace/curvature/backpack.py @@ -2,7 +2,14 @@ import torch from backpack import backpack, extend, memory_cleanup -from backpack.extensions import DiagGGNExact, DiagGGNMC, KFAC, KFLR, SumGradSquared, BatchGrad +from backpack.extensions import ( + DiagGGNExact, + DiagGGNMC, + KFAC, + KFLR, + SumGradSquared, + BatchGrad, +) from backpack.context import CTX from laplace.curvature import CurvatureInterface, GGNInterface, EFInterface @@ -10,10 +17,20 @@ class BackPackInterface(CurvatureInterface): - """Interface for Backpack backend. - """ - def __init__(self, model, likelihood, last_layer=False, subnetwork_indices=None): - super().__init__(model, likelihood, last_layer, subnetwork_indices) + """Interface for Backpack backend.""" + + def __init__( + self, + model, + likelihood, + last_layer=False, + subnetwork_indices=None, + dict_key_x='input_ids', + dict_key_y='labels', + ): + super().__init__( + model, likelihood, last_layer, subnetwork_indices, dict_key_x, dict_key_y + ) extend(self._model) extend(self.lossfunc) @@ -44,13 +61,11 @@ def jacobians(self, x, enable_backprop=False): with backpack(BatchGrad()): if model.output_size > 1: out[:, i].sum().backward( - create_graph=enable_backprop, - retain_graph=enable_backprop + create_graph=enable_backprop, retain_graph=enable_backprop ) else: out.sum().backward( - create_graph=enable_backprop, - retain_graph=enable_backprop + create_graph=enable_backprop, retain_graph=enable_backprop ) to_cat = [] for param in model.parameters(): @@ -92,25 +107,42 @@ def gradients(self, x, y): loss = self.lossfunc(f, y) with backpack(BatchGrad()): loss.backward() - Gs = torch.cat([p.grad_batch.data.flatten(start_dim=1) - for p in self._model.parameters()], dim=1) + Gs = torch.cat( + [p.grad_batch.data.flatten(start_dim=1) for p in self._model.parameters()], + dim=1, + ) if self.subnetwork_indices is not None: Gs = Gs[:, self.subnetwork_indices] return Gs, loss class BackPackGGN(BackPackInterface, GGNInterface): - """Implementation of the `GGNInterface` using Backpack. - """ - def __init__(self, model, likelihood, last_layer=False, subnetwork_indices=None, stochastic=False): - super().__init__(model, likelihood, last_layer, subnetwork_indices) + """Implementation of the `GGNInterface` using Backpack.""" + + def __init__( + self, + model, + likelihood, + last_layer=False, + subnetwork_indices=None, + dict_key_x='input_ids', + dict_key_y='labels', + stochastic=False, + ): + super().__init__( + model, likelihood, last_layer, subnetwork_indices, dict_key_x, dict_key_y + ) self.stochastic = stochastic def _get_diag_ggn(self): if self.stochastic: - return torch.cat([p.diag_ggn_mc.data.flatten() for p in self._model.parameters()]) + return torch.cat( + [p.diag_ggn_mc.data.flatten() for p in self._model.parameters()] + ) else: - return torch.cat([p.diag_ggn_exact.data.flatten() for p in self._model.parameters()]) + return torch.cat( + [p.diag_ggn_exact.data.flatten() for p in self._model.parameters()] + ) def _get_kron_factors(self): if self.stochastic: @@ -124,7 +156,7 @@ def _rescale_kron_factors(kron, M, N): # for M=N (full-batch) just M/N=1 for F in kron.kfacs: if len(F) == 2: - F[1] *= M/N + F[1] *= M / N return kron def diag(self, X, y, **kwargs): @@ -158,8 +190,7 @@ def kron(self, X, y, N, **kwargs) -> Tuple[torch.Tensor, Kron]: class BackPackEF(BackPackInterface, EFInterface): - """Implementation of `EFInterface` using Backpack. - """ + """Implementation of `EFInterface` using Backpack.""" def diag(self, X, y, **kwargs): f = self.model(X) @@ -169,8 +200,9 @@ def diag(self, X, y, **kwargs): loss = self.lossfunc(f, y) with backpack(SumGradSquared()): loss.backward() - diag_EF = torch.cat([p.sum_grad_squared.data.flatten() - for p in self._model.parameters()]) + diag_EF = torch.cat( + [p.sum_grad_squared.data.flatten() for p in self._model.parameters()] + ) if self.subnetwork_indices is not None: diag_EF = diag_EF[self.subnetwork_indices] @@ -184,5 +216,5 @@ def _cleanup(module): for child in module.children(): _cleanup(child) - setattr(module, "_backpack_extend", False) + setattr(module, '_backpack_extend', False) memory_cleanup(module) diff --git a/laplace/curvature/curvature.py b/laplace/curvature/curvature.py index 0dde26d3..7e8bde99 100644 --- a/laplace/curvature/curvature.py +++ b/laplace/curvature/curvature.py @@ -20,6 +20,14 @@ class CurvatureInterface: subnetwork_indices : torch.Tensor, default=None indices of the vectorized model parameters that define the subnetwork to apply the Laplace approximation over + dict_key_x: str, default='input_ids' + The dictionary key under which the input tensor `x` is stored. Only has effect + when the model takes a `MutableMapping` as the input. Useful for Huggingface + LLM models. + dict_key_y: str, default='labels' + The dictionary key under which the target tensor `y` is stored. Only has effect + when the model takes a `MutableMapping` as the input. Useful for Huggingface + LLM models. Attributes ---------- @@ -29,18 +37,30 @@ class CurvatureInterface: For example, \\(\\frac{1}{2}\\) to get to \\(\\mathcal{N}(f, 1)\\) from MSELoss. """ - def __init__(self, model, likelihood, last_layer=False, subnetwork_indices=None): + def __init__( + self, + model, + likelihood, + last_layer=False, + subnetwork_indices=None, + dict_key_x='input_ids', + dict_key_y='labels', + ): assert likelihood in ['regression', 'classification'] self.likelihood = likelihood self.model = model self.last_layer = last_layer self.subnetwork_indices = subnetwork_indices + self.dict_key_x = dict_key_x + self.dict_key_y = dict_key_y + if likelihood == 'regression': self.lossfunc = MSELoss(reduction='sum') self.factor = 0.5 else: self.lossfunc = CrossEntropyLoss(reduction='sum') self.factor = 1.0 + self.params = [p for p in self._model.parameters() if p.requires_grad] self.params_dict = { k: v for k, v in self._model.named_parameters() if v.requires_grad @@ -281,9 +301,17 @@ class GGNInterface(CurvatureInterface): subnetwork_indices : torch.Tensor, default=None indices of the vectorized model parameters that define the subnetwork to apply the Laplace approximation over + dict_key_x: str, default='input_ids' + The dictionary key under which the input tensor `x` is stored. Only has effect + when the model takes a `MutableMapping` as the input. Useful for Huggingface + LLM models. + dict_key_y: str, default='labels' + The dictionary key under which the target tensor `y` is stored. Only has effect + when the model takes a `MutableMapping` as the input. Useful for Huggingface + LLM models. stochastic : bool, default=False Fisher if stochastic else GGN - num_samples: int, default=100 + num_samples: int, default=1 Number of samples used to approximate the stochastic Fisher """ @@ -293,12 +321,16 @@ def __init__( likelihood, last_layer=False, subnetwork_indices=None, + dict_key_x='input_ids', + dict_key_y='labels', stochastic=False, num_samples=1, ): self.stochastic = stochastic self.num_samples = num_samples - super().__init__(model, likelihood, last_layer, subnetwork_indices) + super().__init__( + model, likelihood, last_layer, subnetwork_indices, dict_key_x, dict_key_y + ) def _get_mc_functional_fisher(self, f): """Approximate the Fisher's middle matrix (expected outer product of the functional gradient) @@ -398,6 +430,14 @@ class EFInterface(CurvatureInterface): subnetwork_indices : torch.Tensor, default=None indices of the vectorized model parameters that define the subnetwork to apply the Laplace approximation over + dict_key_x: str, default='input_ids' + The dictionary key under which the input tensor `x` is stored. Only has effect + when the model takes a `MutableMapping` as the input. Useful for Huggingface + LLM models. + dict_key_y: str, default='labels' + The dictionary key under which the target tensor `y` is stored. Only has effect + when the model takes a `MutableMapping` as the input. Useful for Huggingface + LLM models. Attributes ---------- diff --git a/laplace/curvature/curvlinops.py b/laplace/curvature/curvlinops.py index 4489f760..d7bb9d7b 100644 --- a/laplace/curvature/curvlinops.py +++ b/laplace/curvature/curvlinops.py @@ -2,8 +2,11 @@ import numpy as np from curvlinops import ( - HessianLinearOperator, GGNLinearOperator, FisherMCLinearOperator, EFLinearOperator, - KFACLinearOperator + HessianLinearOperator, + GGNLinearOperator, + FisherMCLinearOperator, + EFLinearOperator, + KFACLinearOperator, ) from laplace.curvature import CurvatureInterface, GGNInterface, EFInterface @@ -13,10 +16,20 @@ class CurvlinopsInterface(CurvatureInterface): - """Interface for Curvlinops backend. - """ - def __init__(self, model, likelihood, last_layer=False, subnetwork_indices=None): - super().__init__(model, likelihood, last_layer, subnetwork_indices) + """Interface for Curvlinops backend. """ + + def __init__( + self, + model, + likelihood, + last_layer=False, + subnetwork_indices=None, + dict_key_x='input_ids', + dict_key_y='labels', + ): + super().__init__( + model, likelihood, last_layer, subnetwork_indices, dict_key_x, dict_key_y + ) @property def _kron_fisher_type(self): @@ -32,7 +45,7 @@ def _rescale_kron_factors(kron, M, N): # for M=N (full-batch) just M/N=1 for F in kron.kfacs: if len(F) == 2: - F[1] *= M/N + F[1] *= M / N return kron def _get_kron_factors(self, linop): @@ -59,9 +72,12 @@ def _get_kron_factors(self, linop): def kron(self, X, y, N, **kwargs): if isinstance(X, (dict, UserDict)): - kwargs['batch_size_fn'] = lambda x: x['input_ids'].shape[0] + kwargs['batch_size_fn'] = lambda x: x[self.dict_key_x].shape[0] linop = KFACLinearOperator( - self.model, self.lossfunc, self.params, [(X, y)], + self.model, + self.lossfunc, + self.params, + [(X, y)], fisher_type=self._kron_fisher_type, loss_average=None, # Since self.lossfunc is sum separate_weight_and_bias=True, @@ -88,13 +104,19 @@ def full(self, X, y, **kwargs): curvlinops_kwargs = {k: v for k, v in kwargs.items() if k != 'N'} if isinstance(X, (dict, UserDict)): - curvlinops_kwargs['batch_size_fn'] = lambda x: x['input_ids'].shape[0] - - linop = self._linop_context(self.model, self.lossfunc, self.params, [(X, y)], - check_deterministic=False, **curvlinops_kwargs) + curvlinops_kwargs['batch_size_fn'] = lambda x: x[self.dict_key_x].shape[0] + + linop = self._linop_context( + self.model, + self.lossfunc, + self.params, + [(X, y)], + check_deterministic=False, + **curvlinops_kwargs, + ) H = torch.as_tensor( linop @ torch.eye(linop.shape[0]), - device=next(self.model.parameters()).device + device=next(self.model.parameters()).device, ) f = self.model(X) @@ -105,8 +127,20 @@ def full(self, X, y, **kwargs): class CurvlinopsGGN(CurvlinopsInterface, GGNInterface): """Implementation of the `GGNInterface` using Curvlinops.""" - def __init__(self, model, likelihood, last_layer=False, subnetwork_indices=None, stochastic=False): - super().__init__(model, likelihood, last_layer, subnetwork_indices) + + def __init__( + self, + model, + likelihood, + last_layer=False, + subnetwork_indices=None, + dict_key_x='input_ids', + dict_key_y='labels', + stochastic=False, + ): + super().__init__( + model, likelihood, last_layer, subnetwork_indices, dict_key_x, dict_key_y + ) self.stochastic = stochastic @property diff --git a/laplace/lllaplace.py b/laplace/lllaplace.py index c08e0bb5..3325f989 100644 --- a/laplace/lllaplace.py +++ b/laplace/lllaplace.py @@ -52,6 +52,14 @@ class LLLaplace(ParametricLaplace): enable_backprop: bool, default=False whether to enable backprop to the input `x` through the Laplace predictive. Useful for e.g. Bayesian optimization. + dict_key_x: str, default='input_ids' + The dictionary key under which the input tensor `x` is stored. Only has effect + when the model takes a `MutableMapping` as the input. Useful for Huggingface + LLM models. + dict_key_y: str, default='labels' + The dictionary key under which the target tensor `y` is stored. Only has effect + when the model takes a `MutableMapping` as the input. Useful for Huggingface + LLM models. backend : subclasses of `laplace.curvature.CurvatureInterface` backend for access to curvature/Hessian approximations last_layer_name: str, default=None @@ -70,6 +78,8 @@ def __init__( prior_mean=0.0, temperature=1.0, enable_backprop=False, + dict_key_x='inputs_id', + dict_key_y='labels', backend=None, last_layer_name=None, backend_kwargs=None, @@ -86,6 +96,8 @@ def __init__( prior_mean=0.0, temperature=temperature, enable_backprop=enable_backprop, + dict_key_x=dict_key_x, + dict_key_y=dict_key_y, backend=backend, backend_kwargs=backend_kwargs, ) @@ -276,10 +288,12 @@ def __init__( prior_mean=0.0, temperature=1.0, enable_backprop=False, + dict_key_x='inputs_id', + dict_key_y='labels', backend=None, last_layer_name=None, damping=False, - **backend_kwargs + **backend_kwargs, ): self.damping = damping super().__init__( @@ -290,6 +304,8 @@ def __init__( prior_mean, temperature, enable_backprop, + dict_key_x, + dict_key_y, backend, last_layer_name, backend_kwargs, diff --git a/tests/test_baselaplace.py b/tests/test_baselaplace.py index 24defee8..1eb622ad 100644 --- a/tests/test_baselaplace.py +++ b/tests/test_baselaplace.py @@ -10,6 +10,7 @@ from torch.nn.utils import parameters_to_vector from torch.utils.data import DataLoader, TensorDataset from torch.distributions import Normal, Categorical +from laplace.curvature.backpack import BackPackEF from laplace.curvature.curvlinops import CurvlinopsEF, CurvlinopsGGN from torchvision.models import wide_resnet50_2 @@ -55,7 +56,7 @@ def __init__(self): def forward(self, data: MutableMapping | torch.Tensor): if isinstance(data, MutableMapping): - x = data['input_ids'].to(next(self.parameters()).device) + x = data['test_input_key'].to(next(self.parameters()).device) else: x = data @@ -111,10 +112,25 @@ def reg_loader(): @pytest.fixture -def custom_loader(): +def custom_loader_clf(): data = [] for _ in range(10): - datum = {'input_ids': torch.randn(5), 'labels': torch.randint(2, (1,))} + datum = { + 'test_input_key': torch.randn(5), + 'test_label_key': torch.randint(2, (1,)), + } + data.append(datum) + return DataLoader(ListDataset(data), batch_size=3, collate_fn=dict_data_collator) + + +@pytest.fixture +def custom_loader_reg(): + data = [] + for _ in range(10): + datum = { + 'test_input_key': torch.randn(5), + 'test_label_key': torch.randn(2), + } data.append(datum) return DataLoader(ListDataset(data), batch_size=3, collate_fn=dict_data_collator) @@ -578,33 +594,74 @@ def test_reward_modeling(laplace, reward_model, reward_loader, reward_test_X): @pytest.mark.parametrize('laplace', [KronLaplace, DiagLaplace]) -@pytest.mark.parametrize('backend', [AsdlEF, AsdlGGN, CurvlinopsEF, CurvlinopsGGN]) -def test_dict_data(laplace, backend, custom_model, custom_loader): - if laplace == DiagLaplace and backend == CurvlinopsEF: - pytest.skip( - 'DiagEF is unsupported with Curvlinops when the input is non-tensor.' - ) +@pytest.mark.parametrize( + 'backend', [AsdlEF, AsdlGGN, CurvlinopsEF, CurvlinopsGGN, BackPackGGN, BackPackEF] +) +@pytest.mark.parametrize( + 'lik,custom_loader', + [ + ('classification', 'custom_loader_clf'), + ('regression', 'custom_loader_reg'), + ('reward_modeling', 'custom_loader_clf'), + ], +) +def test_dict_data(laplace, backend, lik, custom_loader, custom_model, request): + custom_loader = request.getfixturevalue(custom_loader) + + if ( + 'backpack' not in backend.__name__.lower() + and laplace != DiagLaplace + and laplace != CurvlinopsEF + ): + with pytest.raises(KeyError): + # Raises an error since custom_loader's input is under the key 'test_input_key' + # but the default is 'input_ids' + lap = laplace(custom_model, lik, backend=backend) + lap.fit(custom_loader) + + lap = laplace( + custom_model, + lik, + backend=backend, + dict_key_x='test_input_key', + dict_key_y='test_label_key', + ) + + if ('backpack' in backend.__name__.lower()) or ( + laplace == DiagLaplace and backend == CurvlinopsEF + ): + # Unsupported, thus raises an exception + with pytest.raises(ValueError): + lap.fit(custom_loader) - for data in custom_loader: - print(data) + return - lap = laplace(custom_model, 'classification', backend=backend) lap.fit(custom_loader) test_data = next(iter(custom_loader)) f = custom_model(test_data) - # GLM predictive - f_pred = lap(test_data, pred_type='glm') - assert f_pred.shape == f.shape - assert torch.allclose( - f_pred.sum(), torch.tensor(len(f_pred), dtype=torch.double) - ) # sum up to 1 + if lik == 'classification': + f_pred = lap(test_data, pred_type='glm') + assert f_pred.shape == f.shape + assert torch.allclose( + f_pred.sum(), torch.tensor(len(f_pred), dtype=torch.double) + ) # sum up to 1 + + f_pred = lap(test_data, pred_type='nn', link_approx='mc') + assert f_pred.shape == f.shape + assert torch.allclose( + f_pred.sum(), torch.tensor(len(f_pred), dtype=torch.double) + ) + else: + f_pred, f_var = lap(test_data, pred_type='glm') + assert f_pred.shape == f.shape + assert torch.allclose(f_pred, f) + assert f_var.shape == (f_pred.shape[0], f_pred.shape[1], f_pred.shape[1]) - # NN predictive - f_pred = lap(test_data, pred_type='nn', link_approx='mc') - assert f_pred.shape == f.shape - assert torch.allclose(f_pred.sum(), torch.tensor(len(f_pred), dtype=torch.double)) + f_pred, f_var = lap(test_data, pred_type='nn', link_approx='mc') + assert f_pred.shape == f.shape + assert f_var.shape == (f_pred.shape[0], f_pred.shape[1]) # TODO: Add LowRankLaplace @@ -689,11 +746,3 @@ def test_backprop_nn(laplace, model, reg_loader, backend): assert grad_X_var.shape == X.shape except ValueError: assert False - - -@pytest.mark.parametrize('likelihood', ['classification', 'regression']) -def test_dict_data_diagEF_curvlinops_fails(custom_model, custom_loader, likelihood): - lap = DiagLaplace(custom_model, likelihood=likelihood, backend=CurvlinopsEF) - - with pytest.raises(ValueError): - lap.fit(custom_loader) From b61d20c11b64c06ab24d5434d3bc985273a59e61 Mon Sep 17 00:00:00 2001 From: Agustinus Kristiadi Date: Thu, 25 Apr 2024 18:39:54 -0400 Subject: [PATCH 2/3] Update `utils.py` and `marglik_training.py` --- laplace/baselaplace.py | 1 + laplace/marglik_training.py | 20 ++++++++-- laplace/utils/utils.py | 79 +++++++++++++++++++++++++------------ tests/test_utils.py | 22 +++++++++-- 4 files changed, 90 insertions(+), 32 deletions(-) diff --git a/laplace/baselaplace.py b/laplace/baselaplace.py index aa1ef11e..7d16575f 100644 --- a/laplace/baselaplace.py +++ b/laplace/baselaplace.py @@ -424,6 +424,7 @@ def _gridsearch( link_approx=link_approx, n_samples=n_samples, loss_with_var=loss_with_var, + dict_key_y=self.dict_key_y, ) except RuntimeError: result = np.inf diff --git a/laplace/marglik_training.py b/laplace/marglik_training.py index b57ac520..c968823c 100644 --- a/laplace/marglik_training.py +++ b/laplace/marglik_training.py @@ -1,3 +1,4 @@ +from collections.abc import MutableMapping from copy import deepcopy import numpy as np import torch @@ -6,7 +7,6 @@ from torch.nn.utils import parameters_to_vector import warnings import logging -from collections import UserDict import tqdm from laplace import Laplace @@ -36,6 +36,8 @@ def marglik_training( fix_sigma_noise=False, progress_bar=False, enable_backprop=False, + dict_key_x='input_ids', + dict_key_y='labels', ): """Marginal-likelihood based training (Algorithm 1 in [1]). Optimize model parameters and hyperparameters jointly. @@ -115,6 +117,14 @@ def marglik_training( whether to show a progress bar (updated per epoch) or not enable_backprop : bool, default=False make the returned Laplace instance backpropable---useful for e.g. Bayesian optimization. + dict_key_x: str, default='input_ids' + The dictionary key under which the input tensor `x` is stored. Only has effect + when the model takes a `MutableMapping` as the input. Useful for Huggingface + LLM models. + dict_key_y: str, default='labels' + The dictionary key under which the target tensor `y` is stored. Only has effect + when the model takes a `MutableMapping` as the input. Useful for Huggingface + LLM models. Returns ------- @@ -194,8 +204,8 @@ def marglik_training( # standard NN training per batch for data in train_loader: - if isinstance(data, UserDict) or isinstance(data, dict): - X, y = data, data['labels'] + if isinstance(data, MutableMapping): + X, y = data, data[dict_key_y] y = y.to(device, non_blocking=True) else: X, y = data @@ -257,6 +267,8 @@ def marglik_training( temperature=temperature, backend=backend, subset_of_weights='all', + dict_key_x=dict_key_x, + dict_key_y=dict_key_y, ) lap.fit(train_loader) @@ -311,6 +323,8 @@ def marglik_training( backend=backend, subset_of_weights='all', enable_backprop=enable_backprop, + dict_key_x=dict_key_x, + dict_key_y=dict_key_y, ) lap.fit(train_loader) return lap, model, margliks, losses diff --git a/laplace/utils/utils.py b/laplace/utils/utils.py index 315784fb..82c93fef 100644 --- a/laplace/utils/utils.py +++ b/laplace/utils/utils.py @@ -1,3 +1,4 @@ +from collections.abc import MutableMapping import logging from typing import Union import numpy as np @@ -7,12 +8,19 @@ from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d from torch.distributions.multivariate_normal import _precision_to_scale_tril from torchmetrics import Metric -from collections import UserDict -import math -__all__ = ['get_nll', 'validate', 'parameters_per_layer', 'invsqrt_precision', 'kron', - 'diagonal_add_scalar', 'symeig', 'block_diag', 'expand_prior_precision'] +__all__ = [ + 'get_nll', + 'validate', + 'parameters_per_layer', + 'invsqrt_precision', + 'kron', + 'diagonal_add_scalar', + 'symeig', + 'block_diag', + 'expand_prior_precision', +] def get_nll(out_dist, targets): @@ -20,7 +28,16 @@ def get_nll(out_dist, targets): @torch.no_grad() -def validate(laplace, val_loader, loss, pred_type='glm', link_approx='probit', n_samples=100, loss_with_var=False) -> float: +def validate( + laplace, + val_loader, + loss, + pred_type='glm', + link_approx='probit', + n_samples=100, + loss_with_var=False, + dict_key_y='labels', +) -> float: laplace.model.eval() assert callable(loss) or isinstance(loss, Metric) is_offline = not isinstance(loss, Metric) @@ -30,17 +47,15 @@ def validate(laplace, val_loader, loss, pred_type='glm', link_approx='probit', n targets = list() for data in val_loader: - # If x is UserDict, then it is a from Huggingface dataset - if isinstance(data, UserDict) or isinstance(data, dict): - X, y = data, data['labels'] + if isinstance(data, MutableMapping): + X, y = data, data[dict_key_y] else: X, y = data X = X.to(laplace._device) y = y.to(laplace._device) out = laplace( - X, pred_type=pred_type, - link_approx=link_approx, - n_samples=n_samples) + X, pred_type=pred_type, link_approx=link_approx, n_samples=n_samples + ) if type(out) == tuple: if is_offline: @@ -97,9 +112,11 @@ def invsqrt_precision(M): def _is_batchnorm(module): - if isinstance(module, BatchNorm1d) or \ - isinstance(module, BatchNorm2d) or \ - isinstance(module, BatchNorm3d): + if ( + isinstance(module, BatchNorm1d) + or isinstance(module, BatchNorm2d) + or isinstance(module, BatchNorm3d) + ): return True return False @@ -134,9 +151,9 @@ def kron(t1, t2): tiled_t2 = t2.repeat(t1_height, t1_width) expanded_t1 = ( t1.unsqueeze(2) - .unsqueeze(3) - .repeat(1, t2_height, t2_width, 1) - .view(out_height, out_width) + .unsqueeze(3) + .repeat(1, t2_height, t2_width, 1) + .view(out_height, out_width) ) return expanded_t1 * tiled_t2 @@ -185,7 +202,7 @@ def symeig(M): M = M + torch.eye(M.shape[0], device=M.device) try: L, W = torch.linalg.eigh(M, UPLO='U') - L -= 1. + L -= 1.0 except RuntimeError: stats = f'diag: {M.diagonal()}, max: {M.abs().max()}, ' stats = stats + f'min: {M.abs().min()}, mean: {M.abs().mean()}' @@ -214,7 +231,7 @@ def block_diag(blocks): p_cur = 0 for block in blocks: p_block = block.shape[0] - M[p_cur:p_cur+p_block, p_cur:p_cur+p_block] = block + M[p_cur : p_cur + p_block, p_cur : p_cur + p_block] = block p_cur += p_block return M @@ -243,11 +260,17 @@ def expand_prior_precision(prior_prec, model): elif len(prior_prec) == P: # full diagonal return prior_prec.to(device) else: - return torch.cat([delta * torch.ones_like(m).flatten() for delta, m - in zip(prior_prec, trainable_params)]) + return torch.cat( + [ + delta * torch.ones_like(m).flatten() + for delta, m in zip(prior_prec, trainable_params) + ] + ) -def fix_prior_prec_structure(prior_prec_init, prior_structure, n_layers, n_params, device): +def fix_prior_prec_structure( + prior_prec_init, prior_structure, n_layers, n_params, device +): if prior_structure == 'scalar': prior_prec_init = torch.full((1,), prior_prec_init, device=device) elif prior_structure == 'layerwise': @@ -275,8 +298,12 @@ def normal_samples(mean, var, n_samples, generator=None): """ assert mean.ndim == 2, 'Invalid input shape of mean, should be 2-dimensional.' _, output_dim = mean.shape - randn_samples = torch.randn((output_dim, n_samples), device=mean.device, - dtype=mean.dtype, generator=generator) + randn_samples = torch.randn( + (output_dim, n_samples), + device=mean.device, + dtype=mean.dtype, + generator=generator, + ) if mean.shape == var.shape: # diagonal covariance @@ -285,7 +312,9 @@ def normal_samples(mean, var, n_samples, generator=None): elif mean.shape == var.shape[:2] and var.shape[-1] == mean.shape[1]: # full covariance scale = torch.linalg.cholesky(var) - scaled_samples = torch.matmul(scale, randn_samples.unsqueeze(0)) # expand batch dim + scaled_samples = torch.matmul( + scale, randn_samples.unsqueeze(0) + ) # expand batch dim return (mean.unsqueeze(-1) + scaled_samples).permute((2, 0, 1)) else: raise ValueError('Invalid input shapes.') diff --git a/tests/test_utils.py b/tests/test_utils.py index d6f9165c..29c142bc 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,7 +1,15 @@ import torch from torch.utils.data import TensorDataset, DataLoader from laplace import Laplace -from laplace.utils import invsqrt_precision, diagonal_add_scalar, symeig, normal_samples, validate, get_nll, RunningNLLMetric +from laplace.utils import ( + invsqrt_precision, + diagonal_add_scalar, + symeig, + normal_samples, + validate, + get_nll, + RunningNLLMetric, +) import math @@ -71,7 +79,9 @@ def test_validate(): y = torch.randint(3, size=(50,)) dataloader = DataLoader(TensorDataset(X, y), batch_size=10) - model = torch.nn.Sequential(torch.nn.Linear(10, 20), torch.nn.ReLU(), torch.nn.Linear(20, 3)) + model = torch.nn.Sequential( + torch.nn.Linear(10, 20), torch.nn.ReLU(), torch.nn.Linear(20, 3) + ) la = Laplace(model, 'classification', 'all') la.fit(dataloader) @@ -83,9 +93,13 @@ def test_validate(): assert res > 0 res = validate( - la, dataloader, RunningNLLMetric(), pred_type='nn', link_approx='mc', n_samples=10 + la, + dataloader, + RunningNLLMetric(), + pred_type='nn', + link_approx='mc', + n_samples=10, ) assert res != math.nan assert isinstance(res, float) assert res > 0 - From 5ccb81d20f59a090bf84b8747b0308c9a62a54a6 Mon Sep 17 00:00:00 2001 From: Agustinus Kristiadi Date: Mon, 10 Jun 2024 16:44:35 -0400 Subject: [PATCH 3/3] More concise `isinstance` check --- laplace/utils/utils.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/laplace/utils/utils.py b/laplace/utils/utils.py index 82c93fef..f1291097 100644 --- a/laplace/utils/utils.py +++ b/laplace/utils/utils.py @@ -1,15 +1,15 @@ -from collections.abc import MutableMapping import logging +from collections.abc import MutableMapping from typing import Union + import numpy as np import torch import torch.nn.functional as F -from torch.nn.utils import parameters_to_vector -from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d from torch.distributions.multivariate_normal import _precision_to_scale_tril +from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d +from torch.nn.utils import parameters_to_vector from torchmetrics import Metric - __all__ = [ 'get_nll', 'validate', @@ -112,11 +112,7 @@ def invsqrt_precision(M): def _is_batchnorm(module): - if ( - isinstance(module, BatchNorm1d) - or isinstance(module, BatchNorm2d) - or isinstance(module, BatchNorm3d) - ): + if isinstance(module, (BatchNorm1d, BatchNorm2d, BatchNorm3d)): return True return False