Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a handler for managing PyroSample plate contexts #3386

Draft
wants to merge 7 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyro/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
PyroModuleList,
PyroParam,
PyroSample,
PyroSamplePlateScope,
pyro_method,
)

Expand All @@ -28,4 +29,5 @@
"PyroSample",
"pyro_method",
"PyroModuleList",
"PyroSamplePlateScope",
]
56 changes: 53 additions & 3 deletions pyro/nn/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def _copy_to_script_wrapper(fn):
Any,
Callable,
Dict,
Iterable,
Iterator,
List,
NamedTuple,
Expand All @@ -53,8 +54,10 @@ def _copy_to_script_wrapper(fn):

import pyro
import pyro.params.param_store
from pyro.distributions.torch_distribution import TorchDistributionMixin
from pyro.ops.provenance import detach_provenance
from pyro.poutine.runtime import _PYRO_PARAM_STORE
from pyro.poutine.messenger import Messenger
from pyro.poutine.runtime import _PYRO_PARAM_STORE, InferDict

_MODULE_LOCAL_PARAMS: bool = False

Expand All @@ -63,7 +66,6 @@ def _copy_to_script_wrapper(fn):
_PyroModule = TypeVar("_PyroModule", bound="PyroModule")

if TYPE_CHECKING:
from pyro.distributions.torch_distribution import TorchDistributionMixin
from pyro.params.param_store import StateDict


Expand Down Expand Up @@ -232,6 +234,48 @@ def __get__(
return value


class _PyroSampleInferDict(InferDict):
_original_pyrosample_dist: TorchDistributionMixin


class PyroSamplePlateScope(Messenger):
"""
Handler for executing PyroSample statements in a more intuitive plate context.
"""

def __init__(self, allowed_plates: Iterable[str] = ()):
self._inner_allowed_plates = frozenset(allowed_plates)

def __enter__(self):
self._plates = (
frozenset(p.name for p in pyro.poutine.runtime.get_plates())
| self._inner_allowed_plates
)
return super().__enter__()

def _is_local_plate(self, m: Messenger) -> bool:
return (
isinstance(m, pyro.poutine.plate_messenger.PlateMessenger)
and m.name not in self._plates
)

def _pyro_sample(self, msg) -> None:
if not msg["infer"].get("_original_pyrosample_dist", None):
return
msg["stop"] = True
msg["done"] = True
with pyro.poutine.messenger.block_messengers(
lambda m: m is self or self._is_local_plate(m)
):
d = msg["infer"].pop("_original_pyrosample_dist")
msg["value"] = pyro.sample(
msg["name"],
d,
obs=msg["value"] if msg["is_observed"] else None,
infer=msg["infer"],
)


def _make_name(prefix: str, name: str) -> str:
return "{}.{}".format(prefix, name) if prefix else name

Expand Down Expand Up @@ -615,7 +659,13 @@ def __getattr__(self, name: str) -> Any:
value = (
pyro.deterministic(fullname, prior)
if isinstance(prior, torch.Tensor)
else pyro.sample(fullname, prior)
else pyro.sample(
fullname,
prior,
infer=_PyroSampleInferDict(
_original_pyrosample_dist=prior
),
)
)
context.set(fullname, value)
return value
Expand Down
43 changes: 43 additions & 0 deletions tests/nn/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1084,3 +1084,46 @@ def forward(self):
with pyro.settings.context(module_local_params=use_module_local_params):
model = Model()
pyro.render_model(model)


def test_pyrosample_platescope():

class Model(pyro.nn.PyroModule):
def __init__(self, num_inputs, num_outputs):
super().__init__()
self.num_inputs = num_inputs
self.num_outputs = num_outputs
self.linear = pyro.nn.PyroModule[torch.nn.Linear](num_inputs, num_outputs)
self.linear.weight = pyro.nn.PyroSample(
dist.Normal(0, 1).expand([num_outputs, num_inputs]).to_event(2)
)
self.linear.bias = pyro.nn.PyroSample(
dist.Normal(0, 1).expand([num_outputs]).to_event(1)
)

@pyro.nn.PyroSample
def scale(self):
return (
pyro.distributions.LogNormal(0, 1)
.expand([self.num_outputs])
.to_event(1)
)

@pyro.nn.PyroSamplePlateScope()
def forward(self, x):
with pyro.plate("data", x.shape[-2], dim=-1):
assert (
len(self.linear.weight.shape) == 2
or self.linear.weight.shape[-3] != 1
) # sampled outside data plate
loc = self.linear(x)
assert (
len(self.scale.shape) == 1 or self.scale.shape[-2] == 1
) # sampled outside data plate
y = pyro.sample("y", dist.Normal(loc, self.scale).to_event(1))
assert y.shape[-2] == x.shape[-2] # ordinary pyro.sample statement
return y

model = Model(3, 2)
x = torch.randn(4, 3)
assert model(x).shape == (4, 2)
Loading