diff --git a/laplace/baselaplace.py b/laplace/baselaplace.py index 62b8d6d..269e0fe 100644 --- a/laplace/baselaplace.py +++ b/laplace/baselaplace.py @@ -50,6 +50,14 @@ class BaseLaplace: enable_backprop: bool, default=False whether to enable backprop to the input `x` through the Laplace predictive. Useful for e.g. Bayesian optimization. + dict_key_x: str, default='input_ids' + The dictionary key under which the input tensor `x` is stored. Only has effect + when the model takes a `MutableMapping` as the input. Useful for Huggingface + LLM models. + dict_key_y: str, default='labels' + The dictionary key under which the target tensor `y` is stored. Only has effect + when the model takes a `MutableMapping` as the input. Useful for Huggingface + LLM models. backend : subclasses of `laplace.curvature.CurvatureInterface` backend for access to curvature/Hessian approximations. Defaults to CurvlinopsGGN if None. backend_kwargs : dict, default=None @@ -68,6 +76,8 @@ def __init__( prior_mean=0.0, temperature=1.0, enable_backprop=False, + dict_key_x='input_ids', + dict_key_y='labels', backend=None, backend_kwargs=None, asdl_fisher_kwargs=None, @@ -104,6 +114,10 @@ def __init__( self.temperature = temperature self.enable_backprop = enable_backprop + # For models with dict-like inputs (e.g. Huggingface LLMs) + self.dict_key_x = dict_key_x + self.dict_key_y = dict_key_y + if backend is None: backend = CurvlinopsGGN else: @@ -132,7 +146,11 @@ def _device(self): def backend(self): if self._backend is None: self._backend = self._backend_cls( - self.model, self.likelihood, **self._backend_kwargs + self.model, + self.likelihood, + dict_key_x=self.dict_key_x, + dict_key_y=self.dict_key_y, + **self._backend_kwargs, ) return self._backend @@ -412,6 +430,7 @@ def _gridsearch( link_approx=link_approx, n_samples=n_samples, loss_with_var=loss_with_var, + dict_key_y=self.dict_key_y, ) except RuntimeError: result = np.inf @@ -487,6 +506,8 @@ def __init__( prior_mean=0.0, temperature=1.0, enable_backprop=False, + dict_key_x='inputs_id', + dict_key_y='labels', backend=None, backend_kwargs=None, asdl_fisher_kwargs=None, @@ -499,6 +520,8 @@ def __init__( prior_mean, temperature, enable_backprop, + dict_key_x, + dict_key_y, backend, backend_kwargs, asdl_fisher_kwargs, @@ -544,12 +567,15 @@ def fit(self, train_loader, override=True, progress_bar=False): data = next(iter(train_loader)) with torch.no_grad(): if isinstance(data, MutableMapping): # To support Huggingface dataset - if isinstance(self, DiagLaplace) and self._backend_cls == CurvlinopsEF: + if 'backpack' in self._backend_cls.__name__.lower() or ( + isinstance(self, DiagLaplace) and self._backend_cls == CurvlinopsEF + ): raise ValueError( 'Currently DiagEF is not supported under CurvlinopsEF backend ' + 'for custom models with non-tensor inputs ' + '(https://github.com/pytorch/functorch/issues/159). Consider ' - + 'using AsdlEF backend instead.' + + 'using AsdlEF backend instead. The same limitation applies ' + + 'to all BackPACK backend' ) out = self.model(data) @@ -571,7 +597,7 @@ def fit(self, train_loader, override=True, progress_bar=False): for data in pbar: if isinstance(data, MutableMapping): # To support Huggingface dataset - X, y = data, data['labels'].to(self._device) + X, y = data, data[self.dict_key_y].to(self._device) else: X, y = data X, y = X.to(self._device), y.to(self._device) @@ -1112,6 +1138,8 @@ def __init__( prior_mean=0.0, temperature=1.0, enable_backprop=False, + dict_key_x='input_ids', + dict_key_y='labels', backend=None, backend_kwargs=None, ): @@ -1123,6 +1151,8 @@ def __init__( prior_mean, temperature, enable_backprop, + dict_key_x, + dict_key_y, backend, backend_kwargs, ) @@ -1229,6 +1259,8 @@ def __init__( prior_mean=0.0, temperature=1.0, enable_backprop=False, + dict_key_x='inputs_id', + dict_key_y='labels', backend=None, damping=False, backend_kwargs=None, @@ -1244,6 +1276,8 @@ def __init__( prior_mean, temperature, enable_backprop, + dict_key_x, + dict_key_y, backend, backend_kwargs, asdl_fisher_kwargs, @@ -1376,6 +1410,8 @@ def __init__( prior_mean=0, temperature=1, enable_backprop=False, + dict_key_x='inputs_id', + dict_key_y='labels', backend=AsdfghjklHessian, backend_kwargs=None, ): @@ -1387,6 +1423,8 @@ def __init__( prior_mean=prior_mean, temperature=temperature, enable_backprop=enable_backprop, + dict_key_x=dict_key_x, + dict_key_y=dict_key_y, backend=backend, backend_kwargs=backend_kwargs, ) diff --git a/laplace/curvature/asdfghjkl.py b/laplace/curvature/asdfghjkl.py index 7b0d9a4..9fff7a8 100644 --- a/laplace/curvature/asdfghjkl.py +++ b/laplace/curvature/asdfghjkl.py @@ -139,8 +139,23 @@ def kron(self, X, y, N, **kwargs): class AsdfghjklHessian(AsdfghjklInterface): - def __init__(self, model, likelihood, last_layer=False, low_rank=10): - super().__init__(model, likelihood, last_layer) + def __init__( + self, + model, + likelihood, + last_layer=False, + dict_key_x='input_ids', + dict_key_y='labels', + low_rank=10, + ): + super().__init__( + model, + likelihood, + last_layer, + None, + dict_key_x='input_ids', + dict_key_y='labels', + ) self.low_rank = low_rank @property @@ -187,11 +202,15 @@ def __init__( likelihood, last_layer=False, subnetwork_indices=None, + dict_key_x='input_ids', + dict_key_y='labels', stochastic=False, ): if likelihood != 'classification': raise ValueError('This backend only supports classification currently.') - super().__init__(model, likelihood, last_layer, subnetwork_indices) + super().__init__( + model, likelihood, last_layer, subnetwork_indices, dict_key_x, dict_key_y + ) self.stochastic = stochastic @property @@ -202,10 +221,17 @@ def _ggn_type(self): class AsdfghjklEF(AsdfghjklInterface, EFInterface): """Implementation of the `EFInterface` using asdfghjkl.""" - def __init__(self, model, likelihood, last_layer=False): + def __init__( + self, + model, + likelihood, + last_layer=False, + dict_key_x='input_ids', + dict_key_y='labels', + ): if likelihood != 'classification': raise ValueError('This backend only supports classification currently.') - super().__init__(model, likelihood, last_layer) + super().__init__(model, likelihood, last_layer, None, dict_key_x, dict_key_y) @property def _ggn_type(self): diff --git a/laplace/curvature/asdl.py b/laplace/curvature/asdl.py index f0a8839..b62fa4f 100644 --- a/laplace/curvature/asdl.py +++ b/laplace/curvature/asdl.py @@ -1,3 +1,4 @@ +from collections.abc import MutableMapping import warnings import numpy as np @@ -19,16 +20,24 @@ from laplace.curvature import CurvatureInterface, GGNInterface, EFInterface from laplace.utils import Kron, _is_batchnorm -from collections import UserDict - EPS = 1e-6 class AsdlInterface(CurvatureInterface): """Interface for asdfghjkl backend.""" - def __init__(self, model, likelihood, last_layer=False, subnetwork_indices=None): - super().__init__(model, likelihood, last_layer, subnetwork_indices) + def __init__( + self, + model, + likelihood, + last_layer=False, + subnetwork_indices=None, + dict_key_x='input_ids', + dict_key_y='labels', + ): + super().__init__( + model, likelihood, last_layer, subnetwork_indices, dict_key_x, dict_key_y + ) @property def loss_type(self): @@ -40,9 +49,9 @@ def jacobians(self, x, enable_backprop=False): Parameters ---------- - x : torch.Tensor or UserDict + x : torch.Tensor or MutableMapping (e.g. dict, UserDict) input data `(batch, input_shape)` on compatible device with model if torch.Tensor. - If UserDict, then at least contains key ['input_ids'] or ['input_ids_0', 'input_ids_1']. + If MutableMapping, then at least contains `self.dict_key_x`. The latter is specific for reward modeling. enable_backprop : bool, default = False whether to enable backprop through the Js and f w.r.t. x @@ -217,28 +226,35 @@ def kron(self, X, y, N, **kwargs): curv_factor = 1.0 # ASDL uses proper 1/2 * MSELoss return self.factor * loss, curv_factor * kron - @staticmethod - def _get_batch_size(x): + def _get_batch_size(self, x): """ ASDL assumes that all leading dimensions are the batch size by default (batch_size = None). Here, we want to specify that only the first dimension is the actual batch size. This is the case for LLMs. """ - # If x is UserDict, then it has weight-sharing dimension (from Huggingface datasets) - if isinstance(x, UserDict) or isinstance(x, dict): - try: - return x['input_ids'].shape[0] - except KeyError: - # The case of reward modeling; the UserDict contains ['input_ids_0', 'input_ids_1'] - return x['input_ids_0'].shape[0] + if isinstance(x, MutableMapping): + return x[self.dict_key_x].shape[0] else: return None # Use ASDL default behavior class AsdlHessian(AsdlInterface): - def __init__(self, model, likelihood, last_layer=False, low_rank=10): - super().__init__(model, likelihood, last_layer) - self.low_rank = low_rank + def __init__( + self, + model, + likelihood, + last_layer=False, + dict_key_x='input_ids', + dict_key_y='labels', + ): + super().__init__( + model, + likelihood, + last_layer, + subnetwork_indices=None, + dict_key_x=dict_key_x, + dict_key_y=dict_key_y, + ) @property def _ggn_type(self): @@ -271,9 +287,13 @@ def __init__( likelihood, last_layer=False, subnetwork_indices=None, + dict_key_x='input_ids', + dict_key_y='labels', stochastic=False, ): - super().__init__(model, likelihood, last_layer, subnetwork_indices) + super().__init__( + model, likelihood, last_layer, subnetwork_indices, dict_key_x, dict_key_y + ) self.stochastic = stochastic @property @@ -284,8 +304,15 @@ def _ggn_type(self): class AsdlEF(AsdlInterface, EFInterface): """Implementation of the `EFInterface` using asdfghjkl.""" - def __init__(self, model, likelihood, last_layer=False): - super().__init__(model, likelihood, last_layer) + def __init__( + self, + model, + likelihood, + last_layer=False, + dict_key_x='input_ids', + dict_key_y='labels', + ): + super().__init__(model, likelihood, last_layer, None, dict_key_x, dict_key_y) @property def _ggn_type(self): diff --git a/laplace/curvature/backpack.py b/laplace/curvature/backpack.py index ee91941..6b66015 100644 --- a/laplace/curvature/backpack.py +++ b/laplace/curvature/backpack.py @@ -19,8 +19,19 @@ class BackPackInterface(CurvatureInterface): """Interface for Backpack backend.""" - def __init__(self, model, likelihood, last_layer=False, subnetwork_indices=None): - super().__init__(model, likelihood, last_layer, subnetwork_indices) + def __init__( + self, + model, + likelihood, + last_layer=False, + subnetwork_indices=None, + dict_key_x='input_ids', + dict_key_y='labels', + ): + super().__init__( + model, likelihood, last_layer, subnetwork_indices, dict_key_x, dict_key_y + ) + extend(self._model) extend(self.lossfunc) @@ -115,9 +126,13 @@ def __init__( likelihood, last_layer=False, subnetwork_indices=None, + dict_key_x='input_ids', + dict_key_y='labels', stochastic=False, ): - super().__init__(model, likelihood, last_layer, subnetwork_indices) + super().__init__( + model, likelihood, last_layer, subnetwork_indices, dict_key_x, dict_key_y + ) self.stochastic = stochastic def _get_diag_ggn(self): diff --git a/laplace/curvature/curvature.py b/laplace/curvature/curvature.py index 0617d2f..6d1e040 100644 --- a/laplace/curvature/curvature.py +++ b/laplace/curvature/curvature.py @@ -20,6 +20,14 @@ class CurvatureInterface: subnetwork_indices : torch.Tensor, default=None indices of the vectorized model parameters that define the subnetwork to apply the Laplace approximation over + dict_key_x: str, default='input_ids' + The dictionary key under which the input tensor `x` is stored. Only has effect + when the model takes a `MutableMapping` as the input. Useful for Huggingface + LLM models. + dict_key_y: str, default='labels' + The dictionary key under which the target tensor `y` is stored. Only has effect + when the model takes a `MutableMapping` as the input. Useful for Huggingface + LLM models. Attributes ---------- @@ -29,18 +37,30 @@ class CurvatureInterface: For example, \\(\\frac{1}{2}\\) to get to \\(\\mathcal{N}(f, 1)\\) from MSELoss. """ - def __init__(self, model, likelihood, last_layer=False, subnetwork_indices=None): + def __init__( + self, + model, + likelihood, + last_layer=False, + subnetwork_indices=None, + dict_key_x='input_ids', + dict_key_y='labels', + ): assert likelihood in ['regression', 'classification'] self.likelihood = likelihood self.model = model self.last_layer = last_layer self.subnetwork_indices = subnetwork_indices + self.dict_key_x = dict_key_x + self.dict_key_y = dict_key_y + if likelihood == 'regression': self.lossfunc = MSELoss(reduction='sum') self.factor = 0.5 else: self.lossfunc = CrossEntropyLoss(reduction='sum') self.factor = 1.0 + self.params = [p for p in self._model.parameters() if p.requires_grad] self.params_dict = { k: v for k, v in self._model.named_parameters() if v.requires_grad @@ -281,9 +301,17 @@ class GGNInterface(CurvatureInterface): subnetwork_indices : torch.Tensor, default=None indices of the vectorized model parameters that define the subnetwork to apply the Laplace approximation over + dict_key_x: str, default='input_ids' + The dictionary key under which the input tensor `x` is stored. Only has effect + when the model takes a `MutableMapping` as the input. Useful for Huggingface + LLM models. + dict_key_y: str, default='labels' + The dictionary key under which the target tensor `y` is stored. Only has effect + when the model takes a `MutableMapping` as the input. Useful for Huggingface + LLM models. stochastic : bool, default=False Fisher if stochastic else GGN - num_samples: int, default=100 + num_samples: int, default=1 Number of samples used to approximate the stochastic Fisher """ @@ -293,12 +321,16 @@ def __init__( likelihood, last_layer=False, subnetwork_indices=None, + dict_key_x='input_ids', + dict_key_y='labels', stochastic=False, num_samples=1, ): self.stochastic = stochastic self.num_samples = num_samples - super().__init__(model, likelihood, last_layer, subnetwork_indices) + super().__init__( + model, likelihood, last_layer, subnetwork_indices, dict_key_x, dict_key_y + ) def _get_mc_functional_fisher(self, f): """Approximate the Fisher's middle matrix (expected outer product of the functional gradient) @@ -398,6 +430,14 @@ class EFInterface(CurvatureInterface): subnetwork_indices : torch.Tensor, default=None indices of the vectorized model parameters that define the subnetwork to apply the Laplace approximation over + dict_key_x: str, default='input_ids' + The dictionary key under which the input tensor `x` is stored. Only has effect + when the model takes a `MutableMapping` as the input. Useful for Huggingface + LLM models. + dict_key_y: str, default='labels' + The dictionary key under which the target tensor `y` is stored. Only has effect + when the model takes a `MutableMapping` as the input. Useful for Huggingface + LLM models. Attributes ---------- diff --git a/laplace/curvature/curvlinops.py b/laplace/curvature/curvlinops.py index 20b1e49..d7bb9d7 100644 --- a/laplace/curvature/curvlinops.py +++ b/laplace/curvature/curvlinops.py @@ -18,8 +18,18 @@ class CurvlinopsInterface(CurvatureInterface): """Interface for Curvlinops backend. """ - def __init__(self, model, likelihood, last_layer=False, subnetwork_indices=None): - super().__init__(model, likelihood, last_layer, subnetwork_indices) + def __init__( + self, + model, + likelihood, + last_layer=False, + subnetwork_indices=None, + dict_key_x='input_ids', + dict_key_y='labels', + ): + super().__init__( + model, likelihood, last_layer, subnetwork_indices, dict_key_x, dict_key_y + ) @property def _kron_fisher_type(self): @@ -62,7 +72,7 @@ def _get_kron_factors(self, linop): def kron(self, X, y, N, **kwargs): if isinstance(X, (dict, UserDict)): - kwargs['batch_size_fn'] = lambda x: x['input_ids'].shape[0] + kwargs['batch_size_fn'] = lambda x: x[self.dict_key_x].shape[0] linop = KFACLinearOperator( self.model, self.lossfunc, @@ -94,7 +104,7 @@ def full(self, X, y, **kwargs): curvlinops_kwargs = {k: v for k, v in kwargs.items() if k != 'N'} if isinstance(X, (dict, UserDict)): - curvlinops_kwargs['batch_size_fn'] = lambda x: x['input_ids'].shape[0] + curvlinops_kwargs['batch_size_fn'] = lambda x: x[self.dict_key_x].shape[0] linop = self._linop_context( self.model, @@ -124,9 +134,13 @@ def __init__( likelihood, last_layer=False, subnetwork_indices=None, + dict_key_x='input_ids', + dict_key_y='labels', stochastic=False, ): - super().__init__(model, likelihood, last_layer, subnetwork_indices) + super().__init__( + model, likelihood, last_layer, subnetwork_indices, dict_key_x, dict_key_y + ) self.stochastic = stochastic @property diff --git a/laplace/lllaplace.py b/laplace/lllaplace.py index c08e0bb..3325f98 100644 --- a/laplace/lllaplace.py +++ b/laplace/lllaplace.py @@ -52,6 +52,14 @@ class LLLaplace(ParametricLaplace): enable_backprop: bool, default=False whether to enable backprop to the input `x` through the Laplace predictive. Useful for e.g. Bayesian optimization. + dict_key_x: str, default='input_ids' + The dictionary key under which the input tensor `x` is stored. Only has effect + when the model takes a `MutableMapping` as the input. Useful for Huggingface + LLM models. + dict_key_y: str, default='labels' + The dictionary key under which the target tensor `y` is stored. Only has effect + when the model takes a `MutableMapping` as the input. Useful for Huggingface + LLM models. backend : subclasses of `laplace.curvature.CurvatureInterface` backend for access to curvature/Hessian approximations last_layer_name: str, default=None @@ -70,6 +78,8 @@ def __init__( prior_mean=0.0, temperature=1.0, enable_backprop=False, + dict_key_x='inputs_id', + dict_key_y='labels', backend=None, last_layer_name=None, backend_kwargs=None, @@ -86,6 +96,8 @@ def __init__( prior_mean=0.0, temperature=temperature, enable_backprop=enable_backprop, + dict_key_x=dict_key_x, + dict_key_y=dict_key_y, backend=backend, backend_kwargs=backend_kwargs, ) @@ -276,10 +288,12 @@ def __init__( prior_mean=0.0, temperature=1.0, enable_backprop=False, + dict_key_x='inputs_id', + dict_key_y='labels', backend=None, last_layer_name=None, damping=False, - **backend_kwargs + **backend_kwargs, ): self.damping = damping super().__init__( @@ -290,6 +304,8 @@ def __init__( prior_mean, temperature, enable_backprop, + dict_key_x, + dict_key_y, backend, last_layer_name, backend_kwargs, diff --git a/laplace/marglik_training.py b/laplace/marglik_training.py index b57ac52..c968823 100644 --- a/laplace/marglik_training.py +++ b/laplace/marglik_training.py @@ -1,3 +1,4 @@ +from collections.abc import MutableMapping from copy import deepcopy import numpy as np import torch @@ -6,7 +7,6 @@ from torch.nn.utils import parameters_to_vector import warnings import logging -from collections import UserDict import tqdm from laplace import Laplace @@ -36,6 +36,8 @@ def marglik_training( fix_sigma_noise=False, progress_bar=False, enable_backprop=False, + dict_key_x='input_ids', + dict_key_y='labels', ): """Marginal-likelihood based training (Algorithm 1 in [1]). Optimize model parameters and hyperparameters jointly. @@ -115,6 +117,14 @@ def marglik_training( whether to show a progress bar (updated per epoch) or not enable_backprop : bool, default=False make the returned Laplace instance backpropable---useful for e.g. Bayesian optimization. + dict_key_x: str, default='input_ids' + The dictionary key under which the input tensor `x` is stored. Only has effect + when the model takes a `MutableMapping` as the input. Useful for Huggingface + LLM models. + dict_key_y: str, default='labels' + The dictionary key under which the target tensor `y` is stored. Only has effect + when the model takes a `MutableMapping` as the input. Useful for Huggingface + LLM models. Returns ------- @@ -194,8 +204,8 @@ def marglik_training( # standard NN training per batch for data in train_loader: - if isinstance(data, UserDict) or isinstance(data, dict): - X, y = data, data['labels'] + if isinstance(data, MutableMapping): + X, y = data, data[dict_key_y] y = y.to(device, non_blocking=True) else: X, y = data @@ -257,6 +267,8 @@ def marglik_training( temperature=temperature, backend=backend, subset_of_weights='all', + dict_key_x=dict_key_x, + dict_key_y=dict_key_y, ) lap.fit(train_loader) @@ -311,6 +323,8 @@ def marglik_training( backend=backend, subset_of_weights='all', enable_backprop=enable_backprop, + dict_key_x=dict_key_x, + dict_key_y=dict_key_y, ) lap.fit(train_loader) return lap, model, margliks, losses diff --git a/laplace/utils/utils.py b/laplace/utils/utils.py index 315784f..f129109 100644 --- a/laplace/utils/utils.py +++ b/laplace/utils/utils.py @@ -1,18 +1,26 @@ import logging +from collections.abc import MutableMapping from typing import Union + import numpy as np import torch import torch.nn.functional as F -from torch.nn.utils import parameters_to_vector -from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d from torch.distributions.multivariate_normal import _precision_to_scale_tril +from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d +from torch.nn.utils import parameters_to_vector from torchmetrics import Metric -from collections import UserDict -import math - -__all__ = ['get_nll', 'validate', 'parameters_per_layer', 'invsqrt_precision', 'kron', - 'diagonal_add_scalar', 'symeig', 'block_diag', 'expand_prior_precision'] +__all__ = [ + 'get_nll', + 'validate', + 'parameters_per_layer', + 'invsqrt_precision', + 'kron', + 'diagonal_add_scalar', + 'symeig', + 'block_diag', + 'expand_prior_precision', +] def get_nll(out_dist, targets): @@ -20,7 +28,16 @@ def get_nll(out_dist, targets): @torch.no_grad() -def validate(laplace, val_loader, loss, pred_type='glm', link_approx='probit', n_samples=100, loss_with_var=False) -> float: +def validate( + laplace, + val_loader, + loss, + pred_type='glm', + link_approx='probit', + n_samples=100, + loss_with_var=False, + dict_key_y='labels', +) -> float: laplace.model.eval() assert callable(loss) or isinstance(loss, Metric) is_offline = not isinstance(loss, Metric) @@ -30,17 +47,15 @@ def validate(laplace, val_loader, loss, pred_type='glm', link_approx='probit', n targets = list() for data in val_loader: - # If x is UserDict, then it is a from Huggingface dataset - if isinstance(data, UserDict) or isinstance(data, dict): - X, y = data, data['labels'] + if isinstance(data, MutableMapping): + X, y = data, data[dict_key_y] else: X, y = data X = X.to(laplace._device) y = y.to(laplace._device) out = laplace( - X, pred_type=pred_type, - link_approx=link_approx, - n_samples=n_samples) + X, pred_type=pred_type, link_approx=link_approx, n_samples=n_samples + ) if type(out) == tuple: if is_offline: @@ -97,9 +112,7 @@ def invsqrt_precision(M): def _is_batchnorm(module): - if isinstance(module, BatchNorm1d) or \ - isinstance(module, BatchNorm2d) or \ - isinstance(module, BatchNorm3d): + if isinstance(module, (BatchNorm1d, BatchNorm2d, BatchNorm3d)): return True return False @@ -134,9 +147,9 @@ def kron(t1, t2): tiled_t2 = t2.repeat(t1_height, t1_width) expanded_t1 = ( t1.unsqueeze(2) - .unsqueeze(3) - .repeat(1, t2_height, t2_width, 1) - .view(out_height, out_width) + .unsqueeze(3) + .repeat(1, t2_height, t2_width, 1) + .view(out_height, out_width) ) return expanded_t1 * tiled_t2 @@ -185,7 +198,7 @@ def symeig(M): M = M + torch.eye(M.shape[0], device=M.device) try: L, W = torch.linalg.eigh(M, UPLO='U') - L -= 1. + L -= 1.0 except RuntimeError: stats = f'diag: {M.diagonal()}, max: {M.abs().max()}, ' stats = stats + f'min: {M.abs().min()}, mean: {M.abs().mean()}' @@ -214,7 +227,7 @@ def block_diag(blocks): p_cur = 0 for block in blocks: p_block = block.shape[0] - M[p_cur:p_cur+p_block, p_cur:p_cur+p_block] = block + M[p_cur : p_cur + p_block, p_cur : p_cur + p_block] = block p_cur += p_block return M @@ -243,11 +256,17 @@ def expand_prior_precision(prior_prec, model): elif len(prior_prec) == P: # full diagonal return prior_prec.to(device) else: - return torch.cat([delta * torch.ones_like(m).flatten() for delta, m - in zip(prior_prec, trainable_params)]) + return torch.cat( + [ + delta * torch.ones_like(m).flatten() + for delta, m in zip(prior_prec, trainable_params) + ] + ) -def fix_prior_prec_structure(prior_prec_init, prior_structure, n_layers, n_params, device): +def fix_prior_prec_structure( + prior_prec_init, prior_structure, n_layers, n_params, device +): if prior_structure == 'scalar': prior_prec_init = torch.full((1,), prior_prec_init, device=device) elif prior_structure == 'layerwise': @@ -275,8 +294,12 @@ def normal_samples(mean, var, n_samples, generator=None): """ assert mean.ndim == 2, 'Invalid input shape of mean, should be 2-dimensional.' _, output_dim = mean.shape - randn_samples = torch.randn((output_dim, n_samples), device=mean.device, - dtype=mean.dtype, generator=generator) + randn_samples = torch.randn( + (output_dim, n_samples), + device=mean.device, + dtype=mean.dtype, + generator=generator, + ) if mean.shape == var.shape: # diagonal covariance @@ -285,7 +308,9 @@ def normal_samples(mean, var, n_samples, generator=None): elif mean.shape == var.shape[:2] and var.shape[-1] == mean.shape[1]: # full covariance scale = torch.linalg.cholesky(var) - scaled_samples = torch.matmul(scale, randn_samples.unsqueeze(0)) # expand batch dim + scaled_samples = torch.matmul( + scale, randn_samples.unsqueeze(0) + ) # expand batch dim return (mean.unsqueeze(-1) + scaled_samples).permute((2, 0, 1)) else: raise ValueError('Invalid input shapes.') diff --git a/tests/test_baselaplace.py b/tests/test_baselaplace.py index 862b3d6..1664d59 100644 --- a/tests/test_baselaplace.py +++ b/tests/test_baselaplace.py @@ -10,7 +10,7 @@ from torch.nn.utils import parameters_to_vector from torch.utils.data import DataLoader, TensorDataset from torch.distributions import Normal, Categorical -from laplace.curvature.asdfghjkl import AsdfghjklGGN, AsdfghjklEF +from laplace.curvature.backpack import BackPackEF from laplace.curvature.curvlinops import CurvlinopsEF, CurvlinopsGGN from torchvision.models import wide_resnet50_2 @@ -51,7 +51,7 @@ def __init__(self): def forward(self, data: MutableMapping | torch.Tensor): if isinstance(data, MutableMapping): - x = data['input_ids'].to(next(self.parameters()).device) + x = data['test_input_key'].to(next(self.parameters()).device) else: x = data @@ -107,10 +107,25 @@ def reg_loader(): @pytest.fixture -def custom_loader(): +def custom_loader_clf(): data = [] for _ in range(10): - datum = {'input_ids': torch.randn(5), 'labels': torch.randint(2, (1,))} + datum = { + 'test_input_key': torch.randn(5), + 'test_label_key': torch.randint(2, (1,)), + } + data.append(datum) + return DataLoader(ListDataset(data), batch_size=3, collate_fn=dict_data_collator) + + +@pytest.fixture +def custom_loader_reg(): + data = [] + for _ in range(10): + datum = { + 'test_input_key': torch.randn(5), + 'test_label_key': torch.randn(2), + } data.append(datum) return DataLoader(ListDataset(data), batch_size=3, collate_fn=dict_data_collator) @@ -574,33 +589,74 @@ def test_reward_modeling(laplace, reward_model, reward_loader, reward_test_X): @pytest.mark.parametrize('laplace', [KronLaplace, DiagLaplace]) -@pytest.mark.parametrize('backend', [AsdlEF, AsdlGGN, CurvlinopsEF, CurvlinopsGGN]) -def test_dict_data(laplace, backend, custom_model, custom_loader): - if laplace == DiagLaplace and backend == CurvlinopsEF: - pytest.skip( - 'DiagEF is unsupported with Curvlinops when the input is non-tensor.' - ) +@pytest.mark.parametrize( + 'backend', [AsdlEF, AsdlGGN, CurvlinopsEF, CurvlinopsGGN, BackPackGGN, BackPackEF] +) +@pytest.mark.parametrize( + 'lik,custom_loader', + [ + ('classification', 'custom_loader_clf'), + ('regression', 'custom_loader_reg'), + ('reward_modeling', 'custom_loader_clf'), + ], +) +def test_dict_data(laplace, backend, lik, custom_loader, custom_model, request): + custom_loader = request.getfixturevalue(custom_loader) + + if ( + 'backpack' not in backend.__name__.lower() + and laplace != DiagLaplace + and laplace != CurvlinopsEF + ): + with pytest.raises(KeyError): + # Raises an error since custom_loader's input is under the key 'test_input_key' + # but the default is 'input_ids' + lap = laplace(custom_model, lik, backend=backend) + lap.fit(custom_loader) + + lap = laplace( + custom_model, + lik, + backend=backend, + dict_key_x='test_input_key', + dict_key_y='test_label_key', + ) + + if ('backpack' in backend.__name__.lower()) or ( + laplace == DiagLaplace and backend == CurvlinopsEF + ): + # Unsupported, thus raises an exception + with pytest.raises(ValueError): + lap.fit(custom_loader) - for data in custom_loader: - print(data) + return - lap = laplace(custom_model, 'classification', backend=backend) lap.fit(custom_loader) test_data = next(iter(custom_loader)) f = custom_model(test_data) - # GLM predictive - f_pred = lap(test_data, pred_type='glm') - assert f_pred.shape == f.shape - assert torch.allclose( - f_pred.sum(), torch.tensor(len(f_pred), dtype=torch.double) - ) # sum up to 1 + if lik == 'classification': + f_pred = lap(test_data, pred_type='glm') + assert f_pred.shape == f.shape + assert torch.allclose( + f_pred.sum(), torch.tensor(len(f_pred), dtype=torch.double) + ) # sum up to 1 + + f_pred = lap(test_data, pred_type='nn', link_approx='mc') + assert f_pred.shape == f.shape + assert torch.allclose( + f_pred.sum(), torch.tensor(len(f_pred), dtype=torch.double) + ) + else: + f_pred, f_var = lap(test_data, pred_type='glm') + assert f_pred.shape == f.shape + assert torch.allclose(f_pred, f) + assert f_var.shape == (f_pred.shape[0], f_pred.shape[1], f_pred.shape[1]) - # NN predictive - f_pred = lap(test_data, pred_type='nn', link_approx='mc') - assert f_pred.shape == f.shape - assert torch.allclose(f_pred.sum(), torch.tensor(len(f_pred), dtype=torch.double)) + f_pred, f_var = lap(test_data, pred_type='nn', link_approx='mc') + assert f_pred.shape == f.shape + assert f_var.shape == (f_pred.shape[0], f_pred.shape[1]) @pytest.mark.parametrize('laplace', [FullLaplace, KronLaplace, DiagLaplace]) @@ -689,11 +745,3 @@ def test_backprop_nn(laplace, model, reg_loader, backend): assert grad_X_var.shape == X.shape except ValueError: assert False - - -@pytest.mark.parametrize('likelihood', ['classification', 'regression']) -def test_dict_data_diagEF_curvlinops_fails(custom_model, custom_loader, likelihood): - lap = DiagLaplace(custom_model, likelihood=likelihood, backend=CurvlinopsEF) - - with pytest.raises(ValueError): - lap.fit(custom_loader) diff --git a/tests/test_utils.py b/tests/test_utils.py index d6f9165..29c142b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,7 +1,15 @@ import torch from torch.utils.data import TensorDataset, DataLoader from laplace import Laplace -from laplace.utils import invsqrt_precision, diagonal_add_scalar, symeig, normal_samples, validate, get_nll, RunningNLLMetric +from laplace.utils import ( + invsqrt_precision, + diagonal_add_scalar, + symeig, + normal_samples, + validate, + get_nll, + RunningNLLMetric, +) import math @@ -71,7 +79,9 @@ def test_validate(): y = torch.randint(3, size=(50,)) dataloader = DataLoader(TensorDataset(X, y), batch_size=10) - model = torch.nn.Sequential(torch.nn.Linear(10, 20), torch.nn.ReLU(), torch.nn.Linear(20, 3)) + model = torch.nn.Sequential( + torch.nn.Linear(10, 20), torch.nn.ReLU(), torch.nn.Linear(20, 3) + ) la = Laplace(model, 'classification', 'all') la.fit(dataloader) @@ -83,9 +93,13 @@ def test_validate(): assert res > 0 res = validate( - la, dataloader, RunningNLLMetric(), pred_type='nn', link_approx='mc', n_samples=10 + la, + dataloader, + RunningNLLMetric(), + pred_type='nn', + link_approx='mc', + n_samples=10, ) assert res != math.nan assert isinstance(res, float) assert res > 0 -