diff --git a/funsor/affine.py b/funsor/affine.py index 5c6b31535..0751f84b2 100644 --- a/funsor/affine.py +++ b/funsor/affine.py @@ -61,7 +61,7 @@ def _(fn): @affine_inputs.register(Unary) def _(fn): - if fn.op in (ops.neg, ops.add) or isinstance(fn.op, ops.ReshapeOp): + if fn.op in (ops.neg, ops.sum) or isinstance(fn.op, ops.ReshapeOp): return affine_inputs(fn.arg) return frozenset() diff --git a/funsor/domains.py b/funsor/domains.py index 3dba4b34b..4785b73ee 100644 --- a/funsor/domains.py +++ b/funsor/domains.py @@ -248,8 +248,8 @@ def _find_domain_log_exp(op, domain): return Array["real", domain.shape] -@find_domain.register(ops.SumOp) -def _find_domain_sum(op, domain): +@find_domain.register(ops.ReductionOp) +def _find_domain_reduction(op, domain): # Canonicalize dim. dim = op.defaults.get("dim", None) ndims = len(domain.shape) @@ -262,12 +262,14 @@ def _find_domain_sum(op, domain): # Compute shape. if op.defaults.get("keepdims", False): - shape = tuple(1 if i in dims else size for i, size in enumerate(domain.shape)) + shape = tuple(1 if i in dims else domain.shape[i] for i in range(ndims)) else: - shape = tuple(size for i, size in enumerate(domain.shape) if i not in dims) + shape = tuple(domain.shape[i] for i in range(ndims) if i not in dims) # Compute domain. - if domain.dtype == "real": + if op.name in ("all", "any"): + dtype = 2 + elif domain.dtype == "real": dtype = "real" else: raise NotImplementedError("TODO") diff --git a/funsor/jax/ops.py b/funsor/jax/ops.py index 76247725b..0b173959f 100644 --- a/funsor/jax/ops.py +++ b/funsor/jax/ops.py @@ -34,29 +34,76 @@ ops.unsqueeze.register(array)(np.expand_dims) +########################################### +# Reduction Ops +########################################### + + @ops.all.register(array) -def _all(x, dim): - return np.all(x, axis=dim) +def _all(x, axis, keepdims): + return np.all(x, axis, keepdims=keepdims) + + +@ops.any.register(array) +def _any(x, axis, keepdims): + return np.any(x, axis, keepdims=keepdims) @ops.amax.register(array) -def _amax(x, dim, keepdims=False): - return np.amax(x, axis=dim, keepdims=keepdims) +def _amax(x, axis, keepdims): + return np.amax(x, axis, keepdims=keepdims) @ops.amin.register(array) -def _amin(x, dim, keepdims=False): - return np.amin(x, axis=dim, keepdims=keepdims) +def _amin(x, axis, keepdims): + return np.amax(x, axis, keepdims=keepdims) + + +@ops.sum.register(array) +def _sum(x, axis, keepdims): + return np.sum(x, axis, keepdims=keepdims) + + +@ops.prod.register(array) +def _prod(x, axis, keepdims): + return np.prod(x, axis, keepdims=keepdims) + + +@ops.logsumexp.register(array) +def _logsumexp(x, axis, keepdims): + return logsumexp(x, axis, keepdims=keepdims) + + +@ops.mean.register(array) +def _mean(x, axis, keepdims): + return np.mean(x, axis, keepdims=keepdims) + + +@ops.std.register(array) +def _std(x, axis, ddof, keepdims): + return np.std(x, axis, ddof=ddof, keepdims=keepdims) + + +@ops.var.register(array) +def _var(x, axis, ddof, keepdims): + return np.var(x, axis, ddof=ddof, keepdims=keepdims) + + +########################################### @ops.argmax.register(array) -def _argmax(x, dim): - return np.argmax(x, dim) +def _argmax(x, axis, keepdims): + if keepdims: + return np.expand_dims(np.argmax(x, axis), axis) + return np.argmax(x, axis) -@ops.any.register(array) -def _any(x, dim): - return np.any(x, axis=dim) +@ops.argmin.register(array) +def _argmin(x, axis, keepdims): + if keepdims: + return np.expand_dims(np.argmin(x, axis), axis) + return np.argmin(x, axis) @ops.astype.register(array) @@ -161,11 +208,6 @@ def _safe_logaddexp_tensor_number(x, y): return _safe_logaddexp_number_tensor(y, x) -@ops.logsumexp.register(array) -def _logsumexp(x, dim): - return logsumexp(x, axis=dim) - - ops.max.register(array, array)(np.maximum) ops.min.register(array, array)(np.minimum) @@ -216,11 +258,6 @@ def _new_zeros(x, shape): return onp.zeros(shape, dtype=np.result_type(x)) -@ops.prod.register(array) -def _prod(x, dim): - return np.prod(x, axis=dim) - - @ops.reciprocal.register(array) def _reciprocal(x): result = np.clip(np.reciprocal(x), a_max=np.finfo(np.result_type(x)).max) @@ -257,11 +294,6 @@ def _stack(parts, dim=0): return np.stack(parts, axis=dim) -@ops.sum.register(array) -def _sum(x, dim, keepdims): - return np.sum(x, dim, keepdims=keepdims) - - @ops.triangular_solve.register(array, array) def _triangular_solve(x, y, upper=False, transpose=False): assert np.ndim(x) >= 2 and np.ndim(y) >= 2 diff --git a/funsor/ops/array.py b/funsor/ops/array.py index 97831fa9b..fe51de5dd 100644 --- a/funsor/ops/array.py +++ b/funsor/ops/array.py @@ -30,6 +30,7 @@ FinitaryOp, Op, OpMeta, + ReductionOp, TernaryOp, UnaryOp, declare_op_types, @@ -49,34 +50,81 @@ atanh.register(array)(np.arctanh) -@UnaryOp.make -def all(x, dim=None): - return np.all(x, dim) +########################################### +# Reduction Ops +########################################### -@UnaryOp.make -def any(x, dim=None): - return np.any(x, dim) +@ReductionOp.make +def all(x, axis=None, keepdims=False): + return np.all(x, axis, keepdims=keepdims) -@UnaryOp.make -def amax(x, dim=None, keepdims=False): - return np.amax(x, dim, keepdims=keepdims) +@ReductionOp.make +def any(x, axis=None, keepdims=False): + return np.any(x, axis, keepdims=keepdims) -@UnaryOp.make -def amin(x, dim=None, keepdims=False): - return np.amax(x, dim, keepdims=keepdims) +@ReductionOp.make +def amax(x, axis=None, keepdims=False): + return np.amax(x, axis, keepdims=keepdims) + + +@ReductionOp.make +def amin(x, axis=None, keepdims=False): + return np.amin(x, axis, keepdims=keepdims) + + +@ReductionOp.make +def sum(x, axis=None, keepdims=False): + return np.sum(x, axis, keepdims=keepdims) + + +@ReductionOp.make +def prod(x, axis=None, keepdims=False): + return np.prod(x, axis, keepdims=keepdims) + + +@ReductionOp.make +def logsumexp(x, axis=None, keepdims=False): + amax = np.amax(x, axis=axis, keepdims=True) + # treat the case x = -inf + amax = np.where(np.isfinite(amax), amax, 0.0) + unnormalized_lse = log(np.sum(np.exp(x - amax), axis, keepdims=keepdims)) + amax = amax if keepdims else amax.squeeze(axis) + return unnormalized_lse + amax + + +@ReductionOp.make +def mean(x, axis=None, keepdims=False): + return np.mean(x, axis, keepdims=keepdims) + + +@ReductionOp.make +def std(x, axis=None, ddof=0, keepdims=False): + return np.std(x, axis, ddof=ddof, keepdims=keepdims) + + +@ReductionOp.make +def var(x, axis=None, ddof=0, keepdims=False): + return np.var(x, axis, ddof=ddof, keepdims=keepdims) + + +########################################### @UnaryOp.make -def sum(x, dim=None, keepdims=False): - return np.sum(x, dim, keepdims=keepdims) +def argmax(x, axis=None, keepdims=False): + if keepdims: + return np.expand_dims(np.argmax(x, axis), axis) + return np.argmax(x, axis) @UnaryOp.make -def prod(x, dim=None): - return np.prod(x, dim) +def argmin(x, axis=None, keepdims=False): + if keepdims: + return np.expand_dims(np.argmin(x, axis), axis) + return np.argmin(x, axis) @UnaryOp.make @@ -247,14 +295,6 @@ def _safe_logaddexp_tensor_number(x, y): return _safe_logaddexp_number_tensor(y, x) -@UnaryOp.make -def logsumexp(x, dim): - amax = np.amax(x, axis=dim, keepdims=True) - # treat the case x = -inf - amax = np.where(np.isfinite(amax), amax, 0.0) - return log(np.sum(np.exp(x - amax), axis=dim)) + amax.squeeze(axis=dim) - - max.register(array, array)(np.maximum) min.register(array, array)(np.minimum) @@ -279,16 +319,6 @@ def _min(x, y): return np.clip(x, a_min=None, a_max=y) -@UnaryOp.make -def argmax(x, dim): - raise NotImplementedError - - -@argmax.register(array) -def _argmax(x, dim): - return np.argmax(x, dim) - - @UnaryOp.make def new_arange(x, start=None, stop=None, step=None): raise NotImplementedError @@ -413,6 +443,7 @@ def unsqueeze(x, dim): "amin", "any", "argmax", + "argmin", "astype", "cat", "cholesky", @@ -429,6 +460,7 @@ def unsqueeze(x, dim): "isnan", "logaddexp", "logsumexp", + "mean", "new_arange", "new_eye", "new_full", @@ -439,10 +471,12 @@ def unsqueeze(x, dim): "scatter", "scatter_add", "stack", + "std", "sum", "transpose", "triangular_solve", "unsqueeze", + "var", ] declare_op_types(globals(), __all__, __name__) diff --git a/funsor/ops/op.py b/funsor/ops/op.py index 34bfe20b7..4c4b9b036 100644 --- a/funsor/ops/op.py +++ b/funsor/ops/op.py @@ -282,6 +282,15 @@ def _list_to_tuple(cls, arg, *args, **kwargs): return op(arg) +class ReductionOp(UnaryOp): + """ + Reduction operations are defined in a broad sense - not only + associative operations. This helps to unify find_domain logic. + """ + + pass + + class TransformOp(UnaryOp): def set_inv(self, fn): """ diff --git a/funsor/tensor.py b/funsor/tensor.py index ec86524c3..2dd61f1cb 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -297,11 +297,6 @@ def eager_subs(self, subs): def eager_unary(self, op): dtype = find_domain(op, self.output).dtype - if op in REDUCE_OP_TO_NUMERIC: - batch_dim = len(self.data.shape) - len(self.output.shape) - data = self.data.reshape(self.data.shape[:batch_dim] + (-1,)) - data = REDUCE_OP_TO_NUMERIC[op](data, -1) - return Tensor(data, self.inputs, dtype) return Tensor(op(self.data), self.inputs, dtype) def eager_reduce(self, op, reduced_vars): @@ -772,23 +767,28 @@ def eager_reshape_tensor(op, arg): return Tensor(data, arg.inputs, arg.dtype) -@eager.register(Unary, ops.SumOp, Tensor) -def eager_sum_tensor(op, arg): +@eager.register(Unary, ops.ReductionOp, Tensor) +def eager_reduction_tensor(op, arg): + dtype = find_domain(op, arg.output).dtype + + if not arg.output.shape: + return Tensor(op(ops.unsqueeze(arg.data, -1), -1), arg.inputs, dtype) + if not arg.inputs: - return Tensor(op(arg.data), arg.inputs, arg.dtype) + return Tensor(op(arg.data), arg.inputs, dtype) # Work around batch inputs. - dim = op.defaults.get("dim", None) + axis = op.defaults.get("axis", None) keepdims = op.defaults.get("keepdims", False) ndims = len(arg.output.shape) - if dim is None: - dim = tuple(range(-ndims, 0)) - elif isinstance(dim, int): - dim = dim % ndims - ndims + if axis is None: + axis = tuple(range(-ndims, 0)) + elif isinstance(axis, int): + axis = axis % ndims - ndims else: - dim = tuple(d % ndims - ndims for d in dim) - data = op(arg.data, dim, keepdims) - return Tensor(data, arg.inputs, arg.dtype) + axis = tuple(d % ndims - ndims for d in axis) + data = op(arg.data, axis=axis, keepdims=keepdims) + return Tensor(data, arg.inputs, dtype) @eager.register(Binary, GetitemOp, Tensor, Number) diff --git a/funsor/terms.py b/funsor/terms.py index 373822dcf..94a5196cb 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -559,26 +559,41 @@ def reshape(self, shape): # reduce over output shape while preserving all inputs. # To reduce over inputs, instead call .reduce(op, reduced_vars). - def sum(self): - return Unary(ops.add, self) + def all(self, axis=None, keepdims=False): + return Unary(ops.AllOp(axis, keepdims), self) - def prod(self): - return Unary(ops.mul, self) + def any(self, axis=None, keepdims=False): + return Unary(ops.AnyOp(axis, keepdims), self) - def logsumexp(self): - return Unary(ops.logaddexp, self) + def argmax(self, axis=None, keepdims=False): + return Unary(ops.ArgmaxOp(axis, keepdims), self) - def all(self): - return Unary(ops.and_, self) + def argmin(self, axis=None, keepdims=False): + return Unary(ops.ArgminOp(axis, keepdims), self) - def any(self): - return Unary(ops.or_, self) + def max(self, axis=None, keepdims=False): + return Unary(ops.AmaxOp(axis, keepdims), self) - def min(self): - return Unary(ops.min, self) + def min(self, axis=None, keepdims=False): + return Unary(ops.AminOp(axis, keepdims), self) - def max(self): - return Unary(ops.max, self) + def sum(self, axis=None, keepdims=False): + return Unary(ops.SumOp(axis, keepdims), self) + + def prod(self, axis=None, keepdims=False): + return Unary(ops.ProdOp(axis, keepdims), self) + + def logsumexp(self, axis=None, keepdims=False): + return Unary(ops.LogsumexpOp(axis, keepdims), self) + + def mean(self, axis=None, keepdims=False): + return Unary(ops.MeanOp(axis, keepdims), self) + + def std(self, axis=None, ddof=0, keepdims=False): + return Unary(ops.StdOp(axis, ddof, keepdims), self) + + def var(self, axis=None, ddof=0, keepdims=False): + return Unary(ops.VarOp(axis, ddof, keepdims), self) def __add__(self, other): return Binary(ops.add, self, to_funsor(other)) diff --git a/funsor/torch/ops.py b/funsor/torch/ops.py index 52e03f64e..2a75e9d3c 100644 --- a/funsor/torch/ops.py +++ b/funsor/torch/ops.py @@ -26,29 +26,141 @@ ops.unsqueeze.register(torch.Tensor)(torch.unsqueeze) +########################################### +# Reduction Ops +########################################### + + +def _flatten_reduced_dim(x, dim): + # Canonicalize reduced dim. + reduced_dim = ( + tuple(range(x.dim())) if dim is None else tuple(d % x.dim() for d in dim) + ) + nonreduced_dim = tuple(i for i in range(x.dim()) if i not in reduced_dim) + # permute & flatten reduced dim. + permutation = nonreduced_dim + reduced_dim + return x.permute(permutation).flatten(-len(reduced_dim), -1), reduced_dim + + @ops.all.register(torch.Tensor) -def _all(x, dim): - return x.all() if dim is None else x.all(dim=dim) +def _all(x, axis, keepdims): + if axis is None and not keepdims: + return torch.all(x) + + if isinstance(axis, int): + return torch.all(x, axis, keepdim=keepdims) + + # reduce over multiple dims. + x_flattened, reduced_dim = _flatten_reduced_dim(x, axis) + if keepdims: + shape = tuple(1 if i in reduced_dim else x.shape[i] for i in range(x.dim())) + return torch.all(x_flattened, -1).view(shape) + return torch.all(x_flattened, -1) + + +@ops.any.register(torch.Tensor) +def _any(x, axis, keepdims): + if axis is None and not keepdims: + return torch.any(x) + + if isinstance(axis, int): + return torch.any(x, axis, keepdim=keepdims) + + # reduce over multiple dims. + x_flattened, reduced_dim = _flatten_reduced_dim(x, axis) + if keepdims: + shape = tuple(1 if i in reduced_dim else x.shape[i] for i in range(x.dim())) + return torch.any(x_flattened, -1).view(shape) + return torch.any(x_flattened, -1) @ops.amax.register(torch.Tensor) -def _amax(x, dim, keepdims=False): - return x.max() if dim is None else x.max(dim, keepdims)[0] +def _amax(x, axis, keepdims): + if axis is None and not keepdims: + return torch.amax(x) + axis = tuple(range(x.dim())) if axis is None else axis + return torch.amax(x, axis, keepdim=keepdims) @ops.amin.register(torch.Tensor) -def _amin(x, dim, keepdims=False): - return x.min() if dim is None else x.min(dim, keepdims)[0] +def _amin(x, axis, keepdims): + if axis is None and not keepdims: + return torch.amin(x) + axis = tuple(range(x.dim())) if axis is None else axis + return torch.amin(x, axis, keepdim=keepdims) + + +@ops.sum.register(torch.Tensor) +def _sum(x, axis, keepdims): + if axis is None and not keepdims: + return torch.sum(x) + axis = tuple(range(x.dim())) if axis is None else axis + return torch.sum(x, axis, keepdim=keepdims) + + +@ops.prod.register(torch.Tensor) +def _prod(x, axis, keepdims): + if axis is None and not keepdims: + return torch.prod(x) + + if isinstance(axis, int): + return torch.prod(x, axis, keepdim=keepdims) + + # reduce over multiple dims. + x_flattened, reduced_dim = _flatten_reduced_dim(x, axis) + if keepdims: + shape = tuple(1 if i in reduced_dim else x.shape[i] for i in range(x.dim())) + return torch.prod(x_flattened, -1).view(shape) + return torch.prod(x_flattened, -1) + + +@ops.logsumexp.register(torch.Tensor) +def _logsumexp(x, axis, keepdims): + axis = tuple(range(x.dim())) if axis is None else axis + return torch.logsumexp(x, axis, keepdim=keepdims) + + +@ops.mean.register(torch.Tensor) +def _mean(x, axis, keepdims): + if axis is None and not keepdims: + return torch.mean(x) + axis = tuple(range(x.dim())) if axis is None else axis + return torch.mean(x, axis, keepdim=keepdims) + + +@ops.std.register(torch.Tensor) +def _std(x, axis, ddof, keepdims): + axis = tuple(range(x.dim())) if axis is None else axis + if ddof == 0: + return torch.std(x, axis, unbiased=False, keepdim=keepdims) + if ddof == 1: + return torch.std(x, axis, keepdim=keepdims) + raise NotImplementedError + + +@ops.var.register(torch.Tensor) +def _var(x, axis, ddof, keepdims): + axis = tuple(range(x.dim())) if axis is None else axis + if ddof == 0: + return torch.var(x, axis, unbiased=False, keepdim=keepdims) + if ddof == 1: + return torch.var(x, axis, keepdim=keepdims) + raise NotImplementedError + + +########################################### @ops.argmax.register(torch.Tensor) -def _argmax(x, dim): - return x.max(dim).indices +def _argmax(x, axis, keepdims): + # FIXME find_domain + return torch.argmax(x, axis, keepdim=keepdims) -@ops.any.register(torch.Tensor) -def _any(x, dim): - return x.any() if dim is None else x.any(dim=dim) +@ops.argmin.register(torch.Tensor) +def _argmin(x, axis, keepdims): + # FIXME find_domain + return torch.argmin(x, axis, keepdim=keepdims) @ops.astype.register(torch.Tensor) @@ -146,11 +258,6 @@ def _safe_logaddexp_tensor_number(x, y): return _safe_logaddexp_number_tensor(y, x) -@ops.logsumexp.register(torch.Tensor) -def _logsumexp(x, dim): - return x.reshape(-1).logsumexp(0) if dim is None else x.logsumexp(dim) - - @ops.max.register(torch.Tensor, torch.Tensor) def _max(x, y): return torch.max(x, y) @@ -223,11 +330,6 @@ def _pow(x, y): return x ** y -@ops.prod.register(torch.Tensor) -def _prod(x, dim): - return x.prod() if dim is None else x.prod(dim=dim) - - @ops.reciprocal.register(torch.Tensor) def _reciprocal(x): result = x.reciprocal().clamp(max=torch.finfo(x.dtype).max) @@ -270,16 +372,6 @@ def _scatter_add(destin, indices, source): ops.stack.register(typing.Tuple[torch.Tensor, ...])(torch.stack) -@ops.sum.register(torch.Tensor) -def _sum(x, dim, keepdims): - if dim is None: - if keepdims: - dim = tuple(range(x.dim())) - return x.sum(dim, True) - return x.sum() - return x.sum(dim, keepdims) - - @ops.triangular_solve.register(torch.Tensor, torch.Tensor) def _triangular_solve(x, y, upper=False, transpose=False): return x.triangular_solve(y, upper, transpose).solution diff --git a/test/test_tensor.py b/test/test_tensor.py index 521c63864..48ddb72d2 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -675,6 +675,7 @@ def test_reduce_event(op, event_shape, dims): dtype = "real" if op in [ops.and_, ops.or_]: data = ops.astype(data, "uint8") + dtype = 2 expected_data = numeric_op(data.reshape(batch_shape + (-1,)), -1) x = Tensor(data, inputs, dtype=dtype) @@ -1307,52 +1308,144 @@ def test_scatter_pure_renaming(): assert ((actual - expected).abs() < 1e-4).data.all() +@pytest.mark.parametrize( + "op", + [ + ops.any, + ops.all, + ops.amin, + ops.amax, + ops.sum, + ops.logsumexp, + ops.prod, + ops.mean, + ], +) @pytest.mark.parametrize("event_shape", [(2, 3, 4)], ids=str) -def test_sum(event_shape): +def test_reduction(op, event_shape): data = randn(*event_shape) DIMS = [None, 0, 1, 2, -1, -2, -3, (0, 2)] KEEPDIMS = [False, True] + op_name = op.name[1:] if op.name in {"amin", "amax"} else op.name + dtype = 2 if op.name in {"all", "any"} else "real" + + expected = Tensor(op(data), dtype=dtype) + assert_close(op(Tensor(data)), expected) + assert_close(getattr(Tensor(data), op_name)(), expected) - assert_close(Tensor(ops.sum(data)), ops.sum(Tensor(data))) for dim in DIMS: - assert_close(Tensor(ops.sum(data, dim)), ops.sum(Tensor(data), dim)) - assert_close(Tensor(ops.sum(data, dim=dim)), ops.sum(Tensor(data), dim=dim)) + expected = Tensor(op(data, dim), dtype=dtype) + assert_close(op(Tensor(data), dim), expected) + assert_close(op(Tensor(data), axis=dim), expected) + assert_close(getattr(Tensor(data), op_name)(dim), expected) + assert_close(getattr(Tensor(data), op_name)(axis=dim), expected) + for keepdims in KEEPDIMS: - assert_close( - Tensor(ops.sum(data, keepdims=keepdims)), - ops.sum(Tensor(data), keepdims=keepdims), - ) + expected = Tensor(op(data, keepdims=keepdims), dtype=dtype) + assert_close(op(Tensor(data), keepdims=keepdims), expected) + assert_close(getattr(Tensor(data), op_name)(keepdims=keepdims), expected) + for dim in DIMS: + expected = Tensor(op(data, dim, keepdims), dtype=dtype) + assert_close(op(Tensor(data), dim, keepdims), expected) + assert_close(op(Tensor(data), dim, keepdims=keepdims), expected) + assert_close(op(Tensor(data), axis=dim, keepdims=keepdims), expected) + assert_close(getattr(Tensor(data), op_name)(dim, keepdims), expected) assert_close( - Tensor(ops.sum(data, dim, keepdims)), - ops.sum(Tensor(data), dim, keepdims), - ) - assert_close( - Tensor(ops.sum(data, dim, keepdims=keepdims)), - ops.sum(Tensor(data), dim, keepdims=keepdims), + getattr(Tensor(data), op_name)(dim, keepdims=keepdims), expected ) assert_close( - Tensor(ops.sum(data, dim=dim, keepdims=keepdims)), - ops.sum(Tensor(data), dim=dim, keepdims=keepdims), + getattr(Tensor(data), op_name)(axis=dim, keepdims=keepdims), expected ) +@pytest.mark.parametrize( + "op", + [ + ops.std, + ops.var, + ], +) +@pytest.mark.parametrize("event_shape", [(2, 3, 4)], ids=str) +def test_std_var(op, event_shape): + data = randn(*event_shape) + DIMS = [None, 0, 1, 2, -1, -2, -3, (0, 2)] + DDOFS = [0, 1] + KEEPDIMS = [False, True] + + expected = Tensor(op(data)) + assert_close(op(Tensor(data)), expected) + assert_close(getattr(Tensor(data), op.name)(), expected) + + for dim in DIMS: + expected = Tensor(op(data, dim)) + assert_close(op(Tensor(data), dim), expected) + assert_close(op(Tensor(data), axis=dim), expected) + assert_close(getattr(Tensor(data), op.name)(dim), expected) + assert_close(getattr(Tensor(data), op.name)(axis=dim), expected) + + for keepdims in KEEPDIMS: + expected = Tensor(op(data, keepdims=keepdims)) + assert_close(op(Tensor(data), keepdims=keepdims), expected) + assert_close(getattr(Tensor(data), op.name)(keepdims=keepdims), expected) + + for ddof in DDOFS: + for dim in DIMS: + expected = Tensor(op(data, dim, ddof, keepdims)) + assert_close(op(Tensor(data), dim, ddof, keepdims), expected) + assert_close(op(Tensor(data), dim, ddof, keepdims=keepdims), expected) + assert_close( + op(Tensor(data), axis=dim, ddof=ddof, keepdims=keepdims), expected + ) + assert_close( + getattr(Tensor(data), op.name)(dim, ddof, keepdims), expected + ) + assert_close( + getattr(Tensor(data), op.name)(dim, ddof, keepdims=keepdims), + expected, + ) + assert_close( + getattr(Tensor(data), op.name)( + axis=dim, ddof=ddof, keepdims=keepdims + ), + expected, + ) + + +@pytest.mark.parametrize( + "op", + [ + ops.any, + ops.all, + ops.amin, + ops.amax, + ops.sum, + ops.logsumexp, + ops.prod, + ops.mean, + ops.std, + ops.var, + ], +) @pytest.mark.parametrize("batch_shape", [(), (5,)], ids=str) @pytest.mark.parametrize("event_shape", [(2, 3, 4)], ids=str) -def test_sum_batch(batch_shape, event_shape): +def test_reduction_batch(op, batch_shape, event_shape): inputs = OrderedDict((k, Bint[s]) for k, s in zip("abc", batch_shape)) data = randn(*batch_shape, *event_shape) + dtype = 2 if op.name in {"all", "any"} else "real" DIMS = [None, 0, 1, 2, -1, -2, -3, (0, 2)] KEEPDIMS = [False, True] - def raw_sum(x, dim=None, keepdims=False, batch_ndims=len(batch_shape)): + def raw_reduction(x, dim=None, keepdims=False, batch_ndims=len(batch_shape)): if batch_ndims == 0: - return ops.sum(x, dim, keepdims) - return ops.stack([raw_sum(part, dim, keepdims, batch_ndims - 1) for part in x]) + return op(x, dim, keepdims=keepdims) + return ops.stack( + [raw_reduction(part, dim, keepdims, batch_ndims - 1) for part in x] + ) rtol = 1e-5 if get_backend() == "jax" else 1e-6 for keepdims in KEEPDIMS: for dim in DIMS: - actual = ops.sum(Tensor(data, inputs), dim, keepdims) - expected = Tensor(raw_sum(data, dim, keepdims), inputs) + actual = op(Tensor(data, inputs), dim, keepdims=keepdims) + expected = Tensor(raw_reduction(data, dim, keepdims), inputs, dtype) assert_close(actual, expected, rtol=rtol) diff --git a/test/test_terms.py b/test/test_terms.py index 8e682f2f4..d0c6f742a 100644 --- a/test/test_terms.py +++ b/test/test_terms.py @@ -276,6 +276,42 @@ def test_unary(symbol, data): check_funsor(actual, {}, Array[dtype, ()], expected_data) +@pytest.mark.parametrize("event_shape", [(3, 2)], ids=str) +@pytest.mark.parametrize("axis", [None, 0, (1,), (0, 1)], ids=str) +@pytest.mark.parametrize("keepdims", [False, True], ids=str) +@pytest.mark.parametrize( + "name", + [ + "all", + "any", + "max", + "min", + "sum", + "prod", + "logsumexp", + "mean", + "std", + "var", + ], +) +def test_reduce_event(name, event_shape, axis, keepdims): + dtype = 2 if name in ("any", "all") else "real" + x = random_tensor(OrderedDict(i=Bint[5]), output=Array[dtype, event_shape]) + actual = getattr(x, name)(axis=axis, keepdims=keepdims) + + # compute expected shape + axis = (0, 1) if axis is None else axis + axis = (axis,) if isinstance(axis, int) else axis + if keepdims: + shape = tuple( + 1 if i in axis else event_shape[i] for i in range(len(event_shape)) + ) + else: + shape = tuple(event_shape[i] for i in range(len(event_shape)) if i not in axis) + + check_funsor(actual, x.inputs, Array[dtype, shape]) + + BINARY_OPS = ["+", "-", "*", "/", "**", "==", "!=", "<", "<=", ">", ">=", "min", "max"] BOOLEAN_OPS = ["&", "|", "^"]