From 28b24eb476c8dff3b4ca37b6f6c9393dc83747fe Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sat, 6 Mar 2021 04:51:45 -0500 Subject: [PATCH 01/20] initial commit --- funsor/domains.py | 7 +++++++ funsor/ops/array.py | 6 ++++++ funsor/tensor.py | 3 +++ funsor/terms.py | 9 +++++++++ funsor/torch/ops.py | 15 +++++++++++++++ 5 files changed, 40 insertions(+) diff --git a/funsor/domains.py b/funsor/domains.py index 617e965c5..25c1b1ff0 100644 --- a/funsor/domains.py +++ b/funsor/domains.py @@ -353,6 +353,13 @@ def _transform_log_abs_det_jacobian(op, domain, codomain): return Real +@find_domain.register(ops.MeanOp) +@find_domain.register(ops.StdOp) +@find_domain.register(ops.VarOp) +def _find_domain_mean_std_var(op, domain): + return Array["real", ()] + + __all__ = [ "Bint", "BintType", diff --git a/funsor/ops/array.py b/funsor/ops/array.py index d31f95668..24bd5eff9 100644 --- a/funsor/ops/array.py +++ b/funsor/ops/array.py @@ -40,10 +40,13 @@ einsum = make_op("einsum") full_like = make_op(np.full_like) isnan = make_op(np.isnan) +mean = make_op(np.mean) prod = make_op(np.prod) stack = make_op("stack") +std = make_op(np.std) sum = make_op(np.sum) transpose = make_op("transpose") +var = make_op(np.var) sqrt.register(array)(np.sqrt) exp.register(array)(np.exp) @@ -362,6 +365,7 @@ def unsqueeze(x, dim): "isnan", "logaddexp", "logsumexp", + "mean", "new_arange", "new_eye", "new_full", @@ -372,10 +376,12 @@ def unsqueeze(x, dim): "scatter", "scatter_add", "stack", + "std", "sum", "transpose", "triangular_solve", "unsqueeze", + "var", ] declare_op_types(globals(), __all__, __name__) diff --git a/funsor/tensor.py b/funsor/tensor.py index 63f5ffb1e..6a3f10f54 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -1225,6 +1225,9 @@ def stack(parts, dim=0): ops.sample: ops.logsumexp, ops.min: ops.amin, ops.max: ops.amax, + ops.mean: ops.mean, + ops.std: ops.std, + ops.var: ops.var, } diff --git a/funsor/terms.py b/funsor/terms.py index 9423e0266..be86fa761 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -577,6 +577,15 @@ def min(self): def max(self): return Unary(ops.max, self) + def mean(self): + return Unary(ops.mean, self) + + def std(self): + return Unary(ops.std, self) + + def var(self): + return Unary(ops.var, self) + def __add__(self, other): return Binary(ops.add, self, to_funsor(other)) diff --git a/funsor/torch/ops.py b/funsor/torch/ops.py index f1a2fd11e..fa9d76193 100644 --- a/funsor/torch/ops.py +++ b/funsor/torch/ops.py @@ -172,6 +172,11 @@ def _max(x, y): return x.clamp(min=y) +@ops.mean.register(torch.Tensor, (int, type(None))) +def _mean(x, dim): + return x.mean() if dim is None else x.mean(dim) + + @ops.min.register(torch.Tensor, torch.Tensor) def _min(x, y): return torch.min(x, y) @@ -279,6 +284,11 @@ def _stack(dim, *x): return torch.stack(x, dim=dim) +@ops.std.register(torch.Tensor, (int, type(None))) +def _std(x, dim): + return x.std() if dim is None else x.std(dim) + + @ops.sum.register(torch.Tensor, (int, type(None))) def _sum(x, dim): return x.sum() if dim is None else x.sum(dim) @@ -287,3 +297,8 @@ def _sum(x, dim): @ops.triangular_solve.register(torch.Tensor, torch.Tensor) def _triangular_solve(x, y, upper=False, transpose=False): return x.triangular_solve(y, upper, transpose).solution + + +@ops.var.register(torch.Tensor, (int, type(None))) +def _var(x, dim): + return x.var() if dim is None else x.var(dim) From 8a6b8c7df560e5d6cc77e9a0d97f59eb772230f5 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sat, 6 Mar 2021 15:03:27 -0500 Subject: [PATCH 02/20] jax implementation --- funsor/jax/ops.py | 15 +++++++++++++++ funsor/tensor.py | 9 +++++++++ test/test_terms.py | 23 +++++++++++++++++++++++ 3 files changed, 47 insertions(+) diff --git a/funsor/jax/ops.py b/funsor/jax/ops.py index 12befbb21..01a9ed326 100644 --- a/funsor/jax/ops.py +++ b/funsor/jax/ops.py @@ -147,6 +147,11 @@ def _log(x): return np.log(x) +@ops.mean.register(array, (int, type(None))) +def _mean(x, dim): + return x.mean() if dim is None else x.mean(dim) + + @ops.logaddexp.register(array, array) def _safe_logaddexp_tensor_tensor(x, y): finfo = np.finfo(x.dtype) @@ -166,6 +171,11 @@ def _safe_logaddexp_tensor_number(x, y): return _safe_logaddexp_number_tensor(y, x) +@ops.std.register(array, (int, type(None))) +def _std(x, dim): + return x.std() if dim is None else x.std(dim) + + @ops.logsumexp.register(array, (int, type(None))) def _logsumexp(x, dim): return logsumexp(x, axis=dim) @@ -324,3 +334,8 @@ def _triangular_solve(x, y, upper=False, transpose=False): permute_inv_dims += (sol.ndim - 1, prepend_ndim + y.ndim - 2) sol = np.transpose(sol, permute_inv_dims) return sol.reshape(batch_shape + (n, m)) + + +@ops.var.register(array, (int, type(None))) +def _var(x, dim): + return x.var() if dim is None else x.var(dim) diff --git a/funsor/tensor.py b/funsor/tensor.py index 6a3f10f54..d4b115392 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -302,6 +302,10 @@ def eager_unary(self, op): data = self.data.reshape(self.data.shape[:batch_dim] + (-1,)) data = REDUCE_OP_TO_NUMERIC[op](data, -1) return Tensor(data, self.inputs, dtype) + if op in OP_TO_NUMERIC: + batch_dim = len(self.data.shape) - len(self.output.shape) + data = self.data.reshape(self.data.shape[:batch_dim] + (-1,)) + return Tensor(op(data, -1), self.inputs, dtype) return Tensor(op(self.data), self.inputs, dtype) def eager_reduce(self, op, reduced_vars): @@ -1225,6 +1229,10 @@ def stack(parts, dim=0): ops.sample: ops.logsumexp, ops.min: ops.amin, ops.max: ops.amax, +} + + +OP_TO_NUMERIC = { ops.mean: ops.mean, ops.std: ops.std, ops.var: ops.var, @@ -1235,6 +1243,7 @@ def stack(parts, dim=0): "Einsum", "Function", "REDUCE_OP_TO_NUMERIC", + "OP_TO_NUMERIC", "Tensor", "align_tensor", "align_tensors", diff --git a/test/test_terms.py b/test/test_terms.py index e495e3e32..88ed47a7b 100644 --- a/test/test_terms.py +++ b/test/test_terms.py @@ -287,6 +287,29 @@ def test_unary(symbol, data): check_funsor(actual, {}, Array[dtype, ()], expected_data) +@pytest.mark.parametrize("event_shape", [(4,), (3, 2)], ids=str) +@pytest.mark.parametrize( + "name", + [ + "all", + "any", + "logsumexp", + "max", + "mean", + "min", + "prod", + "std", + "sum", + "var", + ], +) +def test_reduce_event(name, event_shape): + dtype = 2 if name in ("any", "all") else "real" + x = random_tensor(OrderedDict(i=Bint[5]), output=Array[dtype, event_shape]) + actual = getattr(x, name)() + check_funsor(actual, x.inputs, Array[dtype, ()]) + + BINARY_OPS = [ "+", "-", From fadac8c29df2a62ea12e0fe5ddb5e5292d9c1e50 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sat, 6 Mar 2021 15:17:38 -0500 Subject: [PATCH 03/20] test_reduce_event --- test/test_terms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_terms.py b/test/test_terms.py index 88ed47a7b..9792ef1ee 100644 --- a/test/test_terms.py +++ b/test/test_terms.py @@ -291,6 +291,7 @@ def test_unary(symbol, data): @pytest.mark.parametrize( "name", [ + "var", "all", "any", "logsumexp", @@ -300,7 +301,6 @@ def test_unary(symbol, data): "prod", "std", "sum", - "var", ], ) def test_reduce_event(name, event_shape): From 6b24ef44f92b85b27efbfd4bb00410c22084cd42 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sat, 6 Mar 2021 16:57:52 -0500 Subject: [PATCH 04/20] ddof --- funsor/jax/ops.py | 8 ++++---- funsor/ops/array.py | 14 ++++++++++++-- funsor/torch/ops.py | 21 +++++++++++++++++---- 3 files changed, 33 insertions(+), 10 deletions(-) diff --git a/funsor/jax/ops.py b/funsor/jax/ops.py index 01a9ed326..5b44942bb 100644 --- a/funsor/jax/ops.py +++ b/funsor/jax/ops.py @@ -172,8 +172,8 @@ def _safe_logaddexp_tensor_number(x, y): @ops.std.register(array, (int, type(None))) -def _std(x, dim): - return x.std() if dim is None else x.std(dim) +def _std(x, dim, ddof=0): + return x.std(ddof=ddof) if dim is None else x.std(dim, ddof=ddof) @ops.logsumexp.register(array, (int, type(None))) @@ -337,5 +337,5 @@ def _triangular_solve(x, y, upper=False, transpose=False): @ops.var.register(array, (int, type(None))) -def _var(x, dim): - return x.var() if dim is None else x.var(dim) +def _var(x, dim, ddof=0): + return x.var(ddof=ddof) if dim is None else x.var(dim, ddof=ddof) diff --git a/funsor/ops/array.py b/funsor/ops/array.py index 24bd5eff9..be53966ed 100644 --- a/funsor/ops/array.py +++ b/funsor/ops/array.py @@ -43,10 +43,10 @@ mean = make_op(np.mean) prod = make_op(np.prod) stack = make_op("stack") -std = make_op(np.std) +std = make_op("std") sum = make_op(np.sum) transpose = make_op("transpose") -var = make_op(np.var) +var = make_op("var") sqrt.register(array)(np.sqrt) exp.register(array)(np.exp) @@ -321,6 +321,11 @@ def _stack(dim, *x): return np.stack(x, axis=dim) +@std.register(array, (int, type(None))) +def _std(x, dim, ddof=0): + return x.std(ddof=ddof) if dim is None else x.std(dim, ddof=ddof) + + @transpose.register(array, int, int) def _transpose(x, dim1, dim2): return np.swapaxes(x, dim1, dim2) @@ -338,6 +343,11 @@ def unsqueeze(x, dim): return np.expand_dims(x, axis=dim) +@var.register(array, (int, type(None))) +def _var(x, dim, ddof=0): + return x.var(ddof=ddof) if dim is None else x.var(dim, ddof=ddof) + + DISTRIBUTIVE_OPS.add((logaddexp, add)) DISTRIBUTIVE_OPS.add((sample, add)) diff --git a/funsor/torch/ops.py b/funsor/torch/ops.py index fa9d76193..b32cd8d2a 100644 --- a/funsor/torch/ops.py +++ b/funsor/torch/ops.py @@ -1,6 +1,7 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +import math import numbers import torch @@ -285,8 +286,14 @@ def _stack(dim, *x): @ops.std.register(torch.Tensor, (int, type(None))) -def _std(x, dim): - return x.std() if dim is None else x.std(dim) +def _std(x, dim, ddof=0): + if ddof == 0: + return x.std(unbiased=False) if dim is None else x.std(dim, unbiased=False) + if ddof == 1: + return x.std() if dim is None else x.std(dim) + N = x.numel() if dim is None else x.shape[dim] + correction = math.sqrt((N - 1) / (N - ddof)) + return x.std() * correction if dim is None else x.std(dim) * correction @ops.sum.register(torch.Tensor, (int, type(None))) @@ -300,5 +307,11 @@ def _triangular_solve(x, y, upper=False, transpose=False): @ops.var.register(torch.Tensor, (int, type(None))) -def _var(x, dim): - return x.var() if dim is None else x.var(dim) +def _var(x, dim, ddof=0): + if ddof == 0: + return x.var(unbiased=False) if dim is None else x.var(dim, unbiased=False) + if ddof == 1: + return x.var() if dim is None else x.var(dim) + N = x.numel() if dim is None else x.shape[dim] + correction = (N - 1) / (N - ddof) + return x.var() * correction if dim is None else x.var(dim) * correction From 10a85a60a271b31a6e77504e7bd4a8a9707d5c68 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sat, 6 Mar 2021 17:06:09 -0500 Subject: [PATCH 05/20] NUMERIC_OPS --- funsor/tensor.py | 10 +++++----- test/test_terms.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/funsor/tensor.py b/funsor/tensor.py index d4b115392..d53e6f41c 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -1232,11 +1232,11 @@ def stack(parts, dim=0): } -OP_TO_NUMERIC = { - ops.mean: ops.mean, - ops.std: ops.std, - ops.var: ops.var, -} +NUMERIC_OPS = [ + ops.mean, + ops.std, + ops.var, +] __all__ = [ diff --git a/test/test_terms.py b/test/test_terms.py index 7db200866..61032f918 100644 --- a/test/test_terms.py +++ b/test/test_terms.py @@ -280,7 +280,6 @@ def test_unary(symbol, data): @pytest.mark.parametrize( "name", [ - "var", "all", "any", "logsumexp", @@ -290,6 +289,7 @@ def test_unary(symbol, data): "prod", "std", "sum", + "var", ], ) def test_reduce_event(name, event_shape): From b95ea83c906f99ecef7c458e2c12393fbcd37b83 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sat, 6 Mar 2021 17:08:22 -0500 Subject: [PATCH 06/20] fix missed renaming --- funsor/tensor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/funsor/tensor.py b/funsor/tensor.py index d53e6f41c..95e1849db 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -302,7 +302,7 @@ def eager_unary(self, op): data = self.data.reshape(self.data.shape[:batch_dim] + (-1,)) data = REDUCE_OP_TO_NUMERIC[op](data, -1) return Tensor(data, self.inputs, dtype) - if op in OP_TO_NUMERIC: + if op in NUMERIC_OPS: batch_dim = len(self.data.shape) - len(self.output.shape) data = self.data.reshape(self.data.shape[:batch_dim] + (-1,)) return Tensor(op(data, -1), self.inputs, dtype) @@ -1243,7 +1243,7 @@ def stack(parts, dim=0): "Einsum", "Function", "REDUCE_OP_TO_NUMERIC", - "OP_TO_NUMERIC", + "NUMERIC_OPS", "Tensor", "align_tensor", "align_tensors", From f59f5e24371197ad13d08c2bd16d184c7fd2a3c2 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sun, 7 Mar 2021 01:59:30 -0500 Subject: [PATCH 07/20] use CachedOpMeta --- funsor/domains.py | 11 +++++++- funsor/jax/ops.py | 18 ++++++------- funsor/ops/array.py | 64 ++++++++++++++++++++++++++++++++++++++++++--- funsor/tensor.py | 12 +++++++-- funsor/terms.py | 12 ++++----- funsor/torch/ops.py | 33 +++++++++++------------ test/test_terms.py | 2 +- 7 files changed, 112 insertions(+), 40 deletions(-) diff --git a/funsor/domains.py b/funsor/domains.py index 25c1b1ff0..474b25999 100644 --- a/funsor/domains.py +++ b/funsor/domains.py @@ -357,7 +357,16 @@ def _transform_log_abs_det_jacobian(op, domain, codomain): @find_domain.register(ops.StdOp) @find_domain.register(ops.VarOp) def _find_domain_mean_std_var(op, domain): - return Array["real", ()] + event_dim = len(domain.shape) + if op.axis is None: + shape = () + elif isinstance(op.axis, int): + shape = tuple(domain[i] for i in range(event_dim) if i != op.axis) + elif isinstance(op.axis, tuple): + shape = tuple(domain[i] for i in range(event_dim) if i not in op.axis) + else: + raise ValueError + return Array["real", shape] __all__ = [ diff --git a/funsor/jax/ops.py b/funsor/jax/ops.py index 1cf5accde..9f73c9592 100644 --- a/funsor/jax/ops.py +++ b/funsor/jax/ops.py @@ -147,9 +147,9 @@ def _log(x): return np.log(x) -@ops.mean.register(array, (int, type(None))) -def _mean(x, dim): - return x.mean() if dim is None else x.mean(dim) +@ops.mean.register(array, (tuple, int, type(None)), bool) +def _mean(x, axis, keepdims): + return x.mean(axis, keepdims=keepdims) @ops.logaddexp.register(array, array) @@ -171,9 +171,9 @@ def _safe_logaddexp_tensor_number(x, y): return _safe_logaddexp_number_tensor(y, x) -@ops.std.register(array, (int, type(None))) -def _std(x, dim, ddof=0): - return x.std(ddof=ddof) if dim is None else x.std(dim, ddof=ddof) +@ops.std.register(array, (tuple, int, type(None)), int, bool) +def _std(x, axis, ddof, keepdims): + return x.std(axis, ddof=ddof, keepdims=keepdims) @ops.logsumexp.register(array, (int, type(None))) @@ -333,6 +333,6 @@ def _triangular_solve(x, y, upper=False, transpose=False): return sol.reshape(batch_shape + (n, m)) -@ops.var.register(array, (int, type(None))) -def _var(x, dim, ddof=0): - return x.var(ddof=ddof) if dim is None else x.var(dim, ddof=ddof) +@ops.var.register(array, (tuple, int, type(None)), int, bool) +def _var(x, axis, ddof, keepdims): + return x.var(axis, ddof=ddof, keepdims=keepdims) diff --git a/funsor/ops/array.py b/funsor/ops/array.py index be53966ed..3d2ae98e0 100644 --- a/funsor/ops/array.py +++ b/funsor/ops/array.py @@ -40,13 +40,10 @@ einsum = make_op("einsum") full_like = make_op(np.full_like) isnan = make_op(np.isnan) -mean = make_op(np.mean) prod = make_op(np.prod) stack = make_op("stack") -std = make_op("std") sum = make_op(np.sum) transpose = make_op("transpose") -var = make_op("var") sqrt.register(array)(np.sqrt) exp.register(array)(np.exp) @@ -72,6 +69,67 @@ def _logaddexp(x, y): sample = make_op(_logaddexp, type(logaddexp), name="sample") +class MeanOpMeta(CachedOpMeta): + def __call__(cls, axis=None, keepdims=False): + return super().__call__(axis, keepdims) + + +class MeanOp(Op, metaclass=MeanOpMeta): + def __init__(self, axis, keepdims): + self.axis = axis + self.keepdims = keepdims + super().__init__(self._default) + + def _reduce(self): + return MeanOp, (self.axis, self.keepdims) + + def _default(self, x): + return x.mean(self.axis, keepdims=self.keepdims) + + +class StdOpMeta(CachedOpMeta): + def __call__(cls, axis=None, ddof=0, keepdims=False): + return super().__call__(axis, ddof, keepdims) + + +class StdOp(Op, metaclass=StdOpMeta): + def __init__(self, axis, ddof, keepdims): + self.axis = axis + self.ddof = ddof + self.keepdims = keepdims + super().__init__(self._default) + + def _reduce(self): + return StdOp, (self.axis, self.ddof, self.keepdims) + + def _default(self, x): + return x.std(self.axis, ddof=self.ddof, keepdims=self.keepdims) + + +class VarOpMeta(CachedOpMeta): + def __call__(cls, axis=None, ddof=0, keepdims=False): + return super().__call__(axis, ddof, keepdims) + + +class VarOp(Op, metaclass=VarOpMeta): + def __init__(self, axis, ddof, keepdims): + self.axis = axis + self.ddof = ddof + self.keepdims = keepdims + super().__init__(self._default) + + def _reduce(self): + return VarOp, (self.axis, self.ddof, self.keepdims) + + def _default(self, x): + return x.var(self.axis, ddof=self.ddof, keepdims=self.keepdims) + + +mean = MeanOp() +std = StdOp() +var = VarOp() + + class ReshapeMeta(CachedOpMeta): def __call__(cls, shape): shape = tuple(shape) # necessary to convert torch.Size to tuple diff --git a/funsor/tensor.py b/funsor/tensor.py index 95e1849db..8676b6ef3 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -304,8 +304,16 @@ def eager_unary(self, op): return Tensor(data, self.inputs, dtype) if op in NUMERIC_OPS: batch_dim = len(self.data.shape) - len(self.output.shape) - data = self.data.reshape(self.data.shape[:batch_dim] + (-1,)) - return Tensor(op(data, -1), self.inputs, dtype) + event_dim = len(self.output.shape) + if op.axis is None: + op.axis = tuple(batch_dim + i for i in range(event_dim)) + elif isinstance(op.axis, int): + op.axis = batch_dim + op.axis % event_dim + elif isinstance(op.axis, tuple): + op.axis = tuple(batch_dim + i % event_dim for i in op.axis) + else: + raise ValueError + return Tensor(op(self.data), self.inputs, dtype) return Tensor(op(self.data), self.inputs, dtype) def eager_reduce(self, op, reduced_vars): diff --git a/funsor/terms.py b/funsor/terms.py index be86fa761..5bf8cb04b 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -577,14 +577,14 @@ def min(self): def max(self): return Unary(ops.max, self) - def mean(self): - return Unary(ops.mean, self) + def mean(self, axis=None, keepdims=False): + return Unary(ops.MeanOp(axis, keepdims), self) - def std(self): - return Unary(ops.std, self) + def std(self, axis=None, ddof=0, keepdims=False): + return Unary(ops.StdOp(axis, ddof, keepdims), self) - def var(self): - return Unary(ops.var, self) + def var(self, axis=None, ddof=0, keepdims=False): + return Unary(ops.VarOp(axis, ddof, keepdims), self) def __add__(self, other): return Binary(ops.add, self, to_funsor(other)) diff --git a/funsor/torch/ops.py b/funsor/torch/ops.py index b32cd8d2a..4babc30cb 100644 --- a/funsor/torch/ops.py +++ b/funsor/torch/ops.py @@ -1,7 +1,6 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -import math import numbers import torch @@ -173,9 +172,9 @@ def _max(x, y): return x.clamp(min=y) -@ops.mean.register(torch.Tensor, (int, type(None))) -def _mean(x, dim): - return x.mean() if dim is None else x.mean(dim) +@ops.mean.register(torch.Tensor, (tuple, int, type(None)), bool) +def _mean(x, dim, keepdim): + return x.flatten().mean(dim, keepdim=keepdim) @ops.min.register(torch.Tensor, torch.Tensor) @@ -285,15 +284,14 @@ def _stack(dim, *x): return torch.stack(x, dim=dim) -@ops.std.register(torch.Tensor, (int, type(None))) -def _std(x, dim, ddof=0): +@ops.std.register(torch.Tensor, (tuple, int, type(None)), int, bool) +def _std(x, dim, ddof, keepdim): + dim = tuple(x.shape) if dim is None else dim if ddof == 0: - return x.std(unbiased=False) if dim is None else x.std(dim, unbiased=False) + return x.std(dim, unbiased=False, keepdim=keepdim) if ddof == 1: - return x.std() if dim is None else x.std(dim) - N = x.numel() if dim is None else x.shape[dim] - correction = math.sqrt((N - 1) / (N - ddof)) - return x.std() * correction if dim is None else x.std(dim) * correction + return x.std(dim, keepdim=keepdim) + raise NotImplementedError @ops.sum.register(torch.Tensor, (int, type(None))) @@ -306,12 +304,11 @@ def _triangular_solve(x, y, upper=False, transpose=False): return x.triangular_solve(y, upper, transpose).solution -@ops.var.register(torch.Tensor, (int, type(None))) -def _var(x, dim, ddof=0): +@ops.var.register(torch.Tensor, (tuple, int, type(None)), int, bool) +def _var(x, dim, ddof, keepdim): + dim = tuple(x.shape) if dim is None else dim if ddof == 0: - return x.var(unbiased=False) if dim is None else x.var(dim, unbiased=False) + return x.var(dim, unbiased=False, keepdim=keepdim) if ddof == 1: - return x.var() if dim is None else x.var(dim) - N = x.numel() if dim is None else x.shape[dim] - correction = (N - 1) / (N - ddof) - return x.var() * correction if dim is None else x.var(dim) * correction + return x.var(dim, keepdim=keepdim) + raise NotImplementedError diff --git a/test/test_terms.py b/test/test_terms.py index 61032f918..d18f4ad6b 100644 --- a/test/test_terms.py +++ b/test/test_terms.py @@ -276,7 +276,7 @@ def test_unary(symbol, data): check_funsor(actual, {}, Array[dtype, ()], expected_data) -@pytest.mark.parametrize("event_shape", [(4,), (3, 2)], ids=str) +@pytest.mark.parametrize("event_shape", [(3, 2)], ids=str) @pytest.mark.parametrize( "name", [ From c999470ca3636a301becf598f9ae5005a9cae827 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sun, 7 Mar 2021 14:44:54 -0500 Subject: [PATCH 08/20] remove obsolete code --- funsor/ops/array.py | 10 ---------- funsor/torch/ops.py | 4 +++- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/funsor/ops/array.py b/funsor/ops/array.py index 3d2ae98e0..6e7115725 100644 --- a/funsor/ops/array.py +++ b/funsor/ops/array.py @@ -379,11 +379,6 @@ def _stack(dim, *x): return np.stack(x, axis=dim) -@std.register(array, (int, type(None))) -def _std(x, dim, ddof=0): - return x.std(ddof=ddof) if dim is None else x.std(dim, ddof=ddof) - - @transpose.register(array, int, int) def _transpose(x, dim1, dim2): return np.swapaxes(x, dim1, dim2) @@ -401,11 +396,6 @@ def unsqueeze(x, dim): return np.expand_dims(x, axis=dim) -@var.register(array, (int, type(None))) -def _var(x, dim, ddof=0): - return x.var(ddof=ddof) if dim is None else x.var(dim, ddof=ddof) - - DISTRIBUTIVE_OPS.add((logaddexp, add)) DISTRIBUTIVE_OPS.add((sample, add)) diff --git a/funsor/torch/ops.py b/funsor/torch/ops.py index 4babc30cb..f0bce2a06 100644 --- a/funsor/torch/ops.py +++ b/funsor/torch/ops.py @@ -172,8 +172,10 @@ def _max(x, y): return x.clamp(min=y) -@ops.mean.register(torch.Tensor, (tuple, int, type(None)), bool) +# @ops.mean.register(torch.Tensor, (tuple, int, type(None)), bool) +@ops.mean.register(torch.Tensor) def _mean(x, dim, keepdim): + breakpoint() return x.flatten().mean(dim, keepdim=keepdim) From 5085d25e7b4f7d22f6c0ee2721e710bf53b31e39 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Thu, 18 Mar 2021 16:09:23 -0400 Subject: [PATCH 09/20] ReductionOp; fix old definitions --- funsor/domains.py | 24 ++++-------------------- funsor/jax/ops.py | 18 +++++++++--------- funsor/ops/array.py | 22 +++++++++++++++++++--- funsor/ops/op.py | 8 ++++++++ funsor/tensor.py | 8 ++++---- funsor/terms.py | 16 ++++++++-------- funsor/torch/ops.py | 11 +++++------ 7 files changed, 57 insertions(+), 50 deletions(-) diff --git a/funsor/domains.py b/funsor/domains.py index 4208a236f..a5b5eb42d 100644 --- a/funsor/domains.py +++ b/funsor/domains.py @@ -248,7 +248,7 @@ def _find_domain_log_exp(op, domain): return Array["real", domain.shape] -@find_domain.register(ops.SumOp) +@find_domain.register(ops.ReductionOp) def _find_domain_sum(op, domain): # Canonicalize dim. dim = op.defaults.get("dim", None) @@ -261,10 +261,10 @@ def _find_domain_sum(op, domain): dims = {i % ndims for i in dim} # Compute shape. - if op.defaults.get("keepdims", False): - shape = tuple(1 if i in dims else size for i, size in enumerate(domain.shape)) + if op.defaults.get("keepdim", False): + shape = tuple(1 if i in dims else domain[i] for i in range(ndims)) else: - shape = tuple(size for i, size in enumerate(domain.shape) if i not in dims) + shape = tuple(domain[i] for i in range(ndims) if i not in dims) # Compute domain. if domain.dtype == "real": @@ -381,22 +381,6 @@ def _transform_log_abs_det_jacobian(op, domain, codomain): return Real -@find_domain.register(ops.MeanOp) -@find_domain.register(ops.StdOp) -@find_domain.register(ops.VarOp) -def _find_domain_mean_std_var(op, domain): - event_dim = len(domain.shape) - if op.axis is None: - shape = () - elif isinstance(op.axis, int): - shape = tuple(domain[i] for i in range(event_dim) if i != op.axis) - elif isinstance(op.axis, tuple): - shape = tuple(domain[i] for i in range(event_dim) if i not in op.axis) - else: - raise ValueError - return Array["real", shape] - - @find_domain.register(ops.StackOp) def _find_domain_stack(op, parts): shape = broadcast_shape(*(x.shape for x in parts)) diff --git a/funsor/jax/ops.py b/funsor/jax/ops.py index 600ea8a09..eb053883b 100644 --- a/funsor/jax/ops.py +++ b/funsor/jax/ops.py @@ -142,9 +142,9 @@ def _log(x): return np.log(x) -@ops.mean.register(array, (tuple, int, type(None)), bool) -def _mean(x, axis, keepdims): - return x.mean(axis, keepdims=keepdims) +@ops.mean.register(array) +def _mean(x, dim, keepdim): + return x.mean(dim, keepdims=keepdim) @ops.logaddexp.register(array, array) @@ -166,9 +166,9 @@ def _safe_logaddexp_tensor_number(x, y): return _safe_logaddexp_number_tensor(y, x) -@ops.std.register(array, (tuple, int, type(None)), int, bool) -def _std(x, axis, ddof, keepdims): - return x.std(axis, ddof=ddof, keepdims=keepdims) +@ops.std.register(array) +def _std(x, dim, ddof, keepdim): + return x.std(dim, ddof=ddof, keepdims=keepdim) @ops.logsumexp.register(array) @@ -326,6 +326,6 @@ def _triangular_solve(x, y, upper=False, transpose=False): return sol.reshape(batch_shape + (n, m)) -@ops.var.register(array, (tuple, int, type(None)), int, bool) -def _var(x, axis, ddof, keepdims): - return x.var(axis, ddof=ddof, keepdims=keepdims) +@ops.var.register(array) +def _var(x, dim, ddof, keepdim): + return x.var(dim, ddof=ddof, keepdims=keepdim) diff --git a/funsor/ops/array.py b/funsor/ops/array.py index 97faf4693..498bde7fa 100644 --- a/funsor/ops/array.py +++ b/funsor/ops/array.py @@ -30,6 +30,7 @@ FinitaryOp, Op, OpMeta, + ReductionOp, TernaryOp, UnaryOp, declare_op_types, @@ -69,9 +70,24 @@ def amin(x, dim=None, keepdims=False): return np.amax(x, dim, keepdims=keepdims) -@UnaryOp.make -def sum(x, dim=None, keepdims=False): - return np.sum(x, dim, keepdims=keepdims) +@ReductionOp.make +def sum(x, dim=None, keepdim=False): + return np.sum(x, dim, keepdims=keepdim) + + +@ReductionOp.make +def mean(x, dim=None, keepdim=False): + return np.mean(x, dim, keepdims=keepdim) + + +@ReductionOp.make +def std(x, dim=None, ddof=0, keepdim=False): + return np.std(x, dim, ddof=ddof, keepdims=keepdim) + + +@ReductionOp.make +def var(x, dim=None, ddof=0, keepdim=False): + return np.var(x, dim, ddof=ddof, keepdims=keepdim) @UnaryOp.make diff --git a/funsor/ops/op.py b/funsor/ops/op.py index 34bfe20b7..3080b6ee5 100644 --- a/funsor/ops/op.py +++ b/funsor/ops/op.py @@ -282,6 +282,14 @@ def _list_to_tuple(cls, arg, *args, **kwargs): return op(arg) +class ReductionOp(UnaryOp): + """ + Here reduction operations are defined in a broad sense, not only + associative operations. This helps to unify find_domain logic. + """ + pass + + class TransformOp(UnaryOp): def set_inv(self, fn): """ diff --git a/funsor/tensor.py b/funsor/tensor.py index 644959249..e1c74f31f 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -784,14 +784,14 @@ def eager_reshape_tensor(op, arg): return Tensor(data, arg.inputs, arg.dtype) -@eager.register(Unary, ops.SumOp, Tensor) -def eager_sum_tensor(op, arg): +@eager.register(Unary, ops.ReductionOp, Tensor) +def eager_reduction_tensor(op, arg): if not arg.inputs: return Tensor(op(arg.data), arg.inputs, arg.dtype) # Work around batch inputs. dim = op.defaults.get("dim", None) - keepdims = op.defaults.get("keepdims", False) + keepdim = op.defaults.get("keepdim", False) ndims = len(arg.output.shape) if dim is None: dim = tuple(range(-ndims, 0)) @@ -799,7 +799,7 @@ def eager_sum_tensor(op, arg): dim = dim % ndims - ndims else: dim = tuple(d % ndims - ndims for d in dim) - data = op(arg.data, dim, keepdims) + data = op(arg.data, dim, keepdim) return Tensor(data, arg.inputs, arg.dtype) diff --git a/funsor/terms.py b/funsor/terms.py index 3b833674d..7249cc4b5 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -559,8 +559,8 @@ def reshape(self, shape): # reduce over output shape while preserving all inputs. # To reduce over inputs, instead call .reduce(op, reduced_vars). - def sum(self): - return Unary(ops.add, self) + def sum(self, dim=None, keepdim=False): + return Unary(ops.SumOp(dim, keepdim=keepdim), self) def prod(self): return Unary(ops.mul, self) @@ -580,14 +580,14 @@ def min(self): def max(self): return Unary(ops.max, self) - def mean(self, axis=None, keepdims=False): - return Unary(ops.MeanOp(axis, keepdims), self) + def mean(self, dim=None, keepdim=False): + return Unary(ops.MeanOp(dim, keepdim), self) - def std(self, axis=None, ddof=0, keepdims=False): - return Unary(ops.StdOp(axis, ddof, keepdims), self) + def std(self, dim=None, ddof=0, keepdim=False): + return Unary(ops.StdOp(dim, ddof, keepdim), self) - def var(self, axis=None, ddof=0, keepdims=False): - return Unary(ops.VarOp(axis, ddof, keepdims), self) + def var(self, dim=None, ddof=0, keepdim=False): + return Unary(ops.VarOp(dim, ddof, keepdim), self) def __add__(self, other): return Binary(ops.add, self, to_funsor(other)) diff --git a/funsor/torch/ops.py b/funsor/torch/ops.py index cca6ea99f..4ed299386 100644 --- a/funsor/torch/ops.py +++ b/funsor/torch/ops.py @@ -166,11 +166,10 @@ def _max(x, y): return x.clamp(min=y) -# @ops.mean.register(torch.Tensor, (tuple, int, type(None)), bool) @ops.mean.register(torch.Tensor) -def _mean(x, dim, keepdim): - breakpoint() - return x.flatten().mean(dim, keepdim=keepdim) +def _mean(x, dim=None, keepdim=False): + dim = tuple(x.shape) if dim is None else dim + return x.mean(dim, keepdim=keepdim) @ops.min.register(torch.Tensor, torch.Tensor) @@ -277,7 +276,7 @@ def _scatter_add(destin, indices, source): ops.stack.register(typing.Tuple[torch.Tensor, ...])(torch.stack) -@ops.std.register(torch.Tensor, (tuple, int, type(None)), int, bool) +@ops.std.register(torch.Tensor) def _std(x, dim, ddof, keepdim): dim = tuple(x.shape) if dim is None else dim if ddof == 0: @@ -302,7 +301,7 @@ def _triangular_solve(x, y, upper=False, transpose=False): return x.triangular_solve(y, upper, transpose).solution -@ops.var.register(torch.Tensor, (tuple, int, type(None)), int, bool) +@ops.var.register(torch.Tensor) def _var(x, dim, ddof, keepdim): dim = tuple(x.shape) if dim is None else dim if ddof == 0: From c91b758233b4b3ad873356a61ff0fc0df433cea7 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Thu, 18 Mar 2021 19:18:01 -0400 Subject: [PATCH 10/20] fix all reduction ops --- funsor/domains.py | 3 +- funsor/jax/ops.py | 103 ++++++++++++++++------------ funsor/ops/array.py | 84 +++++++++++++---------- funsor/ops/op.py | 1 + funsor/tensor.py | 29 -------- funsor/terms.py | 34 ++++++---- funsor/torch/ops.py | 160 ++++++++++++++++++++++++++++---------------- test/test_terms.py | 33 +++++++-- 8 files changed, 262 insertions(+), 185 deletions(-) diff --git a/funsor/domains.py b/funsor/domains.py index a5b5eb42d..bec28ad40 100644 --- a/funsor/domains.py +++ b/funsor/domains.py @@ -249,7 +249,7 @@ def _find_domain_log_exp(op, domain): @find_domain.register(ops.ReductionOp) -def _find_domain_sum(op, domain): +def _find_domain_reduction(op, domain): # Canonicalize dim. dim = op.defaults.get("dim", None) ndims = len(domain.shape) @@ -272,6 +272,7 @@ def _find_domain_sum(op, domain): else: raise NotImplementedError("TODO") + breakpoint() return Array[dtype, shape] diff --git a/funsor/jax/ops.py b/funsor/jax/ops.py index eb053883b..d740d40e2 100644 --- a/funsor/jax/ops.py +++ b/funsor/jax/ops.py @@ -34,29 +34,78 @@ ops.unsqueeze.register(array)(np.expand_dims) +########################################### +# Reduction Ops +########################################### + + @ops.all.register(array) -def _all(x, dim): - return np.all(x, axis=dim) +def _all(x, dim, keepdim): + return np.all(x, dim, keepdims=keepdim) + + +@ops.any.register(array) +def _any(x, dim, keepdim): + return np.any(x, dim, keepdims=keepdim) + + +@ops.argmax.register(array) +def _argmax(x, dim, keepdim): + if keepdim: + return np.expand_dims(np.argmax(x, dim), dim) + else: + np.argmax(x, dim) + + +@ops.argmin.register(array) +def _argmin(x, dim, keepdim): + if keepdim: + return np.expand_dims(np.argmin(x, dim), dim) + else: + return np.argmin(x, dim) @ops.amax.register(array) -def _amax(x, dim, keepdims=False): - return np.amax(x, axis=dim, keepdims=keepdims) +def _amax(x, dim, keepdim): + return np.amax(x, dim, keepdims=keepdim) @ops.amin.register(array) -def _amin(x, dim, keepdims=False): - return np.amin(x, axis=dim, keepdims=keepdims) +def _amin(x, dim, keepdim): + return np.amax(x, dim, keepdims=keepdim) -@ops.argmax.register(array) -def _argmax(x, dim): - return np.argmax(x, dim) +@ops.sum.register(array) +def _sum(x, dim, keepdim): + return np.sum(x, dim, keepdims=keepdim) -@ops.any.register(array) -def _any(x, dim): - return np.any(x, axis=dim) +@ops.prod.register(array) +def _prod(x, dim, keepdim): + return np.prod(x, dim, keepdims=keepdim) + + +@ops.logsumexp.register(array) +def _logsumexp(x, dim, keepdim): + return logsumexp(x, dim, keepdims=keepdim) + + +@ops.mean.register(array) +def _mean(x, dim, keepdim): + return np.mean(x, dim, keepdims=keepdim) + + +@ops.std.register(array) +def _std(x, dim, ddof, keepdim): + return np.std(x, dim, ddof=ddof, keepdims=keepdim) + + +@ops.var.register(array) +def _var(x, dim, ddof, keepdim): + return np.var(x, dim, ddof=ddof, keepdims=keepdim) + + +########################################### @ops.astype.register(array) @@ -142,11 +191,6 @@ def _log(x): return np.log(x) -@ops.mean.register(array) -def _mean(x, dim, keepdim): - return x.mean(dim, keepdims=keepdim) - - @ops.logaddexp.register(array, array) def _safe_logaddexp_tensor_tensor(x, y): finfo = np.finfo(np.result_type(x)) @@ -166,16 +210,6 @@ def _safe_logaddexp_tensor_number(x, y): return _safe_logaddexp_number_tensor(y, x) -@ops.std.register(array) -def _std(x, dim, ddof, keepdim): - return x.std(dim, ddof=ddof, keepdims=keepdim) - - -@ops.logsumexp.register(array) -def _logsumexp(x, dim): - return logsumexp(x, axis=dim) - - ops.max.register(array, array)(np.maximum) ops.min.register(array, array)(np.minimum) @@ -226,11 +260,6 @@ def _new_zeros(x, shape): return onp.zeros(shape, dtype=np.result_type(x)) -@ops.prod.register(array) -def _prod(x, dim): - return np.prod(x, axis=dim) - - @ops.reciprocal.register(array) def _reciprocal(x): result = np.clip(np.reciprocal(x), a_max=np.finfo(np.result_type(x)).max) @@ -267,11 +296,6 @@ def _stack(parts, dim=0): return np.stack(parts, axis=dim) -@ops.sum.register(array) -def _sum(x, dim, keepdims): - return np.sum(x, dim, keepdims=keepdims) - - @ops.triangular_solve.register(array, array) def _triangular_solve(x, y, upper=False, transpose=False): assert np.ndim(x) >= 2 and np.ndim(y) >= 2 @@ -324,8 +348,3 @@ def _triangular_solve(x, y, upper=False, transpose=False): permute_inv_dims += (sol.ndim - 1, prepend_ndim + y.ndim - 2) sol = np.transpose(sol, permute_inv_dims) return sol.reshape(batch_shape + (n, m)) - - -@ops.var.register(array) -def _var(x, dim, ddof, keepdim): - return x.var(dim, ddof=ddof, keepdims=keepdim) diff --git a/funsor/ops/array.py b/funsor/ops/array.py index 498bde7fa..709ab6d2d 100644 --- a/funsor/ops/array.py +++ b/funsor/ops/array.py @@ -50,24 +50,45 @@ atanh.register(array)(np.arctanh) -@UnaryOp.make -def all(x, dim=None): - return np.all(x, dim) +########################################### +# Reduction Ops +########################################### -@UnaryOp.make -def any(x, dim=None): - return np.any(x, dim) +@ReductionOp.make +def all(x, dim=None, keepdim=False): + return np.all(x, dim, keepdims=keepdim) -@UnaryOp.make -def amax(x, dim=None, keepdims=False): - return np.amax(x, dim, keepdims=keepdims) +@ReductionOp.make +def any(x, dim=None, keepdim=False): + return np.any(x, dim, keepdims=keepdim) -@UnaryOp.make -def amin(x, dim=None, keepdims=False): - return np.amax(x, dim, keepdims=keepdims) +@ReductionOp.make +def argmax(x, dim=None, keepdim=False): + if keepdim: + return np.expand_dims(np.argmax(x, dim), dim) + else: + return np.argmax(x, dim) + + +@ReductionOp.make +def argmin(x, dim=None, keepdim=False): + if keepdim: + return np.expand_dims(np.argmin(x, dim), dim) + else: + return np.argmin(x, dim) + + +@ReductionOp.make +def amax(x, dim=None, keepdim=False): + return np.amax(x, dim, keepdims=keepdim) + + +@ReductionOp.make +def amin(x, dim=None, keepdim=False): + return np.amax(x, dim, keepdims=keepdim) @ReductionOp.make @@ -75,6 +96,21 @@ def sum(x, dim=None, keepdim=False): return np.sum(x, dim, keepdims=keepdim) +@ReductionOp.make +def prod(x, dim=None, keepdim=False): + return np.prod(x, dim, keepdims=keepdim) + + +@ReductionOp.make +def logsumexp(x, dim, keepdim=False): + amax = np.amax(x, axis=dim, keepdims=True) + # treat the case x = -inf + amax = np.where(np.isfinite(amax), amax, 0.0) + unnormalized_lse = log(np.sum(np.exp(x - amax), dim, keepdims=keepdim)) + amax = amax if keepdim else amax.squeeze(dim) + return unnormalized_lse + amax + + @ReductionOp.make def mean(x, dim=None, keepdim=False): return np.mean(x, dim, keepdims=keepdim) @@ -90,11 +126,6 @@ def var(x, dim=None, ddof=0, keepdim=False): return np.var(x, dim, ddof=ddof, keepdims=keepdim) -@UnaryOp.make -def prod(x, dim=None): - return np.prod(x, dim) - - @UnaryOp.make def isnan(x): return np.isnan(x) @@ -263,14 +294,6 @@ def _safe_logaddexp_tensor_number(x, y): return _safe_logaddexp_number_tensor(y, x) -@UnaryOp.make -def logsumexp(x, dim): - amax = np.amax(x, axis=dim, keepdims=True) - # treat the case x = -inf - amax = np.where(np.isfinite(amax), amax, 0.0) - return log(np.sum(np.exp(x - amax), axis=dim)) + amax.squeeze(axis=dim) - - max.register(array, array)(np.maximum) min.register(array, array)(np.minimum) @@ -295,16 +318,6 @@ def _min(x, y): return np.clip(x, a_min=None, a_max=y) -@UnaryOp.make -def argmax(x, dim): - raise NotImplementedError - - -@argmax.register(array) -def _argmax(x, dim): - return np.argmax(x, dim) - - @UnaryOp.make def new_arange(x, start=None, stop=None, step=None): raise NotImplementedError @@ -429,6 +442,7 @@ def unsqueeze(x, dim): "amin", "any", "argmax", + "argmin", "astype", "cat", "cholesky", diff --git a/funsor/ops/op.py b/funsor/ops/op.py index 3080b6ee5..047412a46 100644 --- a/funsor/ops/op.py +++ b/funsor/ops/op.py @@ -287,6 +287,7 @@ class ReductionOp(UnaryOp): Here reduction operations are defined in a broad sense, not only associative operations. This helps to unify find_domain logic. """ + pass diff --git a/funsor/tensor.py b/funsor/tensor.py index e1c74f31f..8b4e4a0c2 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -295,27 +295,6 @@ def eager_subs(self, subs): data = self.data[tuple(index)] return Tensor(data, inputs, self.dtype) - def eager_unary(self, op): - dtype = find_domain(op, self.output).dtype - if op in REDUCE_OP_TO_NUMERIC: - batch_dim = len(self.data.shape) - len(self.output.shape) - data = self.data.reshape(self.data.shape[:batch_dim] + (-1,)) - data = REDUCE_OP_TO_NUMERIC[op](data, -1) - return Tensor(data, self.inputs, dtype) - if op in NUMERIC_OPS: - batch_dim = len(self.data.shape) - len(self.output.shape) - event_dim = len(self.output.shape) - if op.axis is None: - op.axis = tuple(batch_dim + i for i in range(event_dim)) - elif isinstance(op.axis, int): - op.axis = batch_dim + op.axis % event_dim - elif isinstance(op.axis, tuple): - op.axis = tuple(batch_dim + i % event_dim for i in op.axis) - else: - raise ValueError - return Tensor(op(self.data), self.inputs, dtype) - return Tensor(op(self.data), self.inputs, dtype) - def eager_reduce(self, op, reduced_vars): if op in REDUCE_OP_TO_NUMERIC: numeric_op = REDUCE_OP_TO_NUMERIC[op] @@ -1227,18 +1206,10 @@ def tensordot(x, y, dims): } -NUMERIC_OPS = [ - ops.mean, - ops.std, - ops.var, -] - - __all__ = [ "Einsum", "Function", "REDUCE_OP_TO_NUMERIC", - "NUMERIC_OPS", "Tensor", "align_tensor", "align_tensors", diff --git a/funsor/terms.py b/funsor/terms.py index 7249cc4b5..8b8d2bf31 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -559,26 +559,32 @@ def reshape(self, shape): # reduce over output shape while preserving all inputs. # To reduce over inputs, instead call .reduce(op, reduced_vars). - def sum(self, dim=None, keepdim=False): - return Unary(ops.SumOp(dim, keepdim=keepdim), self) + def all(self, dim=None, keepdim=False): + return Unary(ops.AllOp(dim, keepdim), self) + + def any(self, dim=None, keepdim=False): + return Unary(ops.AnyOp(dim, keepdim), self) - def prod(self): - return Unary(ops.mul, self) + def argmax(self, dim=None, keepdim=False): + return Unary(ops.ArgmaxOp(dim, keepdim), self) - def logsumexp(self): - return Unary(ops.logaddexp, self) + def argmin(self, dim=None, keepdim=False): + return Unary(ops.ArgminOp(dim, keepdim), self) - def all(self): - return Unary(ops.and_, self) + def max(self, dim=None, keepdim=False): + return Unary(ops.AmaxOp(dim, keepdim), self) - def any(self): - return Unary(ops.or_, self) + def min(self, dim=None, keepdim=False): + return Unary(ops.AmaxOp(dim, keepdim), self) + + def sum(self, dim=None, keepdim=False): + return Unary(ops.SumOp(dim, keepdim), self) - def min(self): - return Unary(ops.min, self) + def prod(self, dim=None, keepdim=False): + return Unary(ops.ProdOp(dim, keepdim), self) - def max(self): - return Unary(ops.max, self) + def logsumexp(self, dim=None, keepdim=False): + return Unary(ops.LogsumexpOp(dim, keepdim), self) def mean(self, dim=None, keepdim=False): return Unary(ops.MeanOp(dim, keepdim), self) diff --git a/funsor/torch/ops.py b/funsor/torch/ops.py index 4ed299386..9a1252afb 100644 --- a/funsor/torch/ops.py +++ b/funsor/torch/ops.py @@ -3,6 +3,7 @@ import numbers import typing +from functools import partial, reduce import torch @@ -26,29 +27,118 @@ ops.unsqueeze.register(torch.Tensor)(torch.unsqueeze) +########################################### +# Reduction Ops +########################################### + + @ops.all.register(torch.Tensor) -def _all(x, dim): - return x.all() if dim is None else x.all(dim=dim) +def _all(x, dim, keepdim): + if dim is None and not keepdim: + return x.all() + + if isinstance(dim, int): + return x.all(dim, keepdim=keepdim) + + dim = tuple(range(x.dim())) if dim is None else dim + return reduce(partial(torch.all, keepdim=keepdim), dim, x) + + +@ops.any.register(torch.Tensor) +def _any(x, dim, keepdim): + if dim is None and not keepdim: + return x.any() + + if isinstance(dim, int): + return x.any(dim, keepdim=keepdim) + + dim = tuple(range(x.dim())) if dim is None else dim + return reduce(partial(torch.any, keepdim=keepdim), dim, x) + + +@ops.argmax.register(torch.Tensor) +def _argmax(x, dim, keepdim): + # FIXME find_domain + return x.argmax(dim, keepdim=keepdim) + + +@ops.argmin.register(torch.Tensor) +def _argmin(x, dim, keepdim): + # FIXME find_domain + return x.argmin(dim, keepdim=keepdim) @ops.amax.register(torch.Tensor) -def _amax(x, dim, keepdims=False): - return x.max() if dim is None else x.max(dim, keepdims)[0] +def _amax(x, dim, keepdim): + if dim is None and not keepdim: + return x.amax() + dim = tuple(range(x.dim())) if dim is None else dim + return x.amax(dim, keepdim=keepdim) @ops.amin.register(torch.Tensor) -def _amin(x, dim, keepdims=False): - return x.min() if dim is None else x.min(dim, keepdims)[0] +def _amin(x, dim, keepdim): + if dim is None and not keepdim: + return x.amin() + dim = tuple(range(x.dim())) if dim is None else dim + return x.amin(dim, keepdim=keepdim) -@ops.argmax.register(torch.Tensor) -def _argmax(x, dim): - return x.max(dim).indices +@ops.sum.register(torch.Tensor) +def _sum(x, dim, keepdim): + if dim is None and not keepdim: + return x.sum() + dim = tuple(range(x.dim())) if dim is None else dim + return x.sum(dim, keepdim=keepdim) -@ops.any.register(torch.Tensor) -def _any(x, dim): - return x.any() if dim is None else x.any(dim=dim) +@ops.prod.register(torch.Tensor) +def _prod(x, dim, keepdim): + if dim is None and not keepdim: + return x.prod() + + if isinstance(dim, int): + return x.prod(dim, keepdim=keepdim) + + dim = tuple(range(x.dim())) if dim is None else dim + return reduce(partial(torch.prod, keepdim=keepdim), dim, x) + + +@ops.logsumexp.register(torch.Tensor) +def _logsumexp(x, dim, keepdim): + dim = tuple(range(x.dim())) if dim is None else dim + return x.logsumexp(dim, keepdim=keepdim) + + +@ops.mean.register(torch.Tensor) +def _mean(x, dim, keepdim): + if dim is None and not keepdim: + return x.mean() + dim = tuple(range(x.dim())) if dim is None else dim + return x.mean(dim, keepdim=keepdim) + + +@ops.std.register(torch.Tensor) +def _std(x, dim, ddof, keepdim): + dim = tuple(range(x.dim())) if dim is None else dim + if ddof == 0: + return x.std(dim, unbiased=False, keepdim=keepdim) + if ddof == 1: + return x.std(dim, keepdim=keepdim) + raise NotImplementedError + + +@ops.var.register(torch.Tensor) +def _var(x, dim, ddof, keepdim): + dim = tuple(range(x.dim())) if dim is None else dim + if ddof == 0: + return x.var(dim, unbiased=False, keepdim=keepdim) + if ddof == 1: + return x.var(dim, keepdim=keepdim) + raise NotImplementedError + + +########################################### @ops.astype.register(torch.Tensor) @@ -146,11 +236,6 @@ def _safe_logaddexp_tensor_number(x, y): return _safe_logaddexp_number_tensor(y, x) -@ops.logsumexp.register(torch.Tensor) -def _logsumexp(x, dim): - return x.reshape(-1).logsumexp(0) if dim is None else x.logsumexp(dim) - - @ops.max.register(torch.Tensor, torch.Tensor) def _max(x, y): return torch.max(x, y) @@ -166,12 +251,6 @@ def _max(x, y): return x.clamp(min=y) -@ops.mean.register(torch.Tensor) -def _mean(x, dim=None, keepdim=False): - dim = tuple(x.shape) if dim is None else dim - return x.mean(dim, keepdim=keepdim) - - @ops.min.register(torch.Tensor, torch.Tensor) def _min(x, y): return torch.min(x, y) @@ -229,11 +308,6 @@ def _pow(x, y): return x ** y -@ops.prod.register(torch.Tensor) -def _prod(x, dim): - return x.prod() if dim is None else x.prod(dim=dim) - - @ops.reciprocal.register(torch.Tensor) def _reciprocal(x): result = x.reciprocal().clamp(max=torch.finfo(x.dtype).max) @@ -276,36 +350,6 @@ def _scatter_add(destin, indices, source): ops.stack.register(typing.Tuple[torch.Tensor, ...])(torch.stack) -@ops.std.register(torch.Tensor) -def _std(x, dim, ddof, keepdim): - dim = tuple(x.shape) if dim is None else dim - if ddof == 0: - return x.std(dim, unbiased=False, keepdim=keepdim) - if ddof == 1: - return x.std(dim, keepdim=keepdim) - raise NotImplementedError - - -@ops.sum.register(torch.Tensor) -def _sum(x, dim, keepdims): - if dim is None: - if keepdims: - dim = tuple(range(x.dim())) - return x.sum(dim, True) - return x.sum() - return x.sum(dim, keepdims) - - @ops.triangular_solve.register(torch.Tensor, torch.Tensor) def _triangular_solve(x, y, upper=False, transpose=False): return x.triangular_solve(y, upper, transpose).solution - - -@ops.var.register(torch.Tensor) -def _var(x, dim, ddof, keepdim): - dim = tuple(x.shape) if dim is None else dim - if ddof == 0: - return x.var(dim, unbiased=False, keepdim=keepdim) - if ddof == 1: - return x.var(dim, keepdim=keepdim) - raise NotImplementedError diff --git a/test/test_terms.py b/test/test_terms.py index d18f4ad6b..12444c647 100644 --- a/test/test_terms.py +++ b/test/test_terms.py @@ -277,26 +277,47 @@ def test_unary(symbol, data): @pytest.mark.parametrize("event_shape", [(3, 2)], ids=str) +@pytest.mark.parametrize("dim", [None, 0, (1,), (0, 1)], ids=str) +@pytest.mark.parametrize("keepdim", [False, True], ids=str) @pytest.mark.parametrize( "name", [ "all", "any", - "logsumexp", + # "argmax", + # "argmin", "max", - "mean", "min", + "sum", "prod", + "logsumexp", + "mean", "std", - "sum", "var", ], ) -def test_reduce_event(name, event_shape): +def test_reduce_event(name, event_shape, dim, keepdim): + if name in ("argmax", "argmin"): + if dim is None and keepdim: + pytest.xfail(reason="find_domain needs to be fixed") + elif isinstance(dim, tuple): + pytest.xfail(reason="argmax and argmin don't support tuple dim") + dtype = 2 if name in ("any", "all") else "real" x = random_tensor(OrderedDict(i=Bint[5]), output=Array[dtype, event_shape]) - actual = getattr(x, name)() - check_funsor(actual, x.inputs, Array[dtype, ()]) + actual = getattr(x, name)(dim=dim, keepdim=keepdim) + + # compute expected shape + dim = (0, 1) if dim is None else dim + dim = (dim,) if isinstance(dim, int) else dim + if keepdim: + shape = tuple( + 1 if i in dim else event_shape[i] for i in range(len(event_shape)) + ) + else: + shape = tuple(event_shape[i] for i in range(len(event_shape)) if i not in dim) + + check_funsor(actual, x.inputs, Array[dtype, shape]) BINARY_OPS = ["+", "-", "*", "/", "**", "==", "!=", "<", "<=", ">", ">=", "min", "max"] From 19f6147f18dfc2009b38ab2e3e38235bc2ea58a0 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Thu, 18 Mar 2021 19:23:27 -0400 Subject: [PATCH 11/20] revert eager_unary --- funsor/domains.py | 1 - funsor/tensor.py | 3 +++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/funsor/domains.py b/funsor/domains.py index bec28ad40..484cce7ce 100644 --- a/funsor/domains.py +++ b/funsor/domains.py @@ -272,7 +272,6 @@ def _find_domain_reduction(op, domain): else: raise NotImplementedError("TODO") - breakpoint() return Array[dtype, shape] diff --git a/funsor/tensor.py b/funsor/tensor.py index 8b4e4a0c2..6e2583fd9 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -295,6 +295,9 @@ def eager_subs(self, subs): data = self.data[tuple(index)] return Tensor(data, inputs, self.dtype) + def eager_unary(self, op): + return Tensor(op(self.data), self.inputs, self.dtype) + def eager_reduce(self, op, reduced_vars): if op in REDUCE_OP_TO_NUMERIC: numeric_op = REDUCE_OP_TO_NUMERIC[op] From 3d0f691b225cb435d45562158ffd65c23d29c785 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Fri, 19 Mar 2021 00:09:15 -0400 Subject: [PATCH 12/20] fix tests --- funsor/affine.py | 2 +- funsor/domains.py | 4 ++-- funsor/ops/array.py | 2 +- funsor/tensor.py | 6 +++++- funsor/terms.py | 2 +- funsor/torch/ops.py | 13 +++++++++++-- test/test_tensor.py | 30 +++++++++++++++--------------- 7 files changed, 36 insertions(+), 23 deletions(-) diff --git a/funsor/affine.py b/funsor/affine.py index 5c6b31535..0751f84b2 100644 --- a/funsor/affine.py +++ b/funsor/affine.py @@ -61,7 +61,7 @@ def _(fn): @affine_inputs.register(Unary) def _(fn): - if fn.op in (ops.neg, ops.add) or isinstance(fn.op, ops.ReshapeOp): + if fn.op in (ops.neg, ops.sum) or isinstance(fn.op, ops.ReshapeOp): return affine_inputs(fn.arg) return frozenset() diff --git a/funsor/domains.py b/funsor/domains.py index 484cce7ce..827cb82d2 100644 --- a/funsor/domains.py +++ b/funsor/domains.py @@ -262,9 +262,9 @@ def _find_domain_reduction(op, domain): # Compute shape. if op.defaults.get("keepdim", False): - shape = tuple(1 if i in dims else domain[i] for i in range(ndims)) + shape = tuple(1 if i in dims else domain.shape[i] for i in range(ndims)) else: - shape = tuple(domain[i] for i in range(ndims) if i not in dims) + shape = tuple(domain.shape[i] for i in range(ndims) if i not in dims) # Compute domain. if domain.dtype == "real": diff --git a/funsor/ops/array.py b/funsor/ops/array.py index 709ab6d2d..0afc2adb4 100644 --- a/funsor/ops/array.py +++ b/funsor/ops/array.py @@ -88,7 +88,7 @@ def amax(x, dim=None, keepdim=False): @ReductionOp.make def amin(x, dim=None, keepdim=False): - return np.amax(x, dim, keepdims=keepdim) + return np.amin(x, dim, keepdims=keepdim) @ReductionOp.make diff --git a/funsor/tensor.py b/funsor/tensor.py index 6e2583fd9..84fb465af 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -768,9 +768,13 @@ def eager_reshape_tensor(op, arg): @eager.register(Unary, ops.ReductionOp, Tensor) def eager_reduction_tensor(op, arg): + if not arg.output.shape: + return arg + if not arg.inputs: return Tensor(op(arg.data), arg.inputs, arg.dtype) + dtype = find_domain(op, arg.output).dtype # Work around batch inputs. dim = op.defaults.get("dim", None) keepdim = op.defaults.get("keepdim", False) @@ -782,7 +786,7 @@ def eager_reduction_tensor(op, arg): else: dim = tuple(d % ndims - ndims for d in dim) data = op(arg.data, dim, keepdim) - return Tensor(data, arg.inputs, arg.dtype) + return Tensor(data, arg.inputs, dtype) @eager.register(Binary, GetitemOp, Tensor, Number) diff --git a/funsor/terms.py b/funsor/terms.py index 8b8d2bf31..d2f08df4d 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -575,7 +575,7 @@ def max(self, dim=None, keepdim=False): return Unary(ops.AmaxOp(dim, keepdim), self) def min(self, dim=None, keepdim=False): - return Unary(ops.AmaxOp(dim, keepdim), self) + return Unary(ops.AminOp(dim, keepdim), self) def sum(self, dim=None, keepdim=False): return Unary(ops.SumOp(dim, keepdim), self) diff --git a/funsor/torch/ops.py b/funsor/torch/ops.py index 9a1252afb..2f9dd01ca 100644 --- a/funsor/torch/ops.py +++ b/funsor/torch/ops.py @@ -100,8 +100,17 @@ def _prod(x, dim, keepdim): if isinstance(dim, int): return x.prod(dim, keepdim=keepdim) - dim = tuple(range(x.dim())) if dim is None else dim - return reduce(partial(torch.prod, keepdim=keepdim), dim, x) + # reduce over multiple dims. + reduced_dim = ( + tuple(range(x.dim())) if dim is None else tuple(d % x.dim() for d in dim) + ) + nonreduced_dim = tuple(i for i in range(x.dim()) if i not in reduced_dim) + permutation = nonreduced_dim + reduced_dim + result = torch.prod(x.permute(permutation).flatten(-len(reduced_dim), -1), -1) + if keepdim: + shape = tuple(1 if i in reduced_dim else x.shape[i] for i in range(x.dim())) + result = result.view(shape) + return result @ops.logsumexp.register(torch.Tensor) diff --git a/test/test_tensor.py b/test/test_tensor.py index 521c63864..6d505224a 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -1317,23 +1317,23 @@ def test_sum(event_shape): for dim in DIMS: assert_close(Tensor(ops.sum(data, dim)), ops.sum(Tensor(data), dim)) assert_close(Tensor(ops.sum(data, dim=dim)), ops.sum(Tensor(data), dim=dim)) - for keepdims in KEEPDIMS: + for keepdim in KEEPDIMS: assert_close( - Tensor(ops.sum(data, keepdims=keepdims)), - ops.sum(Tensor(data), keepdims=keepdims), + Tensor(ops.sum(data, keepdim=keepdim)), + ops.sum(Tensor(data), keepdim=keepdim), ) for dim in DIMS: assert_close( - Tensor(ops.sum(data, dim, keepdims)), - ops.sum(Tensor(data), dim, keepdims), + Tensor(ops.sum(data, dim, keepdim)), + ops.sum(Tensor(data), dim, keepdim), ) assert_close( - Tensor(ops.sum(data, dim, keepdims=keepdims)), - ops.sum(Tensor(data), dim, keepdims=keepdims), + Tensor(ops.sum(data, dim, keepdim=keepdim)), + ops.sum(Tensor(data), dim, keepdim=keepdim), ) assert_close( - Tensor(ops.sum(data, dim=dim, keepdims=keepdims)), - ops.sum(Tensor(data), dim=dim, keepdims=keepdims), + Tensor(ops.sum(data, dim=dim, keepdim=keepdim)), + ops.sum(Tensor(data), dim=dim, keepdim=keepdim), ) @@ -1345,14 +1345,14 @@ def test_sum_batch(batch_shape, event_shape): DIMS = [None, 0, 1, 2, -1, -2, -3, (0, 2)] KEEPDIMS = [False, True] - def raw_sum(x, dim=None, keepdims=False, batch_ndims=len(batch_shape)): + def raw_sum(x, dim=None, keepdim=False, batch_ndims=len(batch_shape)): if batch_ndims == 0: - return ops.sum(x, dim, keepdims) - return ops.stack([raw_sum(part, dim, keepdims, batch_ndims - 1) for part in x]) + return ops.sum(x, dim, keepdim) + return ops.stack([raw_sum(part, dim, keepdim, batch_ndims - 1) for part in x]) rtol = 1e-5 if get_backend() == "jax" else 1e-6 - for keepdims in KEEPDIMS: + for keepdim in KEEPDIMS: for dim in DIMS: - actual = ops.sum(Tensor(data, inputs), dim, keepdims) - expected = Tensor(raw_sum(data, dim, keepdims), inputs) + actual = ops.sum(Tensor(data, inputs), dim, keepdim) + expected = Tensor(raw_sum(data, dim, keepdim), inputs) assert_close(actual, expected, rtol=rtol) From 9e6aee94d49faf2766660e4bf9f5c1862256e7ba Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Fri, 19 Mar 2021 01:51:22 -0400 Subject: [PATCH 13/20] fix more tests --- funsor/domains.py | 4 +++- funsor/einsum/numpy_log.py | 2 +- funsor/ops/array.py | 3 +++ funsor/ops/op.py | 2 +- funsor/tensor.py | 3 ++- funsor/torch/ops.py | 39 +++++++++++++++++++++++++------------- 6 files changed, 36 insertions(+), 17 deletions(-) diff --git a/funsor/domains.py b/funsor/domains.py index 827cb82d2..5abb7afb5 100644 --- a/funsor/domains.py +++ b/funsor/domains.py @@ -269,8 +269,10 @@ def _find_domain_reduction(op, domain): # Compute domain. if domain.dtype == "real": dtype = "real" + elif op.name in ("all", "any"): + dtype = domain.dtype else: - raise NotImplementedError("TODO") + raise NotImplementedError return Array[dtype, shape] diff --git a/funsor/einsum/numpy_log.py b/funsor/einsum/numpy_log.py index 33d9b8b11..b71997e54 100644 --- a/funsor/einsum/numpy_log.py +++ b/funsor/einsum/numpy_log.py @@ -27,7 +27,7 @@ def einsum(equation, *operands): shift = ops.detach(operand) for i, dim in enumerate(dims): if dim not in output: - shift = ops.amax(shift, i, keepdims=True) + shift = ops.amax(shift, i, keepdim=True) # avoid nan due to -inf - -inf shift = ops.clamp(shift, ops.finfo(shift).min, None) exp_operands.append(ops.exp(operand - shift)) diff --git a/funsor/ops/array.py b/funsor/ops/array.py index 0afc2adb4..37341fb8e 100644 --- a/funsor/ops/array.py +++ b/funsor/ops/array.py @@ -126,6 +126,9 @@ def var(x, dim=None, ddof=0, keepdim=False): return np.var(x, dim, ddof=ddof, keepdims=keepdim) +########################################### + + @UnaryOp.make def isnan(x): return np.isnan(x) diff --git a/funsor/ops/op.py b/funsor/ops/op.py index 047412a46..833434374 100644 --- a/funsor/ops/op.py +++ b/funsor/ops/op.py @@ -284,7 +284,7 @@ def _list_to_tuple(cls, arg, *args, **kwargs): class ReductionOp(UnaryOp): """ - Here reduction operations are defined in a broad sense, not only + Reduction operations are defined in a broad sense, not only associative operations. This helps to unify find_domain logic. """ diff --git a/funsor/tensor.py b/funsor/tensor.py index 84fb465af..eaeae0efd 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -296,7 +296,8 @@ def eager_subs(self, subs): return Tensor(data, inputs, self.dtype) def eager_unary(self, op): - return Tensor(op(self.data), self.inputs, self.dtype) + dtype = find_domain(op, self.output).dtype + return Tensor(op(self.data), self.inputs, dtype) def eager_reduce(self, op, reduced_vars): if op in REDUCE_OP_TO_NUMERIC: diff --git a/funsor/torch/ops.py b/funsor/torch/ops.py index 2f9dd01ca..14b17a887 100644 --- a/funsor/torch/ops.py +++ b/funsor/torch/ops.py @@ -3,7 +3,6 @@ import numbers import typing -from functools import partial, reduce import torch @@ -32,6 +31,17 @@ ########################################### +def _flatten_reduced_dims(x, dim): + # Canonicalize reduced dim. + reduced_dim = ( + tuple(range(x.dim())) if dim is None else tuple(d % x.dim() for d in dim) + ) + nonreduced_dim = tuple(i for i in range(x.dim()) if i not in reduced_dim) + # permute & flatten reduced dim. + permutation = nonreduced_dim + reduced_dim + return x.permute(permutation).flatten(-len(reduced_dim), -1), reduced_dim + + @ops.all.register(torch.Tensor) def _all(x, dim, keepdim): if dim is None and not keepdim: @@ -40,8 +50,12 @@ def _all(x, dim, keepdim): if isinstance(dim, int): return x.all(dim, keepdim=keepdim) - dim = tuple(range(x.dim())) if dim is None else dim - return reduce(partial(torch.all, keepdim=keepdim), dim, x) + # reduce over multiple dims. + x_flattened, reduced_dim = _flatten_reduced_dims(x, dim) + if keepdim: + shape = tuple(1 if i in reduced_dim else x.shape[i] for i in range(x.dim())) + return torch.all(x_flattened, -1).view(shape) + return torch.all(x_flattened, -1) @ops.any.register(torch.Tensor) @@ -52,8 +66,12 @@ def _any(x, dim, keepdim): if isinstance(dim, int): return x.any(dim, keepdim=keepdim) - dim = tuple(range(x.dim())) if dim is None else dim - return reduce(partial(torch.any, keepdim=keepdim), dim, x) + # reduce over multiple dims. + x_flattened, reduced_dim = _flatten_reduced_dims(x, dim) + if keepdim: + shape = tuple(1 if i in reduced_dim else x.shape[i] for i in range(x.dim())) + return torch.any(x_flattened, -1).view(shape) + return torch.any(x_flattened, -1) @ops.argmax.register(torch.Tensor) @@ -101,16 +119,11 @@ def _prod(x, dim, keepdim): return x.prod(dim, keepdim=keepdim) # reduce over multiple dims. - reduced_dim = ( - tuple(range(x.dim())) if dim is None else tuple(d % x.dim() for d in dim) - ) - nonreduced_dim = tuple(i for i in range(x.dim()) if i not in reduced_dim) - permutation = nonreduced_dim + reduced_dim - result = torch.prod(x.permute(permutation).flatten(-len(reduced_dim), -1), -1) + x_flattened, reduced_dim = _flatten_reduced_dims(x, dim) if keepdim: shape = tuple(1 if i in reduced_dim else x.shape[i] for i in range(x.dim())) - result = result.view(shape) - return result + return torch.prod(x_flattened, -1).view(shape) + return torch.prod(x_flattened, -1) @ops.logsumexp.register(torch.Tensor) From b9d4faf5885e69122669e9d254b16cc7e0b326b4 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Fri, 19 Mar 2021 02:01:55 -0400 Subject: [PATCH 14/20] raise xfail for argmax and argmin --- funsor/domains.py | 2 +- funsor/ops/op.py | 2 +- funsor/torch/ops.py | 8 ++++---- test/test_terms.py | 6 ++++-- 4 files changed, 10 insertions(+), 8 deletions(-) diff --git a/funsor/domains.py b/funsor/domains.py index 5abb7afb5..c045c6694 100644 --- a/funsor/domains.py +++ b/funsor/domains.py @@ -272,7 +272,7 @@ def _find_domain_reduction(op, domain): elif op.name in ("all", "any"): dtype = domain.dtype else: - raise NotImplementedError + raise NotImplementedError("TODO") return Array[dtype, shape] diff --git a/funsor/ops/op.py b/funsor/ops/op.py index 833434374..4c4b9b036 100644 --- a/funsor/ops/op.py +++ b/funsor/ops/op.py @@ -284,7 +284,7 @@ def _list_to_tuple(cls, arg, *args, **kwargs): class ReductionOp(UnaryOp): """ - Reduction operations are defined in a broad sense, not only + Reduction operations are defined in a broad sense - not only associative operations. This helps to unify find_domain logic. """ diff --git a/funsor/torch/ops.py b/funsor/torch/ops.py index 14b17a887..98d7e4fa9 100644 --- a/funsor/torch/ops.py +++ b/funsor/torch/ops.py @@ -31,7 +31,7 @@ ########################################### -def _flatten_reduced_dims(x, dim): +def _flatten_reduced_dim(x, dim): # Canonicalize reduced dim. reduced_dim = ( tuple(range(x.dim())) if dim is None else tuple(d % x.dim() for d in dim) @@ -51,7 +51,7 @@ def _all(x, dim, keepdim): return x.all(dim, keepdim=keepdim) # reduce over multiple dims. - x_flattened, reduced_dim = _flatten_reduced_dims(x, dim) + x_flattened, reduced_dim = _flatten_reduced_dim(x, dim) if keepdim: shape = tuple(1 if i in reduced_dim else x.shape[i] for i in range(x.dim())) return torch.all(x_flattened, -1).view(shape) @@ -67,7 +67,7 @@ def _any(x, dim, keepdim): return x.any(dim, keepdim=keepdim) # reduce over multiple dims. - x_flattened, reduced_dim = _flatten_reduced_dims(x, dim) + x_flattened, reduced_dim = _flatten_reduced_dim(x, dim) if keepdim: shape = tuple(1 if i in reduced_dim else x.shape[i] for i in range(x.dim())) return torch.any(x_flattened, -1).view(shape) @@ -119,7 +119,7 @@ def _prod(x, dim, keepdim): return x.prod(dim, keepdim=keepdim) # reduce over multiple dims. - x_flattened, reduced_dim = _flatten_reduced_dims(x, dim) + x_flattened, reduced_dim = _flatten_reduced_dim(x, dim) if keepdim: shape = tuple(1 if i in reduced_dim else x.shape[i] for i in range(x.dim())) return torch.prod(x_flattened, -1).view(shape) diff --git a/test/test_terms.py b/test/test_terms.py index 12444c647..e52db29e4 100644 --- a/test/test_terms.py +++ b/test/test_terms.py @@ -284,8 +284,8 @@ def test_unary(symbol, data): [ "all", "any", - # "argmax", - # "argmin", + "argmax", + "argmin", "max", "min", "sum", @@ -300,6 +300,8 @@ def test_reduce_event(name, event_shape, dim, keepdim): if name in ("argmax", "argmin"): if dim is None and keepdim: pytest.xfail(reason="find_domain needs to be fixed") + elif dim is None and not keepdim: + pytest.xfail(reason="eager_reduction_tensor converts None to tuple") elif isinstance(dim, tuple): pytest.xfail(reason="argmax and argmin don't support tuple dim") From 87e491510b09536848e2f325b5b943512812951b Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Fri, 19 Mar 2021 02:19:45 -0400 Subject: [PATCH 15/20] fix argmax in jax/ops --- funsor/jax/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/funsor/jax/ops.py b/funsor/jax/ops.py index d740d40e2..5a207d5f8 100644 --- a/funsor/jax/ops.py +++ b/funsor/jax/ops.py @@ -54,7 +54,7 @@ def _argmax(x, dim, keepdim): if keepdim: return np.expand_dims(np.argmax(x, dim), dim) else: - np.argmax(x, dim) + return np.argmax(x, dim) @ops.argmin.register(array) From 27304964a1bec06838f86733c6b7fd7818200b89 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Fri, 19 Mar 2021 02:43:36 -0400 Subject: [PATCH 16/20] torch renaming --- funsor/jax/ops.py | 6 ++---- funsor/ops/array.py | 6 ++---- funsor/torch/ops.py | 42 +++++++++++++++++++++--------------------- 3 files changed, 25 insertions(+), 29 deletions(-) diff --git a/funsor/jax/ops.py b/funsor/jax/ops.py index 5a207d5f8..512583cfc 100644 --- a/funsor/jax/ops.py +++ b/funsor/jax/ops.py @@ -53,16 +53,14 @@ def _any(x, dim, keepdim): def _argmax(x, dim, keepdim): if keepdim: return np.expand_dims(np.argmax(x, dim), dim) - else: - return np.argmax(x, dim) + return np.argmax(x, dim) @ops.argmin.register(array) def _argmin(x, dim, keepdim): if keepdim: return np.expand_dims(np.argmin(x, dim), dim) - else: - return np.argmin(x, dim) + return np.argmin(x, dim) @ops.amax.register(array) diff --git a/funsor/ops/array.py b/funsor/ops/array.py index 37341fb8e..8e1b1fdfa 100644 --- a/funsor/ops/array.py +++ b/funsor/ops/array.py @@ -69,16 +69,14 @@ def any(x, dim=None, keepdim=False): def argmax(x, dim=None, keepdim=False): if keepdim: return np.expand_dims(np.argmax(x, dim), dim) - else: - return np.argmax(x, dim) + return np.argmax(x, dim) @ReductionOp.make def argmin(x, dim=None, keepdim=False): if keepdim: return np.expand_dims(np.argmin(x, dim), dim) - else: - return np.argmin(x, dim) + return np.argmin(x, dim) @ReductionOp.make diff --git a/funsor/torch/ops.py b/funsor/torch/ops.py index 98d7e4fa9..56d351b2b 100644 --- a/funsor/torch/ops.py +++ b/funsor/torch/ops.py @@ -45,10 +45,10 @@ def _flatten_reduced_dim(x, dim): @ops.all.register(torch.Tensor) def _all(x, dim, keepdim): if dim is None and not keepdim: - return x.all() + return torch.all(x) if isinstance(dim, int): - return x.all(dim, keepdim=keepdim) + return torch.all(x, dim, keepdim=keepdim) # reduce over multiple dims. x_flattened, reduced_dim = _flatten_reduced_dim(x, dim) @@ -61,10 +61,10 @@ def _all(x, dim, keepdim): @ops.any.register(torch.Tensor) def _any(x, dim, keepdim): if dim is None and not keepdim: - return x.any() + return torch.any(x) if isinstance(dim, int): - return x.any(dim, keepdim=keepdim) + return torch.any(x, dim, keepdim=keepdim) # reduce over multiple dims. x_flattened, reduced_dim = _flatten_reduced_dim(x, dim) @@ -77,46 +77,46 @@ def _any(x, dim, keepdim): @ops.argmax.register(torch.Tensor) def _argmax(x, dim, keepdim): # FIXME find_domain - return x.argmax(dim, keepdim=keepdim) + return torch.argmax(x, dim, keepdim=keepdim) @ops.argmin.register(torch.Tensor) def _argmin(x, dim, keepdim): # FIXME find_domain - return x.argmin(dim, keepdim=keepdim) + return torch.argmin(x, dim, keepdim=keepdim) @ops.amax.register(torch.Tensor) def _amax(x, dim, keepdim): if dim is None and not keepdim: - return x.amax() + return torch.amax(x) dim = tuple(range(x.dim())) if dim is None else dim - return x.amax(dim, keepdim=keepdim) + return torch.amax(x, dim, keepdim=keepdim) @ops.amin.register(torch.Tensor) def _amin(x, dim, keepdim): if dim is None and not keepdim: - return x.amin() + return torch.amin(x) dim = tuple(range(x.dim())) if dim is None else dim - return x.amin(dim, keepdim=keepdim) + return torch.amin(x, dim, keepdim=keepdim) @ops.sum.register(torch.Tensor) def _sum(x, dim, keepdim): if dim is None and not keepdim: - return x.sum() + return torch.sum(x) dim = tuple(range(x.dim())) if dim is None else dim - return x.sum(dim, keepdim=keepdim) + return torch.sum(x, dim, keepdim=keepdim) @ops.prod.register(torch.Tensor) def _prod(x, dim, keepdim): if dim is None and not keepdim: - return x.prod() + return torch.prod(x) if isinstance(dim, int): - return x.prod(dim, keepdim=keepdim) + return torch.prod(x, dim, keepdim=keepdim) # reduce over multiple dims. x_flattened, reduced_dim = _flatten_reduced_dim(x, dim) @@ -129,24 +129,24 @@ def _prod(x, dim, keepdim): @ops.logsumexp.register(torch.Tensor) def _logsumexp(x, dim, keepdim): dim = tuple(range(x.dim())) if dim is None else dim - return x.logsumexp(dim, keepdim=keepdim) + return torch.logsumexp(x, dim, keepdim=keepdim) @ops.mean.register(torch.Tensor) def _mean(x, dim, keepdim): if dim is None and not keepdim: - return x.mean() + return torch.mean(x) dim = tuple(range(x.dim())) if dim is None else dim - return x.mean(dim, keepdim=keepdim) + return torch.mean(x, dim, keepdim=keepdim) @ops.std.register(torch.Tensor) def _std(x, dim, ddof, keepdim): dim = tuple(range(x.dim())) if dim is None else dim if ddof == 0: - return x.std(dim, unbiased=False, keepdim=keepdim) + return torch.std(x, dim, unbiased=False, keepdim=keepdim) if ddof == 1: - return x.std(dim, keepdim=keepdim) + return torch.std(x, dim, keepdim=keepdim) raise NotImplementedError @@ -154,9 +154,9 @@ def _std(x, dim, ddof, keepdim): def _var(x, dim, ddof, keepdim): dim = tuple(range(x.dim())) if dim is None else dim if ddof == 0: - return x.var(dim, unbiased=False, keepdim=keepdim) + return torch.var(x, dim, unbiased=False, keepdim=keepdim) if ddof == 1: - return x.var(dim, keepdim=keepdim) + return torch.var(x, dim, keepdim=keepdim) raise NotImplementedError From 0ab6b0791ee81780c47b81fd796660b23ab862d2 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Fri, 19 Mar 2021 16:36:20 -0400 Subject: [PATCH 17/20] address comments --- funsor/ops/array.py | 2 +- funsor/tensor.py | 2 +- test/test_tensor.py | 133 ++++++++++++++++++++++++++++++++++++-------- test/test_terms.py | 10 ---- 4 files changed, 111 insertions(+), 36 deletions(-) diff --git a/funsor/ops/array.py b/funsor/ops/array.py index 8e1b1fdfa..f55c55be0 100644 --- a/funsor/ops/array.py +++ b/funsor/ops/array.py @@ -100,7 +100,7 @@ def prod(x, dim=None, keepdim=False): @ReductionOp.make -def logsumexp(x, dim, keepdim=False): +def logsumexp(x, dim=None, keepdim=False): amax = np.amax(x, axis=dim, keepdims=True) # treat the case x = -inf amax = np.where(np.isfinite(amax), amax, 0.0) diff --git a/funsor/tensor.py b/funsor/tensor.py index eaeae0efd..7bb3a6ea9 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -786,7 +786,7 @@ def eager_reduction_tensor(op, arg): dim = dim % ndims - ndims else: dim = tuple(d % ndims - ndims for d in dim) - data = op(arg.data, dim, keepdim) + data = op(arg.data, dim=dim, keepdim=keepdim) return Tensor(data, arg.inputs, dtype) diff --git a/test/test_tensor.py b/test/test_tensor.py index 6d505224a..489085c0d 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -1307,52 +1307,137 @@ def test_scatter_pure_renaming(): assert ((actual - expected).abs() < 1e-4).data.all() +@pytest.mark.parametrize( + "op", + [ + ops.any, + ops.all, + ops.amin, + ops.amax, + ops.sum, + ops.logsumexp, + ops.prod, + ops.mean, + ], +) @pytest.mark.parametrize("event_shape", [(2, 3, 4)], ids=str) -def test_sum(event_shape): +def test_reduction(op, event_shape): data = randn(*event_shape) DIMS = [None, 0, 1, 2, -1, -2, -3, (0, 2)] KEEPDIMS = [False, True] + op_name = op.name[1:] if op.name in {"amin", "amax"} else op.name + + expected = Tensor(op(data)) + assert_close(op(Tensor(data)), expected) + assert_close(getattr(Tensor(data), op_name)(), expected) - assert_close(Tensor(ops.sum(data)), ops.sum(Tensor(data))) for dim in DIMS: - assert_close(Tensor(ops.sum(data, dim)), ops.sum(Tensor(data), dim)) - assert_close(Tensor(ops.sum(data, dim=dim)), ops.sum(Tensor(data), dim=dim)) + expected = Tensor(op(data, dim)) + assert_close(op(Tensor(data), dim), expected) + assert_close(op(Tensor(data), dim=dim), expected) + assert_close(getattr(Tensor(data), op_name)(dim), expected) + assert_close(getattr(Tensor(data), op_name)(dim=dim), expected) + for keepdim in KEEPDIMS: - assert_close( - Tensor(ops.sum(data, keepdim=keepdim)), - ops.sum(Tensor(data), keepdim=keepdim), - ) + expected = Tensor(op(data, keepdim=keepdim)) + assert_close(op(Tensor(data), keepdim=keepdim), expected) + assert_close(getattr(Tensor(data), op_name)(keepdim=keepdim), expected) + for dim in DIMS: + expected = Tensor(op(data, dim, keepdim)) + assert_close(op(Tensor(data), dim, keepdim), expected) + assert_close(op(Tensor(data), dim, keepdim=keepdim), expected) + assert_close(op(Tensor(data), dim=dim, keepdim=keepdim), expected) + assert_close(getattr(Tensor(data), op_name)(dim, keepdim), expected) + assert_close(getattr(Tensor(data), op_name)(dim, keepdim=keepdim), expected) assert_close( - Tensor(ops.sum(data, dim, keepdim)), - ops.sum(Tensor(data), dim, keepdim), - ) - assert_close( - Tensor(ops.sum(data, dim, keepdim=keepdim)), - ops.sum(Tensor(data), dim, keepdim=keepdim), - ) - assert_close( - Tensor(ops.sum(data, dim=dim, keepdim=keepdim)), - ops.sum(Tensor(data), dim=dim, keepdim=keepdim), + getattr(Tensor(data), op_name)(dim=dim, keepdim=keepdim), expected ) +@pytest.mark.parametrize( + "op", + [ + ops.std, + ops.var, + ], +) +@pytest.mark.parametrize("event_shape", [(2, 3, 4)], ids=str) +def test_std_var(op, event_shape): + data = randn(*event_shape) + DIMS = [None, 0, 1, 2, -1, -2, -3, (0, 2)] + DDOFS = [0, 1] + KEEPDIMS = [False, True] + + expected = Tensor(op(data)) + assert_close(op(Tensor(data)), expected) + assert_close(getattr(Tensor(data), op.name)(), expected) + + for dim in DIMS: + expected = Tensor(op(data, dim)) + assert_close(op(Tensor(data), dim), expected) + assert_close(op(Tensor(data), dim=dim), expected) + assert_close(getattr(Tensor(data), op.name)(dim), expected) + assert_close(getattr(Tensor(data), op.name)(dim=dim), expected) + + for keepdim in KEEPDIMS: + expected = Tensor(op(data, keepdim=keepdim)) + assert_close(op(Tensor(data), keepdim=keepdim), expected) + assert_close(getattr(Tensor(data), op.name)(keepdim=keepdim), expected) + + for ddof in DDOFS: + for dim in DIMS: + expected = Tensor(op(data, dim, ddof, keepdim)) + assert_close(op(Tensor(data), dim, ddof, keepdim), expected) + assert_close(op(Tensor(data), dim, ddof, keepdim=keepdim), expected) + assert_close( + op(Tensor(data), dim=dim, ddof=ddof, keepdim=keepdim), expected + ) + assert_close( + getattr(Tensor(data), op.name)(dim, ddof, keepdim), expected + ) + assert_close( + getattr(Tensor(data), op.name)(dim, ddof, keepdim=keepdim), expected + ) + assert_close( + getattr(Tensor(data), op.name)(dim=dim, ddof=ddof, keepdim=keepdim), + expected, + ) + + +@pytest.mark.parametrize( + "op", + [ + ops.any, + ops.all, + ops.amin, + ops.amax, + ops.sum, + ops.logsumexp, + ops.prod, + ops.mean, + ops.std, + ops.var, + ], +) @pytest.mark.parametrize("batch_shape", [(), (5,)], ids=str) @pytest.mark.parametrize("event_shape", [(2, 3, 4)], ids=str) -def test_sum_batch(batch_shape, event_shape): +def test_reduction_batch(op, batch_shape, event_shape): inputs = OrderedDict((k, Bint[s]) for k, s in zip("abc", batch_shape)) data = randn(*batch_shape, *event_shape) DIMS = [None, 0, 1, 2, -1, -2, -3, (0, 2)] KEEPDIMS = [False, True] - def raw_sum(x, dim=None, keepdim=False, batch_ndims=len(batch_shape)): + def raw_reduction(x, dim=None, keepdim=False, batch_ndims=len(batch_shape)): if batch_ndims == 0: - return ops.sum(x, dim, keepdim) - return ops.stack([raw_sum(part, dim, keepdim, batch_ndims - 1) for part in x]) + return op(x, dim, keepdim=keepdim) + return ops.stack( + [raw_reduction(part, dim, keepdim, batch_ndims - 1) for part in x] + ) rtol = 1e-5 if get_backend() == "jax" else 1e-6 for keepdim in KEEPDIMS: for dim in DIMS: - actual = ops.sum(Tensor(data, inputs), dim, keepdim) - expected = Tensor(raw_sum(data, dim, keepdim), inputs) + actual = op(Tensor(data, inputs), dim, keepdim=keepdim) + expected = Tensor(raw_reduction(data, dim, keepdim), inputs) assert_close(actual, expected, rtol=rtol) diff --git a/test/test_terms.py b/test/test_terms.py index e52db29e4..50b768918 100644 --- a/test/test_terms.py +++ b/test/test_terms.py @@ -284,8 +284,6 @@ def test_unary(symbol, data): [ "all", "any", - "argmax", - "argmin", "max", "min", "sum", @@ -297,14 +295,6 @@ def test_unary(symbol, data): ], ) def test_reduce_event(name, event_shape, dim, keepdim): - if name in ("argmax", "argmin"): - if dim is None and keepdim: - pytest.xfail(reason="find_domain needs to be fixed") - elif dim is None and not keepdim: - pytest.xfail(reason="eager_reduction_tensor converts None to tuple") - elif isinstance(dim, tuple): - pytest.xfail(reason="argmax and argmin don't support tuple dim") - dtype = 2 if name in ("any", "all") else "real" x = random_tensor(OrderedDict(i=Bint[5]), output=Array[dtype, event_shape]) actual = getattr(x, name)(dim=dim, keepdim=keepdim) From b99c201ae529a12961dfd04c269ef8c19752a475 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Fri, 19 Mar 2021 16:53:57 -0400 Subject: [PATCH 18/20] fix (all,any) dtype --- funsor/domains.py | 6 +++--- funsor/tensor.py | 5 +++-- test/test_tensor.py | 13 ++++++++----- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/funsor/domains.py b/funsor/domains.py index c045c6694..219fdfb40 100644 --- a/funsor/domains.py +++ b/funsor/domains.py @@ -267,10 +267,10 @@ def _find_domain_reduction(op, domain): shape = tuple(domain.shape[i] for i in range(ndims) if i not in dims) # Compute domain. - if domain.dtype == "real": + if op.name in ("all", "any"): + dtype = 2 + elif domain.dtype == "real": dtype = "real" - elif op.name in ("all", "any"): - dtype = domain.dtype else: raise NotImplementedError("TODO") diff --git a/funsor/tensor.py b/funsor/tensor.py index 7bb3a6ea9..f7a913866 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -769,13 +769,14 @@ def eager_reshape_tensor(op, arg): @eager.register(Unary, ops.ReductionOp, Tensor) def eager_reduction_tensor(op, arg): + dtype = find_domain(op, arg.output).dtype + if not arg.output.shape: return arg if not arg.inputs: - return Tensor(op(arg.data), arg.inputs, arg.dtype) + return Tensor(op(arg.data), arg.inputs, dtype) - dtype = find_domain(op, arg.output).dtype # Work around batch inputs. dim = op.defaults.get("dim", None) keepdim = op.defaults.get("keepdim", False) diff --git a/test/test_tensor.py b/test/test_tensor.py index 489085c0d..a2455eb2b 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -675,6 +675,7 @@ def test_reduce_event(op, event_shape, dims): dtype = "real" if op in [ops.and_, ops.or_]: data = ops.astype(data, "uint8") + dtype = 2 expected_data = numeric_op(data.reshape(batch_shape + (-1,)), -1) x = Tensor(data, inputs, dtype=dtype) @@ -1326,25 +1327,26 @@ def test_reduction(op, event_shape): DIMS = [None, 0, 1, 2, -1, -2, -3, (0, 2)] KEEPDIMS = [False, True] op_name = op.name[1:] if op.name in {"amin", "amax"} else op.name + dtype = 2 if op.name in {"all", "any"} else "real" - expected = Tensor(op(data)) + expected = Tensor(op(data), dtype=dtype) assert_close(op(Tensor(data)), expected) assert_close(getattr(Tensor(data), op_name)(), expected) for dim in DIMS: - expected = Tensor(op(data, dim)) + expected = Tensor(op(data, dim), dtype=dtype) assert_close(op(Tensor(data), dim), expected) assert_close(op(Tensor(data), dim=dim), expected) assert_close(getattr(Tensor(data), op_name)(dim), expected) assert_close(getattr(Tensor(data), op_name)(dim=dim), expected) for keepdim in KEEPDIMS: - expected = Tensor(op(data, keepdim=keepdim)) + expected = Tensor(op(data, keepdim=keepdim), dtype=dtype) assert_close(op(Tensor(data), keepdim=keepdim), expected) assert_close(getattr(Tensor(data), op_name)(keepdim=keepdim), expected) for dim in DIMS: - expected = Tensor(op(data, dim, keepdim)) + expected = Tensor(op(data, dim, keepdim), dtype=dtype) assert_close(op(Tensor(data), dim, keepdim), expected) assert_close(op(Tensor(data), dim, keepdim=keepdim), expected) assert_close(op(Tensor(data), dim=dim, keepdim=keepdim), expected) @@ -1425,6 +1427,7 @@ def test_std_var(op, event_shape): def test_reduction_batch(op, batch_shape, event_shape): inputs = OrderedDict((k, Bint[s]) for k, s in zip("abc", batch_shape)) data = randn(*batch_shape, *event_shape) + dtype = 2 if op.name in {"all", "any"} else "real" DIMS = [None, 0, 1, 2, -1, -2, -3, (0, 2)] KEEPDIMS = [False, True] @@ -1439,5 +1442,5 @@ def raw_reduction(x, dim=None, keepdim=False, batch_ndims=len(batch_shape)): for keepdim in KEEPDIMS: for dim in DIMS: actual = op(Tensor(data, inputs), dim, keepdim=keepdim) - expected = Tensor(raw_reduction(data, dim, keepdim), inputs) + expected = Tensor(raw_reduction(data, dim, keepdim), inputs, dtype) assert_close(actual, expected, rtol=rtol) From 1bd5aa5a7356f065c8ee0e9dfbea947510ffd308 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Fri, 19 Mar 2021 20:51:24 -0400 Subject: [PATCH 19/20] demote argmax,argmin to UnaryOp --- funsor/jax/ops.py | 28 ++++++++++++++-------------- funsor/ops/array.py | 28 ++++++++++++++-------------- funsor/tensor.py | 2 +- funsor/torch/ops.py | 24 ++++++++++++------------ 4 files changed, 41 insertions(+), 41 deletions(-) diff --git a/funsor/jax/ops.py b/funsor/jax/ops.py index 512583cfc..be5b314df 100644 --- a/funsor/jax/ops.py +++ b/funsor/jax/ops.py @@ -49,20 +49,6 @@ def _any(x, dim, keepdim): return np.any(x, dim, keepdims=keepdim) -@ops.argmax.register(array) -def _argmax(x, dim, keepdim): - if keepdim: - return np.expand_dims(np.argmax(x, dim), dim) - return np.argmax(x, dim) - - -@ops.argmin.register(array) -def _argmin(x, dim, keepdim): - if keepdim: - return np.expand_dims(np.argmin(x, dim), dim) - return np.argmin(x, dim) - - @ops.amax.register(array) def _amax(x, dim, keepdim): return np.amax(x, dim, keepdims=keepdim) @@ -106,6 +92,20 @@ def _var(x, dim, ddof, keepdim): ########################################### +@ops.argmax.register(array) +def _argmax(x, dim, keepdim): + if keepdim: + return np.expand_dims(np.argmax(x, dim), dim) + return np.argmax(x, dim) + + +@ops.argmin.register(array) +def _argmin(x, dim, keepdim): + if keepdim: + return np.expand_dims(np.argmin(x, dim), dim) + return np.argmin(x, dim) + + @ops.astype.register(array) def _astype(x, dtype): return x.astype(np.result_type(dtype)) diff --git a/funsor/ops/array.py b/funsor/ops/array.py index f55c55be0..04020ec90 100644 --- a/funsor/ops/array.py +++ b/funsor/ops/array.py @@ -65,20 +65,6 @@ def any(x, dim=None, keepdim=False): return np.any(x, dim, keepdims=keepdim) -@ReductionOp.make -def argmax(x, dim=None, keepdim=False): - if keepdim: - return np.expand_dims(np.argmax(x, dim), dim) - return np.argmax(x, dim) - - -@ReductionOp.make -def argmin(x, dim=None, keepdim=False): - if keepdim: - return np.expand_dims(np.argmin(x, dim), dim) - return np.argmin(x, dim) - - @ReductionOp.make def amax(x, dim=None, keepdim=False): return np.amax(x, dim, keepdims=keepdim) @@ -127,6 +113,20 @@ def var(x, dim=None, ddof=0, keepdim=False): ########################################### +@UnaryOp.make +def argmax(x, dim=None, keepdim=False): + if keepdim: + return np.expand_dims(np.argmax(x, dim), dim) + return np.argmax(x, dim) + + +@UnaryOp.make +def argmin(x, dim=None, keepdim=False): + if keepdim: + return np.expand_dims(np.argmin(x, dim), dim) + return np.argmin(x, dim) + + @UnaryOp.make def isnan(x): return np.isnan(x) diff --git a/funsor/tensor.py b/funsor/tensor.py index f7a913866..3bf8f215e 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -772,7 +772,7 @@ def eager_reduction_tensor(op, arg): dtype = find_domain(op, arg.output).dtype if not arg.output.shape: - return arg + return Tensor(op(ops.unsqueeze(arg.data, -1), -1), arg.inputs, dtype) if not arg.inputs: return Tensor(op(arg.data), arg.inputs, dtype) diff --git a/funsor/torch/ops.py b/funsor/torch/ops.py index 56d351b2b..a61b6024a 100644 --- a/funsor/torch/ops.py +++ b/funsor/torch/ops.py @@ -74,18 +74,6 @@ def _any(x, dim, keepdim): return torch.any(x_flattened, -1) -@ops.argmax.register(torch.Tensor) -def _argmax(x, dim, keepdim): - # FIXME find_domain - return torch.argmax(x, dim, keepdim=keepdim) - - -@ops.argmin.register(torch.Tensor) -def _argmin(x, dim, keepdim): - # FIXME find_domain - return torch.argmin(x, dim, keepdim=keepdim) - - @ops.amax.register(torch.Tensor) def _amax(x, dim, keepdim): if dim is None and not keepdim: @@ -163,6 +151,18 @@ def _var(x, dim, ddof, keepdim): ########################################### +@ops.argmax.register(torch.Tensor) +def _argmax(x, dim, keepdim): + # FIXME find_domain + return torch.argmax(x, dim, keepdim=keepdim) + + +@ops.argmin.register(torch.Tensor) +def _argmin(x, dim, keepdim): + # FIXME find_domain + return torch.argmin(x, dim, keepdim=keepdim) + + @ops.astype.register(torch.Tensor) def _astype(x, dtype): return x.type(getattr(torch, dtype)) From a0686f9ea1e5fff9109c4f7a1e1a794a8a1b90cf Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Fri, 19 Mar 2021 21:23:37 -0400 Subject: [PATCH 20/20] rename to axis,keepdims --- funsor/domains.py | 2 +- funsor/einsum/numpy_log.py | 2 +- funsor/jax/ops.py | 56 +++++++++++----------- funsor/ops/array.py | 60 +++++++++++------------ funsor/tensor.py | 16 +++---- funsor/terms.py | 48 +++++++++---------- funsor/torch/ops.py | 98 +++++++++++++++++++------------------- test/test_tensor.py | 69 ++++++++++++++------------- test/test_terms.py | 18 +++---- 9 files changed, 187 insertions(+), 182 deletions(-) diff --git a/funsor/domains.py b/funsor/domains.py index 219fdfb40..4785b73ee 100644 --- a/funsor/domains.py +++ b/funsor/domains.py @@ -261,7 +261,7 @@ def _find_domain_reduction(op, domain): dims = {i % ndims for i in dim} # Compute shape. - if op.defaults.get("keepdim", False): + if op.defaults.get("keepdims", False): shape = tuple(1 if i in dims else domain.shape[i] for i in range(ndims)) else: shape = tuple(domain.shape[i] for i in range(ndims) if i not in dims) diff --git a/funsor/einsum/numpy_log.py b/funsor/einsum/numpy_log.py index b71997e54..33d9b8b11 100644 --- a/funsor/einsum/numpy_log.py +++ b/funsor/einsum/numpy_log.py @@ -27,7 +27,7 @@ def einsum(equation, *operands): shift = ops.detach(operand) for i, dim in enumerate(dims): if dim not in output: - shift = ops.amax(shift, i, keepdim=True) + shift = ops.amax(shift, i, keepdims=True) # avoid nan due to -inf - -inf shift = ops.clamp(shift, ops.finfo(shift).min, None) exp_operands.append(ops.exp(operand - shift)) diff --git a/funsor/jax/ops.py b/funsor/jax/ops.py index be5b314df..0b173959f 100644 --- a/funsor/jax/ops.py +++ b/funsor/jax/ops.py @@ -40,70 +40,70 @@ @ops.all.register(array) -def _all(x, dim, keepdim): - return np.all(x, dim, keepdims=keepdim) +def _all(x, axis, keepdims): + return np.all(x, axis, keepdims=keepdims) @ops.any.register(array) -def _any(x, dim, keepdim): - return np.any(x, dim, keepdims=keepdim) +def _any(x, axis, keepdims): + return np.any(x, axis, keepdims=keepdims) @ops.amax.register(array) -def _amax(x, dim, keepdim): - return np.amax(x, dim, keepdims=keepdim) +def _amax(x, axis, keepdims): + return np.amax(x, axis, keepdims=keepdims) @ops.amin.register(array) -def _amin(x, dim, keepdim): - return np.amax(x, dim, keepdims=keepdim) +def _amin(x, axis, keepdims): + return np.amax(x, axis, keepdims=keepdims) @ops.sum.register(array) -def _sum(x, dim, keepdim): - return np.sum(x, dim, keepdims=keepdim) +def _sum(x, axis, keepdims): + return np.sum(x, axis, keepdims=keepdims) @ops.prod.register(array) -def _prod(x, dim, keepdim): - return np.prod(x, dim, keepdims=keepdim) +def _prod(x, axis, keepdims): + return np.prod(x, axis, keepdims=keepdims) @ops.logsumexp.register(array) -def _logsumexp(x, dim, keepdim): - return logsumexp(x, dim, keepdims=keepdim) +def _logsumexp(x, axis, keepdims): + return logsumexp(x, axis, keepdims=keepdims) @ops.mean.register(array) -def _mean(x, dim, keepdim): - return np.mean(x, dim, keepdims=keepdim) +def _mean(x, axis, keepdims): + return np.mean(x, axis, keepdims=keepdims) @ops.std.register(array) -def _std(x, dim, ddof, keepdim): - return np.std(x, dim, ddof=ddof, keepdims=keepdim) +def _std(x, axis, ddof, keepdims): + return np.std(x, axis, ddof=ddof, keepdims=keepdims) @ops.var.register(array) -def _var(x, dim, ddof, keepdim): - return np.var(x, dim, ddof=ddof, keepdims=keepdim) +def _var(x, axis, ddof, keepdims): + return np.var(x, axis, ddof=ddof, keepdims=keepdims) ########################################### @ops.argmax.register(array) -def _argmax(x, dim, keepdim): - if keepdim: - return np.expand_dims(np.argmax(x, dim), dim) - return np.argmax(x, dim) +def _argmax(x, axis, keepdims): + if keepdims: + return np.expand_dims(np.argmax(x, axis), axis) + return np.argmax(x, axis) @ops.argmin.register(array) -def _argmin(x, dim, keepdim): - if keepdim: - return np.expand_dims(np.argmin(x, dim), dim) - return np.argmin(x, dim) +def _argmin(x, axis, keepdims): + if keepdims: + return np.expand_dims(np.argmin(x, axis), axis) + return np.argmin(x, axis) @ops.astype.register(array) diff --git a/funsor/ops/array.py b/funsor/ops/array.py index 04020ec90..fe51de5dd 100644 --- a/funsor/ops/array.py +++ b/funsor/ops/array.py @@ -56,75 +56,75 @@ @ReductionOp.make -def all(x, dim=None, keepdim=False): - return np.all(x, dim, keepdims=keepdim) +def all(x, axis=None, keepdims=False): + return np.all(x, axis, keepdims=keepdims) @ReductionOp.make -def any(x, dim=None, keepdim=False): - return np.any(x, dim, keepdims=keepdim) +def any(x, axis=None, keepdims=False): + return np.any(x, axis, keepdims=keepdims) @ReductionOp.make -def amax(x, dim=None, keepdim=False): - return np.amax(x, dim, keepdims=keepdim) +def amax(x, axis=None, keepdims=False): + return np.amax(x, axis, keepdims=keepdims) @ReductionOp.make -def amin(x, dim=None, keepdim=False): - return np.amin(x, dim, keepdims=keepdim) +def amin(x, axis=None, keepdims=False): + return np.amin(x, axis, keepdims=keepdims) @ReductionOp.make -def sum(x, dim=None, keepdim=False): - return np.sum(x, dim, keepdims=keepdim) +def sum(x, axis=None, keepdims=False): + return np.sum(x, axis, keepdims=keepdims) @ReductionOp.make -def prod(x, dim=None, keepdim=False): - return np.prod(x, dim, keepdims=keepdim) +def prod(x, axis=None, keepdims=False): + return np.prod(x, axis, keepdims=keepdims) @ReductionOp.make -def logsumexp(x, dim=None, keepdim=False): - amax = np.amax(x, axis=dim, keepdims=True) +def logsumexp(x, axis=None, keepdims=False): + amax = np.amax(x, axis=axis, keepdims=True) # treat the case x = -inf amax = np.where(np.isfinite(amax), amax, 0.0) - unnormalized_lse = log(np.sum(np.exp(x - amax), dim, keepdims=keepdim)) - amax = amax if keepdim else amax.squeeze(dim) + unnormalized_lse = log(np.sum(np.exp(x - amax), axis, keepdims=keepdims)) + amax = amax if keepdims else amax.squeeze(axis) return unnormalized_lse + amax @ReductionOp.make -def mean(x, dim=None, keepdim=False): - return np.mean(x, dim, keepdims=keepdim) +def mean(x, axis=None, keepdims=False): + return np.mean(x, axis, keepdims=keepdims) @ReductionOp.make -def std(x, dim=None, ddof=0, keepdim=False): - return np.std(x, dim, ddof=ddof, keepdims=keepdim) +def std(x, axis=None, ddof=0, keepdims=False): + return np.std(x, axis, ddof=ddof, keepdims=keepdims) @ReductionOp.make -def var(x, dim=None, ddof=0, keepdim=False): - return np.var(x, dim, ddof=ddof, keepdims=keepdim) +def var(x, axis=None, ddof=0, keepdims=False): + return np.var(x, axis, ddof=ddof, keepdims=keepdims) ########################################### @UnaryOp.make -def argmax(x, dim=None, keepdim=False): - if keepdim: - return np.expand_dims(np.argmax(x, dim), dim) - return np.argmax(x, dim) +def argmax(x, axis=None, keepdims=False): + if keepdims: + return np.expand_dims(np.argmax(x, axis), axis) + return np.argmax(x, axis) @UnaryOp.make -def argmin(x, dim=None, keepdim=False): - if keepdim: - return np.expand_dims(np.argmin(x, dim), dim) - return np.argmin(x, dim) +def argmin(x, axis=None, keepdims=False): + if keepdims: + return np.expand_dims(np.argmin(x, axis), axis) + return np.argmin(x, axis) @UnaryOp.make diff --git a/funsor/tensor.py b/funsor/tensor.py index 3bf8f215e..2dd61f1cb 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -778,16 +778,16 @@ def eager_reduction_tensor(op, arg): return Tensor(op(arg.data), arg.inputs, dtype) # Work around batch inputs. - dim = op.defaults.get("dim", None) - keepdim = op.defaults.get("keepdim", False) + axis = op.defaults.get("axis", None) + keepdims = op.defaults.get("keepdims", False) ndims = len(arg.output.shape) - if dim is None: - dim = tuple(range(-ndims, 0)) - elif isinstance(dim, int): - dim = dim % ndims - ndims + if axis is None: + axis = tuple(range(-ndims, 0)) + elif isinstance(axis, int): + axis = axis % ndims - ndims else: - dim = tuple(d % ndims - ndims for d in dim) - data = op(arg.data, dim=dim, keepdim=keepdim) + axis = tuple(d % ndims - ndims for d in axis) + data = op(arg.data, axis=axis, keepdims=keepdims) return Tensor(data, arg.inputs, dtype) diff --git a/funsor/terms.py b/funsor/terms.py index d2f08df4d..94a5196cb 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -559,41 +559,41 @@ def reshape(self, shape): # reduce over output shape while preserving all inputs. # To reduce over inputs, instead call .reduce(op, reduced_vars). - def all(self, dim=None, keepdim=False): - return Unary(ops.AllOp(dim, keepdim), self) + def all(self, axis=None, keepdims=False): + return Unary(ops.AllOp(axis, keepdims), self) - def any(self, dim=None, keepdim=False): - return Unary(ops.AnyOp(dim, keepdim), self) + def any(self, axis=None, keepdims=False): + return Unary(ops.AnyOp(axis, keepdims), self) - def argmax(self, dim=None, keepdim=False): - return Unary(ops.ArgmaxOp(dim, keepdim), self) + def argmax(self, axis=None, keepdims=False): + return Unary(ops.ArgmaxOp(axis, keepdims), self) - def argmin(self, dim=None, keepdim=False): - return Unary(ops.ArgminOp(dim, keepdim), self) + def argmin(self, axis=None, keepdims=False): + return Unary(ops.ArgminOp(axis, keepdims), self) - def max(self, dim=None, keepdim=False): - return Unary(ops.AmaxOp(dim, keepdim), self) + def max(self, axis=None, keepdims=False): + return Unary(ops.AmaxOp(axis, keepdims), self) - def min(self, dim=None, keepdim=False): - return Unary(ops.AminOp(dim, keepdim), self) + def min(self, axis=None, keepdims=False): + return Unary(ops.AminOp(axis, keepdims), self) - def sum(self, dim=None, keepdim=False): - return Unary(ops.SumOp(dim, keepdim), self) + def sum(self, axis=None, keepdims=False): + return Unary(ops.SumOp(axis, keepdims), self) - def prod(self, dim=None, keepdim=False): - return Unary(ops.ProdOp(dim, keepdim), self) + def prod(self, axis=None, keepdims=False): + return Unary(ops.ProdOp(axis, keepdims), self) - def logsumexp(self, dim=None, keepdim=False): - return Unary(ops.LogsumexpOp(dim, keepdim), self) + def logsumexp(self, axis=None, keepdims=False): + return Unary(ops.LogsumexpOp(axis, keepdims), self) - def mean(self, dim=None, keepdim=False): - return Unary(ops.MeanOp(dim, keepdim), self) + def mean(self, axis=None, keepdims=False): + return Unary(ops.MeanOp(axis, keepdims), self) - def std(self, dim=None, ddof=0, keepdim=False): - return Unary(ops.StdOp(dim, ddof, keepdim), self) + def std(self, axis=None, ddof=0, keepdims=False): + return Unary(ops.StdOp(axis, ddof, keepdims), self) - def var(self, dim=None, ddof=0, keepdim=False): - return Unary(ops.VarOp(dim, ddof, keepdim), self) + def var(self, axis=None, ddof=0, keepdims=False): + return Unary(ops.VarOp(axis, ddof, keepdims), self) def __add__(self, other): return Binary(ops.add, self, to_funsor(other)) diff --git a/funsor/torch/ops.py b/funsor/torch/ops.py index a61b6024a..2a75e9d3c 100644 --- a/funsor/torch/ops.py +++ b/funsor/torch/ops.py @@ -43,108 +43,108 @@ def _flatten_reduced_dim(x, dim): @ops.all.register(torch.Tensor) -def _all(x, dim, keepdim): - if dim is None and not keepdim: +def _all(x, axis, keepdims): + if axis is None and not keepdims: return torch.all(x) - if isinstance(dim, int): - return torch.all(x, dim, keepdim=keepdim) + if isinstance(axis, int): + return torch.all(x, axis, keepdim=keepdims) # reduce over multiple dims. - x_flattened, reduced_dim = _flatten_reduced_dim(x, dim) - if keepdim: + x_flattened, reduced_dim = _flatten_reduced_dim(x, axis) + if keepdims: shape = tuple(1 if i in reduced_dim else x.shape[i] for i in range(x.dim())) return torch.all(x_flattened, -1).view(shape) return torch.all(x_flattened, -1) @ops.any.register(torch.Tensor) -def _any(x, dim, keepdim): - if dim is None and not keepdim: +def _any(x, axis, keepdims): + if axis is None and not keepdims: return torch.any(x) - if isinstance(dim, int): - return torch.any(x, dim, keepdim=keepdim) + if isinstance(axis, int): + return torch.any(x, axis, keepdim=keepdims) # reduce over multiple dims. - x_flattened, reduced_dim = _flatten_reduced_dim(x, dim) - if keepdim: + x_flattened, reduced_dim = _flatten_reduced_dim(x, axis) + if keepdims: shape = tuple(1 if i in reduced_dim else x.shape[i] for i in range(x.dim())) return torch.any(x_flattened, -1).view(shape) return torch.any(x_flattened, -1) @ops.amax.register(torch.Tensor) -def _amax(x, dim, keepdim): - if dim is None and not keepdim: +def _amax(x, axis, keepdims): + if axis is None and not keepdims: return torch.amax(x) - dim = tuple(range(x.dim())) if dim is None else dim - return torch.amax(x, dim, keepdim=keepdim) + axis = tuple(range(x.dim())) if axis is None else axis + return torch.amax(x, axis, keepdim=keepdims) @ops.amin.register(torch.Tensor) -def _amin(x, dim, keepdim): - if dim is None and not keepdim: +def _amin(x, axis, keepdims): + if axis is None and not keepdims: return torch.amin(x) - dim = tuple(range(x.dim())) if dim is None else dim - return torch.amin(x, dim, keepdim=keepdim) + axis = tuple(range(x.dim())) if axis is None else axis + return torch.amin(x, axis, keepdim=keepdims) @ops.sum.register(torch.Tensor) -def _sum(x, dim, keepdim): - if dim is None and not keepdim: +def _sum(x, axis, keepdims): + if axis is None and not keepdims: return torch.sum(x) - dim = tuple(range(x.dim())) if dim is None else dim - return torch.sum(x, dim, keepdim=keepdim) + axis = tuple(range(x.dim())) if axis is None else axis + return torch.sum(x, axis, keepdim=keepdims) @ops.prod.register(torch.Tensor) -def _prod(x, dim, keepdim): - if dim is None and not keepdim: +def _prod(x, axis, keepdims): + if axis is None and not keepdims: return torch.prod(x) - if isinstance(dim, int): - return torch.prod(x, dim, keepdim=keepdim) + if isinstance(axis, int): + return torch.prod(x, axis, keepdim=keepdims) # reduce over multiple dims. - x_flattened, reduced_dim = _flatten_reduced_dim(x, dim) - if keepdim: + x_flattened, reduced_dim = _flatten_reduced_dim(x, axis) + if keepdims: shape = tuple(1 if i in reduced_dim else x.shape[i] for i in range(x.dim())) return torch.prod(x_flattened, -1).view(shape) return torch.prod(x_flattened, -1) @ops.logsumexp.register(torch.Tensor) -def _logsumexp(x, dim, keepdim): - dim = tuple(range(x.dim())) if dim is None else dim - return torch.logsumexp(x, dim, keepdim=keepdim) +def _logsumexp(x, axis, keepdims): + axis = tuple(range(x.dim())) if axis is None else axis + return torch.logsumexp(x, axis, keepdim=keepdims) @ops.mean.register(torch.Tensor) -def _mean(x, dim, keepdim): - if dim is None and not keepdim: +def _mean(x, axis, keepdims): + if axis is None and not keepdims: return torch.mean(x) - dim = tuple(range(x.dim())) if dim is None else dim - return torch.mean(x, dim, keepdim=keepdim) + axis = tuple(range(x.dim())) if axis is None else axis + return torch.mean(x, axis, keepdim=keepdims) @ops.std.register(torch.Tensor) -def _std(x, dim, ddof, keepdim): - dim = tuple(range(x.dim())) if dim is None else dim +def _std(x, axis, ddof, keepdims): + axis = tuple(range(x.dim())) if axis is None else axis if ddof == 0: - return torch.std(x, dim, unbiased=False, keepdim=keepdim) + return torch.std(x, axis, unbiased=False, keepdim=keepdims) if ddof == 1: - return torch.std(x, dim, keepdim=keepdim) + return torch.std(x, axis, keepdim=keepdims) raise NotImplementedError @ops.var.register(torch.Tensor) -def _var(x, dim, ddof, keepdim): - dim = tuple(range(x.dim())) if dim is None else dim +def _var(x, axis, ddof, keepdims): + axis = tuple(range(x.dim())) if axis is None else axis if ddof == 0: - return torch.var(x, dim, unbiased=False, keepdim=keepdim) + return torch.var(x, axis, unbiased=False, keepdim=keepdims) if ddof == 1: - return torch.var(x, dim, keepdim=keepdim) + return torch.var(x, axis, keepdim=keepdims) raise NotImplementedError @@ -152,15 +152,15 @@ def _var(x, dim, ddof, keepdim): @ops.argmax.register(torch.Tensor) -def _argmax(x, dim, keepdim): +def _argmax(x, axis, keepdims): # FIXME find_domain - return torch.argmax(x, dim, keepdim=keepdim) + return torch.argmax(x, axis, keepdim=keepdims) @ops.argmin.register(torch.Tensor) -def _argmin(x, dim, keepdim): +def _argmin(x, axis, keepdims): # FIXME find_domain - return torch.argmin(x, dim, keepdim=keepdim) + return torch.argmin(x, axis, keepdim=keepdims) @ops.astype.register(torch.Tensor) diff --git a/test/test_tensor.py b/test/test_tensor.py index a2455eb2b..48ddb72d2 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -1336,24 +1336,26 @@ def test_reduction(op, event_shape): for dim in DIMS: expected = Tensor(op(data, dim), dtype=dtype) assert_close(op(Tensor(data), dim), expected) - assert_close(op(Tensor(data), dim=dim), expected) + assert_close(op(Tensor(data), axis=dim), expected) assert_close(getattr(Tensor(data), op_name)(dim), expected) - assert_close(getattr(Tensor(data), op_name)(dim=dim), expected) + assert_close(getattr(Tensor(data), op_name)(axis=dim), expected) - for keepdim in KEEPDIMS: - expected = Tensor(op(data, keepdim=keepdim), dtype=dtype) - assert_close(op(Tensor(data), keepdim=keepdim), expected) - assert_close(getattr(Tensor(data), op_name)(keepdim=keepdim), expected) + for keepdims in KEEPDIMS: + expected = Tensor(op(data, keepdims=keepdims), dtype=dtype) + assert_close(op(Tensor(data), keepdims=keepdims), expected) + assert_close(getattr(Tensor(data), op_name)(keepdims=keepdims), expected) for dim in DIMS: - expected = Tensor(op(data, dim, keepdim), dtype=dtype) - assert_close(op(Tensor(data), dim, keepdim), expected) - assert_close(op(Tensor(data), dim, keepdim=keepdim), expected) - assert_close(op(Tensor(data), dim=dim, keepdim=keepdim), expected) - assert_close(getattr(Tensor(data), op_name)(dim, keepdim), expected) - assert_close(getattr(Tensor(data), op_name)(dim, keepdim=keepdim), expected) + expected = Tensor(op(data, dim, keepdims), dtype=dtype) + assert_close(op(Tensor(data), dim, keepdims), expected) + assert_close(op(Tensor(data), dim, keepdims=keepdims), expected) + assert_close(op(Tensor(data), axis=dim, keepdims=keepdims), expected) + assert_close(getattr(Tensor(data), op_name)(dim, keepdims), expected) assert_close( - getattr(Tensor(data), op_name)(dim=dim, keepdim=keepdim), expected + getattr(Tensor(data), op_name)(dim, keepdims=keepdims), expected + ) + assert_close( + getattr(Tensor(data), op_name)(axis=dim, keepdims=keepdims), expected ) @@ -1378,31 +1380,34 @@ def test_std_var(op, event_shape): for dim in DIMS: expected = Tensor(op(data, dim)) assert_close(op(Tensor(data), dim), expected) - assert_close(op(Tensor(data), dim=dim), expected) + assert_close(op(Tensor(data), axis=dim), expected) assert_close(getattr(Tensor(data), op.name)(dim), expected) - assert_close(getattr(Tensor(data), op.name)(dim=dim), expected) + assert_close(getattr(Tensor(data), op.name)(axis=dim), expected) - for keepdim in KEEPDIMS: - expected = Tensor(op(data, keepdim=keepdim)) - assert_close(op(Tensor(data), keepdim=keepdim), expected) - assert_close(getattr(Tensor(data), op.name)(keepdim=keepdim), expected) + for keepdims in KEEPDIMS: + expected = Tensor(op(data, keepdims=keepdims)) + assert_close(op(Tensor(data), keepdims=keepdims), expected) + assert_close(getattr(Tensor(data), op.name)(keepdims=keepdims), expected) for ddof in DDOFS: for dim in DIMS: - expected = Tensor(op(data, dim, ddof, keepdim)) - assert_close(op(Tensor(data), dim, ddof, keepdim), expected) - assert_close(op(Tensor(data), dim, ddof, keepdim=keepdim), expected) + expected = Tensor(op(data, dim, ddof, keepdims)) + assert_close(op(Tensor(data), dim, ddof, keepdims), expected) + assert_close(op(Tensor(data), dim, ddof, keepdims=keepdims), expected) assert_close( - op(Tensor(data), dim=dim, ddof=ddof, keepdim=keepdim), expected + op(Tensor(data), axis=dim, ddof=ddof, keepdims=keepdims), expected ) assert_close( - getattr(Tensor(data), op.name)(dim, ddof, keepdim), expected + getattr(Tensor(data), op.name)(dim, ddof, keepdims), expected ) assert_close( - getattr(Tensor(data), op.name)(dim, ddof, keepdim=keepdim), expected + getattr(Tensor(data), op.name)(dim, ddof, keepdims=keepdims), + expected, ) assert_close( - getattr(Tensor(data), op.name)(dim=dim, ddof=ddof, keepdim=keepdim), + getattr(Tensor(data), op.name)( + axis=dim, ddof=ddof, keepdims=keepdims + ), expected, ) @@ -1431,16 +1436,16 @@ def test_reduction_batch(op, batch_shape, event_shape): DIMS = [None, 0, 1, 2, -1, -2, -3, (0, 2)] KEEPDIMS = [False, True] - def raw_reduction(x, dim=None, keepdim=False, batch_ndims=len(batch_shape)): + def raw_reduction(x, dim=None, keepdims=False, batch_ndims=len(batch_shape)): if batch_ndims == 0: - return op(x, dim, keepdim=keepdim) + return op(x, dim, keepdims=keepdims) return ops.stack( - [raw_reduction(part, dim, keepdim, batch_ndims - 1) for part in x] + [raw_reduction(part, dim, keepdims, batch_ndims - 1) for part in x] ) rtol = 1e-5 if get_backend() == "jax" else 1e-6 - for keepdim in KEEPDIMS: + for keepdims in KEEPDIMS: for dim in DIMS: - actual = op(Tensor(data, inputs), dim, keepdim=keepdim) - expected = Tensor(raw_reduction(data, dim, keepdim), inputs, dtype) + actual = op(Tensor(data, inputs), dim, keepdims=keepdims) + expected = Tensor(raw_reduction(data, dim, keepdims), inputs, dtype) assert_close(actual, expected, rtol=rtol) diff --git a/test/test_terms.py b/test/test_terms.py index 50b768918..d0c6f742a 100644 --- a/test/test_terms.py +++ b/test/test_terms.py @@ -277,8 +277,8 @@ def test_unary(symbol, data): @pytest.mark.parametrize("event_shape", [(3, 2)], ids=str) -@pytest.mark.parametrize("dim", [None, 0, (1,), (0, 1)], ids=str) -@pytest.mark.parametrize("keepdim", [False, True], ids=str) +@pytest.mark.parametrize("axis", [None, 0, (1,), (0, 1)], ids=str) +@pytest.mark.parametrize("keepdims", [False, True], ids=str) @pytest.mark.parametrize( "name", [ @@ -294,20 +294,20 @@ def test_unary(symbol, data): "var", ], ) -def test_reduce_event(name, event_shape, dim, keepdim): +def test_reduce_event(name, event_shape, axis, keepdims): dtype = 2 if name in ("any", "all") else "real" x = random_tensor(OrderedDict(i=Bint[5]), output=Array[dtype, event_shape]) - actual = getattr(x, name)(dim=dim, keepdim=keepdim) + actual = getattr(x, name)(axis=axis, keepdims=keepdims) # compute expected shape - dim = (0, 1) if dim is None else dim - dim = (dim,) if isinstance(dim, int) else dim - if keepdim: + axis = (0, 1) if axis is None else axis + axis = (axis,) if isinstance(axis, int) else axis + if keepdims: shape = tuple( - 1 if i in dim else event_shape[i] for i in range(len(event_shape)) + 1 if i in axis else event_shape[i] for i in range(len(event_shape)) ) else: - shape = tuple(event_shape[i] for i in range(len(event_shape)) if i not in dim) + shape = tuple(event_shape[i] for i in range(len(event_shape)) if i not in axis) check_funsor(actual, x.inputs, Array[dtype, shape])