Skip to content

Commit

Permalink
Merge pull request #168 from aleximmer/generalized-dict-input
Browse files Browse the repository at this point in the history
Make the dict keys for models with dict-like inputs general
  • Loading branch information
wiseodd authored Jun 10, 2024
2 parents f404594 + 5ccb81d commit fddcf55
Show file tree
Hide file tree
Showing 11 changed files with 385 additions and 108 deletions.
46 changes: 42 additions & 4 deletions laplace/baselaplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@ class BaseLaplace:
enable_backprop: bool, default=False
whether to enable backprop to the input `x` through the Laplace predictive.
Useful for e.g. Bayesian optimization.
dict_key_x: str, default='input_ids'
The dictionary key under which the input tensor `x` is stored. Only has effect
when the model takes a `MutableMapping` as the input. Useful for Huggingface
LLM models.
dict_key_y: str, default='labels'
The dictionary key under which the target tensor `y` is stored. Only has effect
when the model takes a `MutableMapping` as the input. Useful for Huggingface
LLM models.
backend : subclasses of `laplace.curvature.CurvatureInterface`
backend for access to curvature/Hessian approximations. Defaults to CurvlinopsGGN if None.
backend_kwargs : dict, default=None
Expand All @@ -68,6 +76,8 @@ def __init__(
prior_mean=0.0,
temperature=1.0,
enable_backprop=False,
dict_key_x='input_ids',
dict_key_y='labels',
backend=None,
backend_kwargs=None,
asdl_fisher_kwargs=None,
Expand Down Expand Up @@ -104,6 +114,10 @@ def __init__(
self.temperature = temperature
self.enable_backprop = enable_backprop

# For models with dict-like inputs (e.g. Huggingface LLMs)
self.dict_key_x = dict_key_x
self.dict_key_y = dict_key_y

if backend is None:
backend = CurvlinopsGGN
else:
Expand Down Expand Up @@ -132,7 +146,11 @@ def _device(self):
def backend(self):
if self._backend is None:
self._backend = self._backend_cls(
self.model, self.likelihood, **self._backend_kwargs
self.model,
self.likelihood,
dict_key_x=self.dict_key_x,
dict_key_y=self.dict_key_y,
**self._backend_kwargs,
)
return self._backend

Expand Down Expand Up @@ -412,6 +430,7 @@ def _gridsearch(
link_approx=link_approx,
n_samples=n_samples,
loss_with_var=loss_with_var,
dict_key_y=self.dict_key_y,
)
except RuntimeError:
result = np.inf
Expand Down Expand Up @@ -487,6 +506,8 @@ def __init__(
prior_mean=0.0,
temperature=1.0,
enable_backprop=False,
dict_key_x='inputs_id',
dict_key_y='labels',
backend=None,
backend_kwargs=None,
asdl_fisher_kwargs=None,
Expand All @@ -499,6 +520,8 @@ def __init__(
prior_mean,
temperature,
enable_backprop,
dict_key_x,
dict_key_y,
backend,
backend_kwargs,
asdl_fisher_kwargs,
Expand Down Expand Up @@ -544,12 +567,15 @@ def fit(self, train_loader, override=True, progress_bar=False):
data = next(iter(train_loader))
with torch.no_grad():
if isinstance(data, MutableMapping): # To support Huggingface dataset
if isinstance(self, DiagLaplace) and self._backend_cls == CurvlinopsEF:
if 'backpack' in self._backend_cls.__name__.lower() or (
isinstance(self, DiagLaplace) and self._backend_cls == CurvlinopsEF
):
raise ValueError(
'Currently DiagEF is not supported under CurvlinopsEF backend '
+ 'for custom models with non-tensor inputs '
+ '(https://github.com/pytorch/functorch/issues/159). Consider '
+ 'using AsdlEF backend instead.'
+ 'using AsdlEF backend instead. The same limitation applies '
+ 'to all BackPACK backend'
)

out = self.model(data)
Expand All @@ -571,7 +597,7 @@ def fit(self, train_loader, override=True, progress_bar=False):

for data in pbar:
if isinstance(data, MutableMapping): # To support Huggingface dataset
X, y = data, data['labels'].to(self._device)
X, y = data, data[self.dict_key_y].to(self._device)
else:
X, y = data
X, y = X.to(self._device), y.to(self._device)
Expand Down Expand Up @@ -1112,6 +1138,8 @@ def __init__(
prior_mean=0.0,
temperature=1.0,
enable_backprop=False,
dict_key_x='input_ids',
dict_key_y='labels',
backend=None,
backend_kwargs=None,
):
Expand All @@ -1123,6 +1151,8 @@ def __init__(
prior_mean,
temperature,
enable_backprop,
dict_key_x,
dict_key_y,
backend,
backend_kwargs,
)
Expand Down Expand Up @@ -1229,6 +1259,8 @@ def __init__(
prior_mean=0.0,
temperature=1.0,
enable_backprop=False,
dict_key_x='inputs_id',
dict_key_y='labels',
backend=None,
damping=False,
backend_kwargs=None,
Expand All @@ -1244,6 +1276,8 @@ def __init__(
prior_mean,
temperature,
enable_backprop,
dict_key_x,
dict_key_y,
backend,
backend_kwargs,
asdl_fisher_kwargs,
Expand Down Expand Up @@ -1376,6 +1410,8 @@ def __init__(
prior_mean=0,
temperature=1,
enable_backprop=False,
dict_key_x='inputs_id',
dict_key_y='labels',
backend=AsdfghjklHessian,
backend_kwargs=None,
):
Expand All @@ -1387,6 +1423,8 @@ def __init__(
prior_mean=prior_mean,
temperature=temperature,
enable_backprop=enable_backprop,
dict_key_x=dict_key_x,
dict_key_y=dict_key_y,
backend=backend,
backend_kwargs=backend_kwargs,
)
Expand Down
36 changes: 31 additions & 5 deletions laplace/curvature/asdfghjkl.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,23 @@ def kron(self, X, y, N, **kwargs):


class AsdfghjklHessian(AsdfghjklInterface):
def __init__(self, model, likelihood, last_layer=False, low_rank=10):
super().__init__(model, likelihood, last_layer)
def __init__(
self,
model,
likelihood,
last_layer=False,
dict_key_x='input_ids',
dict_key_y='labels',
low_rank=10,
):
super().__init__(
model,
likelihood,
last_layer,
None,
dict_key_x='input_ids',
dict_key_y='labels',
)
self.low_rank = low_rank

@property
Expand Down Expand Up @@ -187,11 +202,15 @@ def __init__(
likelihood,
last_layer=False,
subnetwork_indices=None,
dict_key_x='input_ids',
dict_key_y='labels',
stochastic=False,
):
if likelihood != 'classification':
raise ValueError('This backend only supports classification currently.')
super().__init__(model, likelihood, last_layer, subnetwork_indices)
super().__init__(
model, likelihood, last_layer, subnetwork_indices, dict_key_x, dict_key_y
)
self.stochastic = stochastic

@property
Expand All @@ -202,10 +221,17 @@ def _ggn_type(self):
class AsdfghjklEF(AsdfghjklInterface, EFInterface):
"""Implementation of the `EFInterface` using asdfghjkl."""

def __init__(self, model, likelihood, last_layer=False):
def __init__(
self,
model,
likelihood,
last_layer=False,
dict_key_x='input_ids',
dict_key_y='labels',
):
if likelihood != 'classification':
raise ValueError('This backend only supports classification currently.')
super().__init__(model, likelihood, last_layer)
super().__init__(model, likelihood, last_layer, None, dict_key_x, dict_key_y)

@property
def _ggn_type(self):
Expand Down
69 changes: 48 additions & 21 deletions laplace/curvature/asdl.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections.abc import MutableMapping
import warnings

import numpy as np
Expand All @@ -19,16 +20,24 @@
from laplace.curvature import CurvatureInterface, GGNInterface, EFInterface
from laplace.utils import Kron, _is_batchnorm

from collections import UserDict

EPS = 1e-6


class AsdlInterface(CurvatureInterface):
"""Interface for asdfghjkl backend."""

def __init__(self, model, likelihood, last_layer=False, subnetwork_indices=None):
super().__init__(model, likelihood, last_layer, subnetwork_indices)
def __init__(
self,
model,
likelihood,
last_layer=False,
subnetwork_indices=None,
dict_key_x='input_ids',
dict_key_y='labels',
):
super().__init__(
model, likelihood, last_layer, subnetwork_indices, dict_key_x, dict_key_y
)

@property
def loss_type(self):
Expand All @@ -40,9 +49,9 @@ def jacobians(self, x, enable_backprop=False):
Parameters
----------
x : torch.Tensor or UserDict
x : torch.Tensor or MutableMapping (e.g. dict, UserDict)
input data `(batch, input_shape)` on compatible device with model if torch.Tensor.
If UserDict, then at least contains key ['input_ids'] or ['input_ids_0', 'input_ids_1'].
If MutableMapping, then at least contains `self.dict_key_x`.
The latter is specific for reward modeling.
enable_backprop : bool, default = False
whether to enable backprop through the Js and f w.r.t. x
Expand Down Expand Up @@ -217,28 +226,35 @@ def kron(self, X, y, N, **kwargs):
curv_factor = 1.0 # ASDL uses proper 1/2 * MSELoss
return self.factor * loss, curv_factor * kron

@staticmethod
def _get_batch_size(x):
def _get_batch_size(self, x):
"""
ASDL assumes that all leading dimensions are the batch size by default (batch_size = None).
Here, we want to specify that only the first dimension is the actual batch size.
This is the case for LLMs.
"""
# If x is UserDict, then it has weight-sharing dimension (from Huggingface datasets)
if isinstance(x, UserDict) or isinstance(x, dict):
try:
return x['input_ids'].shape[0]
except KeyError:
# The case of reward modeling; the UserDict contains ['input_ids_0', 'input_ids_1']
return x['input_ids_0'].shape[0]
if isinstance(x, MutableMapping):
return x[self.dict_key_x].shape[0]
else:
return None # Use ASDL default behavior


class AsdlHessian(AsdlInterface):
def __init__(self, model, likelihood, last_layer=False, low_rank=10):
super().__init__(model, likelihood, last_layer)
self.low_rank = low_rank
def __init__(
self,
model,
likelihood,
last_layer=False,
dict_key_x='input_ids',
dict_key_y='labels',
):
super().__init__(
model,
likelihood,
last_layer,
subnetwork_indices=None,
dict_key_x=dict_key_x,
dict_key_y=dict_key_y,
)

@property
def _ggn_type(self):
Expand Down Expand Up @@ -271,9 +287,13 @@ def __init__(
likelihood,
last_layer=False,
subnetwork_indices=None,
dict_key_x='input_ids',
dict_key_y='labels',
stochastic=False,
):
super().__init__(model, likelihood, last_layer, subnetwork_indices)
super().__init__(
model, likelihood, last_layer, subnetwork_indices, dict_key_x, dict_key_y
)
self.stochastic = stochastic

@property
Expand All @@ -284,8 +304,15 @@ def _ggn_type(self):
class AsdlEF(AsdlInterface, EFInterface):
"""Implementation of the `EFInterface` using asdfghjkl."""

def __init__(self, model, likelihood, last_layer=False):
super().__init__(model, likelihood, last_layer)
def __init__(
self,
model,
likelihood,
last_layer=False,
dict_key_x='input_ids',
dict_key_y='labels',
):
super().__init__(model, likelihood, last_layer, None, dict_key_x, dict_key_y)

@property
def _ggn_type(self):
Expand Down
21 changes: 18 additions & 3 deletions laplace/curvature/backpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,19 @@
class BackPackInterface(CurvatureInterface):
"""Interface for Backpack backend."""

def __init__(self, model, likelihood, last_layer=False, subnetwork_indices=None):
super().__init__(model, likelihood, last_layer, subnetwork_indices)
def __init__(
self,
model,
likelihood,
last_layer=False,
subnetwork_indices=None,
dict_key_x='input_ids',
dict_key_y='labels',
):
super().__init__(
model, likelihood, last_layer, subnetwork_indices, dict_key_x, dict_key_y
)

extend(self._model)
extend(self.lossfunc)

Expand Down Expand Up @@ -115,9 +126,13 @@ def __init__(
likelihood,
last_layer=False,
subnetwork_indices=None,
dict_key_x='input_ids',
dict_key_y='labels',
stochastic=False,
):
super().__init__(model, likelihood, last_layer, subnetwork_indices)
super().__init__(
model, likelihood, last_layer, subnetwork_indices, dict_key_x, dict_key_y
)
self.stochastic = stochastic

def _get_diag_ggn(self):
Expand Down
Loading

0 comments on commit fddcf55

Please sign in to comment.