Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expectation examples #263

Merged
merged 5 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 120 additions & 0 deletions docs/expectation_example.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
## Motivation

Often, we want to compute non-analytic expectation \(\mathbb{E}_{p(f(x) \mid \mathcal{D})} [g(f(x))]\) of a function \(g(f(x))\) given a Laplace posterior over functions \(p(f(x) \mid \mathcal{D})\) at input \(x\).
This naturally arises in decision-making:
Given a posterior belief about an unknown function \(f\), we would like to compute the expected utility of \(x\).
In Bayesian optimization, this is called \_acquisition function_.
For some utility function \(g\) and posterior belief \(p(f(x) \mid \mathcal{D})\), the resulting acquisition function can be computed analytically, e.g. if \(g\) is linear and the posterior is Gaussian (process).
But, in general, closed-form solutions don't exist.

In this example, we will see how easy it is to compute a _differentiable_, Monte-Carlo approximated acquisition function under the posterior distribution over neural network functions implied by a Laplace approximation.

## Laplace approximations

As always, the first step of a Laplace approximation is MAP estimation.

```python
import torch
import torch.utils.data as data_utils
from torch import autograd, nn, optim

from laplace import Laplace
from laplace.utils.enums import HessianStructure, Likelihood, PredType, SubsetOfWeights

torch.manual_seed(123)

model = nn.Sequential(nn.Linear(2, 10), nn.GELU(), nn.Linear(10, 1))
X, Y = torch.randn(5, 2), torch.randn(5, 1)
train_loader = data_utils.DataLoader(data_utils.TensorDataset(X, Y), batch_size=3)
opt = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=5e-4)
loss_fn = nn.MSELoss()

for epoch in range(10):
model.train()

for x, y in train_loader:
opt.zero_grad()
out = model(x)
loss = loss_fn(out, y)
loss.backward()
opt.step()
```

Then, we are ready to obtain the Laplace approximation.
In this example, we focus on the weight-space approximation, but the same can be done in the function space directly by simply specifying `hessian_structure=HessianStructure.GP`.

```python
la = Laplace(
model,
Likelihood.REGRESSION,
subset_of_weights=SubsetOfWeights.ALL,
hessian_structure=HessianStructure.KRON,
enable_backprop=True,
)
la.fit(train_loader)
la.optimize_prior_precision(PredType.GLM)
```

!!! tip

If you need the gradient of quantities that depend on Laplace's predictive
distribution/samples, then be sure to specify `enable_backprop=True`. This is in
fact necessary for continuous Bayesian optimization.

## Thompson Sampling

The simplest acquisition function that can be obtained from the Laplace posterior is Thompson sampling.
This is defined as \(a \sim p(f(x) \mid \mathcal{D})\).
Thus, given `la`, it can be obtained very simply:

```python
f_sample = la.functional_samples(x_test, n_samples=1)
```

Note that `f_sample` can be obtained through the `"nn"` and `"glm"` predictives.
The `la.functional_samples` function supports both options.

```python
for pred_type in [PredType.GLM, PredType.NN]:
print(f"Thompson sampling, {pred_type}")

x_test = torch.randn(10, 2)
x_test.requires_grad = True

f_sample = la.functional_samples(x_test, pred_type=pred_type, n_samples=1)
f_sample = f_sample.squeeze(0) # We only use a single sample
print(f"TS shape: {f_sample.shape}, TS requires grad: {f_sample.requires_grad}")

grad_x = autograd.grad(f_sample.sum(), x_test)[0]

print(
f"Grad x_test shape: {grad_x.shape}, Grad x_test vanishing: {torch.allclose(grad_x, torch.tensor(0.0))}"
)
print()
```

The snippet above will output:

```
Thompson sampling, glm
TS shape: torch.Size([10, 1]), TS requires grad: True
Grad x_test shape: torch.Size([10, 2]), Grad x_test vanishing: False

Thompson sampling, nn
TS shape: torch.Size([10, 1]), TS requires grad: True
Grad x_test shape: torch.Size([10, 2]), Grad x_test vanishing: False
```

As we can see, the gradient can be computed through Laplace's predictive and its non-vanishing.

## Monte-Carlo EI

In general, given a choice of utility function \(u(f(x))\), any acquisition function can be obtained w.r.t. the Laplace posterior.
For example, to compute Monte-Carlo-approximated EI, we can do so via:

```python
f_samples = la.functional_samples(x_test, pred_type=pred_type, n_samples=10)
ei = (f_samples - f_best.reshape(1, 1, 1)).clamp(0.0).mean(0)
```

Again, if \(u\) is differentiable, then we can obtain the gradient w.r.t. the input, and we can do continuous Bayesian optimization.
78 changes: 78 additions & 0 deletions examples/expectation_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import torch
import torch.utils.data as data_utils
from torch import autograd, nn, optim

from laplace import Laplace
from laplace.utils.enums import HessianStructure, Likelihood, PredType, SubsetOfWeights

