diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 02b77feb17..4259ae8bf4 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -121,6 +121,7 @@ jobs: tests/logprob/test_order.py tests/logprob/test_rewriting.py tests/logprob/test_scan.py + tests/logprob/test_set_subtensor.py tests/logprob/test_tensor.py tests/logprob/test_transform_value.py tests/logprob/test_transforms.py diff --git a/pymc/logprob/__init__.py b/pymc/logprob/__init__.py index 6b4911ae62..9774e43f0a 100644 --- a/pymc/logprob/__init__.py +++ b/pymc/logprob/__init__.py @@ -53,6 +53,7 @@ import pymc.logprob.mixture import pymc.logprob.order import pymc.logprob.scan +import pymc.logprob.set_subtensor import pymc.logprob.tensor import pymc.logprob.transforms diff --git a/pymc/logprob/checks.py b/pymc/logprob/checks.py index c8c21ef61c..14fe82f0d6 100644 --- a/pymc/logprob/checks.py +++ b/pymc/logprob/checks.py @@ -92,7 +92,7 @@ def logprob_check_and_raise(op, values, inner_rv, *assertions, **kwargs): (value,) = values # transfer assertion from rv to value assertions = replace_rvs_by_values(assertions, rvs_to_values={inner_rv: value}) - value = op(value, *assertions) + value = CheckAndRaise(**op._props_dict())(value, *assertions) return _logprob_helper(inner_rv, value) diff --git a/pymc/logprob/set_subtensor.py b/pymc/logprob/set_subtensor.py new file mode 100644 index 0000000000..472a266edc --- /dev/null +++ b/pymc/logprob/set_subtensor.py @@ -0,0 +1,215 @@ +# Copyright 2024 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytensor.graph.basic import Variable +from pytensor.graph.rewriting.basic import node_rewriter +from pytensor.tensor import eq +from pytensor.tensor.subtensor import ( + AdvancedIncSubtensor, + AdvancedIncSubtensor1, + IncSubtensor, + indices_from_subtensor, +) +from pytensor.tensor.type import TensorType +from pytensor.tensor.type_other import NoneTypeT + +from pymc.logprob.abstract import MeasurableOp, _logprob, _logprob_helper +from pymc.logprob.checks import MeasurableCheckAndRaise +from pymc.logprob.rewriting import measurable_ir_rewrites_db +from pymc.logprob.utils import ( + check_potential_measurability, + dirac_delta, + filter_measurable_variables, +) + + +class MeasurableSetSubtensor(IncSubtensor, MeasurableOp): + """Measurable SetSubtensor Op.""" + + def __str__(self): + return f"Measurable{super().__str__()}" + + +class MeasurableAdvancedSetSubtensor(AdvancedIncSubtensor, MeasurableOp): + """Measurable AdvancedSetSubtensor Op.""" + + def __str__(self): + return f"Measurable{super().__str__()}" + + +set_subtensor_does_not_broadcast = MeasurableCheckAndRaise( + exc_type=NotImplementedError, + msg="Measurable SetSubtensor not supported when set value is broadcasted.", +) + + +@node_rewriter(tracks=[IncSubtensor, AdvancedIncSubtensor1, AdvancedIncSubtensor]) +def find_measurable_set_subtensor(fgraph, node) -> list | None: + """Find `SetSubtensor` for which a `logprob` can be computed.""" + if isinstance(node.op, MeasurableOp): + return None + + if not node.op.set_instead_of_inc: + return None + + x, y, *idx_elements = node.inputs + + measurable_inputs = filter_measurable_variables([x, y]) + + if y not in measurable_inputs: + return None + + if x not in measurable_inputs: + # x is potentially measurable, wait for it's logprob IR to be inferred + if check_potential_measurability([x]): + return None + # x has no link to measurable variables, so it's value should be constant + else: + x = dirac_delta(x, rtol=0, atol=0) + + if check_potential_measurability(idx_elements): + return None + + measurable_class: type[MeasurableSetSubtensor | MeasurableAdvancedSetSubtensor] + if isinstance(node.op, IncSubtensor): + measurable_class = MeasurableSetSubtensor + idx = indices_from_subtensor(idx_elements, node.op.idx_list) + else: + measurable_class = MeasurableAdvancedSetSubtensor + idx = tuple(idx_elements) + + # Check that y is not certainly broadcasted. + indexed_block = x[idx] + missing_y_dims = indexed_block.type.ndim - y.type.ndim + y_bcast = [True] * missing_y_dims + list(y.type.broadcastable) + if any( + y_dim_bcast and indexed_block_dim_len not in (None, 1) + for y_dim_bcast, indexed_block_dim_len in zip( + y_bcast, indexed_block.type.shape, strict=True + ) + ): + return None + + measurable_set_subtensor = measurable_class(**node.op._props_dict())(x, y, *idx_elements) + + # Often with indexing we don't know the static shape of the indexed block. + # And, what's more, the indexing operations actually support runtime broadcasting. + # As the logp is not valid under broadcasting, we have to add a runtime check. + # This will hopefully be removed during shape inference when not violated. + potential_broadcasted_dims = [ + i + for i, (y_bcast_dim, indexed_block_dim_len) in enumerate( + zip(y_bcast, indexed_block.type.shape) + ) + if y_bcast_dim and indexed_block_dim_len is None + ] + if potential_broadcasted_dims: + indexed_block_shape = tuple(indexed_block.shape) + measurable_set_subtensor = set_subtensor_does_not_broadcast( + measurable_set_subtensor, + *(eq(indexed_block_shape[i], 1) for i in potential_broadcasted_dims), + ) + + return [measurable_set_subtensor] + + +measurable_ir_rewrites_db.register( + find_measurable_set_subtensor.__name__, + find_measurable_set_subtensor, + "basic", + "set_subtensor", +) + + +def indexed_dims(idx) -> list[int | None]: + """Return the indices of the dimensions of the indexed tensor that are being indexed.""" + dims: list[int | None] = [] + idx_counter = 0 + for idx_elem in idx: + if isinstance(idx_elem, Variable) and isinstance(idx_elem.type, NoneTypeT): + # None in indexes correspond to newaxis, and don't map to any existing dimension + dims.append(None) + + elif ( + isinstance(idx_elem, Variable) + and isinstance(idx_elem.type, TensorType) + and idx_elem.type.dtype == "bool" + ): + # Boolean indexes map to as many dimensions as the mask has + for i in range(idx_elem.type.ndim): + dims.append(idx_counter) + idx_counter += 1 + else: + dims.append(idx_counter) + idx_counter += 1 + + return dims + + +@_logprob.register(MeasurableSetSubtensor) +@_logprob.register(MeasurableAdvancedSetSubtensor) +def logprob_setsubtensor(op, values, x, y, *idx_elements, **kwargs): + """Compute the log-likelihood graph for a `SetSubtensor`. + + For a generative graph like: + o = zeros(2) + x = o[0].set(X) + y = x[1].set(Y) + + The log-likelihood graph is: + logp(y, value) = ( + logp(x, value) + [1].set(logp(y, value[1])) + ) + + Unrolling the logp(x, value) gives: + logp(y, value) = ( + DiracDelta(zeros(2), value) # Irrelevant if all entries are set + [0].set(logp(x, value[0])) + [1].set(logp(y, value[1])) + ) + """ + [value] = values + if isinstance(op, MeasurableSetSubtensor): + # For basic indexing we have to recreate the index from the input list + idx = indices_from_subtensor(idx_elements, op.idx_list) + else: + # For advanced indexing we can use the idx_elements directly + idx = tuple(idx_elements) + + x_logp = _logprob_helper(x, value) + y_logp = _logprob_helper(y, value[idx]) + + y_ndim_supp = x[idx].type.ndim - y_logp.type.ndim + x_ndim_supp = x.type.ndim - x_logp.type.ndim + ndim_supp = max(y_ndim_supp, x_ndim_supp) + if ndim_supp > 0: + # Multivariate logp only valid if we are not doing indexing along the reduced dimensions + # Otherwise we don't know if successive writings are overlapping or not + core_dims = set(range(x.type.ndim)[-ndim_supp:]) + if set(indexed_dims(idx)) & core_dims: + # When we have IR meta-info about support_ndim, we can fail at the rewriting stage + raise NotImplementedError( + "Indexing along core dimensions of multivariate SetSubtensor not supported" + ) + + ndim_supp_diff = y_ndim_supp - x_ndim_supp + if ndim_supp_diff > 0: + # In this case y_logp will have fewer dimensions than x_logp after indexing, so we need to reduce x before indexing. + x_logp = x_logp.sum(axis=tuple(range(-ndim_supp_diff, 0))) + elif ndim_supp_diff < 0: + # In this case x_logp will have fewer dimensions than y_logp after indexing, so we need to reduce y before indexing. + y_logp = y_logp.sum(axis=tuple(range(ndim_supp_diff, 0))) + + out_logp = x_logp[idx].set(y_logp) + return out_logp diff --git a/pymc/logprob/utils.py b/pymc/logprob/utils.py index 9865226e42..7cdfa6a03c 100644 --- a/pymc/logprob/utils.py +++ b/pymc/logprob/utils.py @@ -46,10 +46,9 @@ from pytensor.graph.basic import Constant, Variable, clone_get_equiv, graph_inputs, walk from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import HasInnerGraph -from pytensor.link.c.type import CType from pytensor.raise_op import CheckAndRaise from pytensor.scalar.basic import Mul -from pytensor.tensor.basic import get_underlying_scalar_constant_value +from pytensor.tensor.basic import AllocEmpty, get_underlying_scalar_constant_value from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.random.op import RandomVariable @@ -150,27 +149,6 @@ def expand(r): } -def convert_indices(indices, entry): - if indices and isinstance(entry, CType): - rval = indices.pop(0) - return rval - elif isinstance(entry, slice): - return slice( - convert_indices(indices, entry.start), - convert_indices(indices, entry.stop), - convert_indices(indices, entry.step), - ) - else: - return entry - - -def indices_from_subtensor(idx_list, indices): - """Compute a useable index tuple from the inputs of a ``*Subtensor**`` ``Op``.""" - return tuple( - tuple(convert_indices(list(indices), idx) for idx in idx_list) if idx_list else indices - ) - - def filter_measurable_variables(inputs): return [ inp for inp in inputs if (inp.owner is not None and isinstance(inp.owner.op, MeasurableOp)) @@ -266,7 +244,7 @@ class DiracDelta(MeasurableOp, Op): __props__ = ("rtol", "atol") - def __init__(self, rtol=1e-5, atol=1e-8): + def __init__(self, rtol, atol): self.rtol = rtol self.atol = atol @@ -289,15 +267,25 @@ def infer_shape(self, fgraph, node, input_shapes): return input_shapes -dirac_delta = DiracDelta() +def dirac_delta(x, rtol=1e-5, atol=1e-8): + return DiracDelta(rtol, atol)(x) @_logprob.register(DiracDelta) -def diracdelta_logprob(op, values, *inputs, **kwargs): - (values,) = values - (const_value,) = inputs - values, const_value = pt.broadcast_arrays(values, const_value) - return pt.switch(pt.isclose(values, const_value, rtol=op.rtol, atol=op.atol), 0.0, -np.inf) +def diracdelta_logprob(op, values, const_value, **kwargs): + [value] = values + + if const_value.owner and isinstance(const_value.owner.op, AllocEmpty): + # Any value is considered valid for an AllocEmpty array + return pt.zeros_like(value) + + if op.rtol == 0 and op.atol == 0: + # Strict equality, cheaper logp + match = pt.eq(value, const_value) + else: + # Loose equality, more expensive logp + match = pt.isclose(value, const_value, rtol=op.rtol, atol=op.atol) + return pt.switch(match, np.array(0, dtype=value.dtype), -np.inf) def find_negated_var(var): diff --git a/tests/logprob/test_set_subtensor.py b/tests/logprob/test_set_subtensor.py new file mode 100644 index 0000000000..d3eedea3aa --- /dev/null +++ b/tests/logprob/test_set_subtensor.py @@ -0,0 +1,158 @@ +# Copyright 2024 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytensor +import pytensor.tensor as pt +import pytest + +from pymc.distributions import Beta, Dirichlet, MvNormal, MvStudentT, Normal, StudentT +from pymc.logprob.basic import logp + + +@pytest.mark.parametrize("univariate", [True, False]) +def test_complete_set_subtensor(univariate): + if univariate: + rv0 = Normal.dist(mu=-10) + rv1 = StudentT.dist(nu=3, mu=0) + rv2 = Normal.dist(mu=10, sigma=3) + rv34 = Beta.dist(alpha=[np.pi, np.e], beta=[1, 1]) + base = pt.empty((5,)) + test_val = [2, 0, -2, 0.25, 0.5] + else: + rv0 = MvNormal.dist(mu=[-11, -9], cov=pt.eye(2)) + rv1 = MvStudentT.dist(nu=3, mu=[-1, 1], cov=pt.eye(2)) + rv2 = MvNormal.dist(mu=[9, 11], cov=pt.eye(2) * 3) + rv34 = Dirichlet.dist(a=[[np.pi, 1], [np.e, 1]]) + base = pt.empty((3, 2)) + test_val = [[2, 0], [0, -2], [-2, 2], [0.25, 0.75], [0.5, 0.5]] + + # fmt: off + rv = ( + # Boolean indexing + base[np.array([True, False, False, False, False])].set(rv0) + # Slice indexing + [1:2].set(rv1) + # Integer indexing + [2].set(rv2) + # Vector indexing + [[3, 4]].set(rv34) + ) + # fmt: on + ref_rv = pt.join(0, [rv0], [rv1], [rv2], rv34) + + np.testing.assert_allclose( + logp(rv, test_val).eval(), + logp(ref_rv, test_val).eval(), + ) + + +def test_partial_set_subtensor(): + rv123 = Normal.dist(mu=[-10, 0, 10]) + + # When base is empty, it doesn't matter what the missing values are + base = pt.empty((5,)) + rv = base[:3].set(rv123) + + np.testing.assert_allclose( + logp(rv, [0, 0, 0, 1, np.pi]).eval(), + [*logp(rv123, [0, 0, 0]).eval(), 0, 0], + ) + + # Otherwise they should match + base = pt.ones((5,)) + rv = base[:3].set(rv123) + + np.testing.assert_allclose( + logp(rv, [0, 0, 0, 1, np.pi]).eval(), + [*logp(rv123, [0, 0, 0]).eval(), 0, -np.inf], + ) + + +def test_overwrite_set_subtensor(): + """Test that order of overwriting in the generative graph is respected.""" + x = Normal.dist(mu=[0, 1, 2]) + y = x[1:].set(Normal.dist([10, 20])) + z = y[2:].set(Normal.dist([300])) + + np.testing.assert_allclose( + logp(z, [0, 0, 0]).eval(), + logp(Normal.dist([0, 10, 300]), [0, 0, 0]).eval(), + ) + + +def test_mixed_dimensionality_set_subtensor(): + x = Normal.dist(mu=0, size=(3, 2)) + y = x[1].set(MvNormal.dist(mu=[1, 1], cov=np.eye(2))) + z = y[2].set(Normal.dist(mu=2, size=(2,))) + + # Because `y` is multivariate the last dimension of `z` must be summed over + test_val = np.zeros((3, 2)) + logp_eval = logp(z, test_val).eval() + assert logp_eval.shape == (3,) + np.testing.assert_allclose( + logp_eval, + logp(Normal.dist(mu=[[0, 0], [1, 1], [2, 2]]), test_val).sum(-1).eval(), + ) + + +def test_invalid_indexing_core_dims(): + x = pt.empty((2, 2)) + rv = MvNormal.dist(cov=np.eye(2)) + vv = x.type() + + match_msg = "Indexing along core dimensions of multivariate SetSubtensor not supported" + + y = x[[0, 1], [1, 0]].set(rv) + with pytest.raises(NotImplementedError, match=match_msg): + logp(y, vv) + + y = x[np.array([[False, True], [True, False]])].set(rv) + with pytest.raises(NotImplementedError, match=match_msg): + logp(y, vv) + + # Univariate indexing above multivariate core dims also not supported + z = y[0].set(rv)[0, 1].set(Normal.dist()) + with pytest.raises(NotImplementedError, match=match_msg): + logp(z, vv) + + +def test_invalid_broadcasted_set_subtensor(): + rv_bcast = Normal.dist(mu=0) + base = pt.empty((5,)) + + rv = base[:3].set(rv_bcast) + vv = rv.type() + + # Broadcasting is known at write time, and PyMC does not attempt to make SetSubtensor measurable + with pytest.raises(NotImplementedError): + logp(rv, vv) + + mask = pt.tensor(shape=(5,), dtype=bool) + rv = base[mask].set(rv_bcast) + + # Broadcasting is only known at runtime, and PyMC raises an error when it happens + logp_rv = logp(rv, vv) + fn = pytensor.function([mask, vv], logp_rv) + test_vv = np.zeros(5) + + np.testing.assert_allclose( + fn([False, False, True, False, False], test_vv), + [0, 0, -0.91893853, 0, 0], + ) + + with pytest.raises( + NotImplementedError, + match="Measurable SetSubtensor not supported when set value is broadcasted.", + ): + fn([False, False, True, False, True], test_vv)