torch.manual_seed(123)

model = nn.Sequential(nn.Linear(2, 10), nn.GELU(), nn.Linear(10, 1))
X, Y = torch.randn(5, 2), torch.randn(5, 1)
train_loader = data_utils.DataLoader(data_utils.TensorDataset(X, Y), batch_size=3)
opt = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=5e-4)
loss_fn = nn.MSELoss()

for epoch in range(10):
model.train()

for x, y in train_loader:
opt.zero_grad()
out = model(x)
loss = loss_fn(out, y)
loss.backward()
opt.step()

la = Laplace(
model,
Likelihood.REGRESSION,
subset_of_weights=SubsetOfWeights.ALL,
hessian_structure=HessianStructure.KRON,
enable_backprop=True,
)
la.fit(train_loader)
la.optimize_prior_precision(PredType.GLM)

# Thompson sampling
for pred_type in [PredType.GLM, PredType.NN]:
print(f"Thompson sampling, {pred_type}")

x_test = torch.randn(10, 2)
x_test.requires_grad = True

f_sample = la.functional_samples(x_test, pred_type=pred_type, n_samples=1)
f_sample = f_sample.squeeze(0) # We only use a single sample
print(f"TS shape: {f_sample.shape}, TS requires grad: {f_sample.requires_grad}")

# Get the gradient of the Thompson sample w.r.t. input x.
# Summed since it doesn't change the grad and autograd requires a scalar function.
grad_x = autograd.grad(f_sample.sum(), x_test)[0]

print(
f"Grad x_test shape: {grad_x.shape}, Grad x_test vanishing: {torch.allclose(grad_x, torch.tensor(0.0))}"
)
print()

print()

# Monte-Carlo expected improvement (EI): E_{f(x) ~ p(f(x) | D)} [max(f(x) - best_f, 0)]
f_best = torch.tensor(0.123) # Arbitrary in this example

for pred_type in [PredType.GLM, PredType.NN]:
print(f"MC-EI, {pred_type}")

x_test = torch.randn(10, 2)
x_test.requires_grad = True

f_samples = la.functional_samples(x_test, pred_type=pred_type, n_samples=10)
ei = (f_samples - f_best.reshape(1, 1, 1)).clamp(0.0).mean(0)
print(f"EI shape: {ei.shape}, EI requires grad: {ei.requires_grad}")

# Get the gradient of the EI w.r.t. input x.
# Summed since it doesn't change the grad and autograd requires a scalar function.
grad_x = autograd.grad(ei.sum(), x_test)[0]

print(
f"Grad x_test shape: {grad_x.shape}, Grad x_test vanishing: {torch.allclose(grad_x, torch.tensor(0.0))}"
)
print()
56 changes: 56 additions & 0 deletions laplace/baselaplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,62 @@ def _glm_forward_call(
"Prediction path invalid. Check the likelihood, pred_type, link_approx combination!"
)

def sample(
self, n_samples: int = 1, generator: torch.Generator | None = None
) -> torch.Tensor:
"""Sample from the Laplace posterior approximation, i.e.,
\\( \\theta \\sim \\mathcal{N}(\\theta_{MAP}, P^{-1})\\).

Parameters
----------
n_samples : int, default=100
number of samples

generator : torch.Generator, optional
random number generator to control the samples

Returns
-------
samples: torch.Tensor
"""
raise NotImplementedError

def functional_samples(
self,
x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
pred_type: PredType | str = PredType.GLM,
n_samples: int = 1,
diagonal_output: bool = False,
generator: torch.Generator | None = None,
) -> torch.Tensor:
"""Sample from the functional posterior on input data `x`.
Can be used, for example, for Thompson sampling.

Parameters
----------
x : torch.Tensor or MutableMapping
input data `(batch_size, input_shape)`

pred_type : {'glm'}, default='glm'
type of posterior predictive, linearized GLM predictive.

n_samples : int
number of samples

diagonal_output : bool
whether to use a diagonalized glm posterior predictive on the outputs.
Only applies when `pred_type='glm'`.

generator : torch.Generator, optional
random number generator to control the samples (if sampling used)

Returns
-------
samples : torch.Tensor
samples `(n_samples, batch_size, output_shape)`
"""
raise NotImplementedError

def _glm_functional_samples(
self,
f_mu: torch.Tensor,
Expand Down
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ nav:
- "Example: Regression": regression_example.md
- "Example: Calibration": calibration_example.md
- "Example: GP Inference": calibration_gp_example.md
- "Example: MC Acquistion Functions": expectation_example.md
- "Example: Huggingface LLMs": huggingface_example.md
- "Example: Reward Modeling": reward_modeling_example.md
- API Reference:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "laplace-torch"
version = "0.2.2.2"
version = "0.2.3"
description = "laplace - Laplace approximations for deep learning"
readme = "README.md"
authors = [
Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.