Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ReductionOp, ops.mean, ops.std, ops.var #482

Merged
merged 22 commits into from
Mar 20, 2021
Merged
2 changes: 1 addition & 1 deletion funsor/affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def _(fn):

@affine_inputs.register(Unary)
def _(fn):
if fn.op in (ops.neg, ops.add) or isinstance(fn.op, ops.ReshapeOp):
if fn.op in (ops.neg, ops.sum) or isinstance(fn.op, ops.ReshapeOp):
return affine_inputs(fn.arg)
return frozenset()

Expand Down
12 changes: 7 additions & 5 deletions funsor/domains.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,8 @@ def _find_domain_log_exp(op, domain):
return Array["real", domain.shape]


@find_domain.register(ops.SumOp)
def _find_domain_sum(op, domain):
@find_domain.register(ops.ReductionOp)
def _find_domain_reduction(op, domain):
# Canonicalize dim.
dim = op.defaults.get("dim", None)
ndims = len(domain.shape)
Expand All @@ -261,14 +261,16 @@ def _find_domain_sum(op, domain):
dims = {i % ndims for i in dim}

# Compute shape.
if op.defaults.get("keepdims", False):
shape = tuple(1 if i in dims else size for i, size in enumerate(domain.shape))
if op.defaults.get("keepdim", False):
shape = tuple(1 if i in dims else domain.shape[i] for i in range(ndims))
else:
shape = tuple(size for i, size in enumerate(domain.shape) if i not in dims)
shape = tuple(domain.shape[i] for i in range(ndims) if i not in dims)

# Compute domain.
if domain.dtype == "real":
dtype = "real"
elif op.name in ("all", "any"):
dtype = domain.dtype
else:
raise NotImplementedError("TODO")

Expand Down
2 changes: 1 addition & 1 deletion funsor/einsum/numpy_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def einsum(equation, *operands):
shift = ops.detach(operand)
for i, dim in enumerate(dims):
if dim not in output:
shift = ops.amax(shift, i, keepdims=True)
shift = ops.amax(shift, i, keepdim=True)
# avoid nan due to -inf - -inf
shift = ops.clamp(shift, ops.finfo(shift).min, None)
exp_operands.append(ops.exp(operand - shift))
Expand Down
86 changes: 59 additions & 27 deletions funsor/jax/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,29 +34,76 @@
ops.unsqueeze.register(array)(np.expand_dims)


###########################################
# Reduction Ops
###########################################


@ops.all.register(array)
def _all(x, dim):
return np.all(x, axis=dim)
def _all(x, dim, keepdim):
return np.all(x, dim, keepdims=keepdim)


@ops.any.register(array)
def _any(x, dim, keepdim):
return np.any(x, dim, keepdims=keepdim)


@ops.argmax.register(array)
def _argmax(x, dim, keepdim):
if keepdim:
return np.expand_dims(np.argmax(x, dim), dim)
return np.argmax(x, dim)


@ops.argmin.register(array)
def _argmin(x, dim, keepdim):
if keepdim:
return np.expand_dims(np.argmin(x, dim), dim)
return np.argmin(x, dim)


@ops.amax.register(array)
def _amax(x, dim, keepdims=False):
return np.amax(x, axis=dim, keepdims=keepdims)
def _amax(x, dim, keepdim):
return np.amax(x, dim, keepdims=keepdim)


@ops.amin.register(array)
def _amin(x, dim, keepdims=False):
return np.amin(x, axis=dim, keepdims=keepdims)
def _amin(x, dim, keepdim):
return np.amax(x, dim, keepdims=keepdim)


@ops.argmax.register(array)
def _argmax(x, dim):
return np.argmax(x, dim)
@ops.sum.register(array)
def _sum(x, dim, keepdim):
return np.sum(x, dim, keepdims=keepdim)


@ops.any.register(array)
def _any(x, dim):
return np.any(x, axis=dim)
@ops.prod.register(array)
def _prod(x, dim, keepdim):
return np.prod(x, dim, keepdims=keepdim)


@ops.logsumexp.register(array)
def _logsumexp(x, dim, keepdim):
return logsumexp(x, dim, keepdims=keepdim)


@ops.mean.register(array)
def _mean(x, dim, keepdim):
return np.mean(x, dim, keepdims=keepdim)


@ops.std.register(array)
def _std(x, dim, ddof, keepdim):
return np.std(x, dim, ddof=ddof, keepdims=keepdim)


@ops.var.register(array)
def _var(x, dim, ddof, keepdim):
return np.var(x, dim, ddof=ddof, keepdims=keepdim)


###########################################


@ops.astype.register(array)
Expand Down Expand Up @@ -161,11 +208,6 @@ def _safe_logaddexp_tensor_number(x, y):
return _safe_logaddexp_number_tensor(y, x)


@ops.logsumexp.register(array)
def _logsumexp(x, dim):
return logsumexp(x, axis=dim)


ops.max.register(array, array)(np.maximum)
ops.min.register(array, array)(np.minimum)

Expand Down Expand Up @@ -216,11 +258,6 @@ def _new_zeros(x, shape):
return onp.zeros(shape, dtype=np.result_type(x))


@ops.prod.register(array)
def _prod(x, dim):
return np.prod(x, axis=dim)


@ops.reciprocal.register(array)
def _reciprocal(x):
result = np.clip(np.reciprocal(x), a_max=np.finfo(np.result_type(x)).max)
Expand Down Expand Up @@ -257,11 +294,6 @@ def _stack(parts, dim=0):
return np.stack(parts, axis=dim)


@ops.sum.register(array)
def _sum(x, dim, keepdims):
return np.sum(x, dim, keepdims=keepdims)


@ops.triangular_solve.register(array, array)
def _triangular_solve(x, y, upper=False, transpose=False):
assert np.ndim(x) >= 2 and np.ndim(y) >= 2
Expand Down
106 changes: 70 additions & 36 deletions funsor/ops/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
FinitaryOp,
Op,
OpMeta,
ReductionOp,
TernaryOp,
UnaryOp,
declare_op_types,
Expand All @@ -49,34 +50,81 @@
atanh.register(array)(np.arctanh)


@UnaryOp.make
def all(x, dim=None):
return np.all(x, dim)
###########################################
# Reduction Ops
###########################################


@UnaryOp.make
def any(x, dim=None):
return np.any(x, dim)
@ReductionOp.make
def all(x, dim=None, keepdim=False):
return np.all(x, dim, keepdims=keepdim)


@UnaryOp.make
def amax(x, dim=None, keepdims=False):
return np.amax(x, dim, keepdims=keepdims)
@ReductionOp.make
def any(x, dim=None, keepdim=False):
return np.any(x, dim, keepdims=keepdim)


@UnaryOp.make
def amin(x, dim=None, keepdims=False):
return np.amax(x, dim, keepdims=keepdims)
@ReductionOp.make
def argmax(x, dim=None, keepdim=False):
if keepdim:
return np.expand_dims(np.argmax(x, dim), dim)
return np.argmax(x, dim)


@UnaryOp.make
def sum(x, dim=None, keepdims=False):
return np.sum(x, dim, keepdims=keepdims)
@ReductionOp.make
def argmin(x, dim=None, keepdim=False):
if keepdim:
return np.expand_dims(np.argmin(x, dim), dim)
return np.argmin(x, dim)


@UnaryOp.make
def prod(x, dim=None):
return np.prod(x, dim)
@ReductionOp.make
def amax(x, dim=None, keepdim=False):
return np.amax(x, dim, keepdims=keepdim)


@ReductionOp.make
def amin(x, dim=None, keepdim=False):
return np.amin(x, dim, keepdims=keepdim)


@ReductionOp.make
def sum(x, dim=None, keepdim=False):
return np.sum(x, dim, keepdims=keepdim)


@ReductionOp.make
def prod(x, dim=None, keepdim=False):
return np.prod(x, dim, keepdims=keepdim)


@ReductionOp.make
def logsumexp(x, dim, keepdim=False):
amax = np.amax(x, axis=dim, keepdims=True)
# treat the case x = -inf
amax = np.where(np.isfinite(amax), amax, 0.0)
unnormalized_lse = log(np.sum(np.exp(x - amax), dim, keepdims=keepdim))
amax = amax if keepdim else amax.squeeze(dim)
return unnormalized_lse + amax


@ReductionOp.make
def mean(x, dim=None, keepdim=False):
return np.mean(x, dim, keepdims=keepdim)


@ReductionOp.make
def std(x, dim=None, ddof=0, keepdim=False):
return np.std(x, dim, ddof=ddof, keepdims=keepdim)


@ReductionOp.make
def var(x, dim=None, ddof=0, keepdim=False):
return np.var(x, dim, ddof=ddof, keepdims=keepdim)


###########################################


@UnaryOp.make
Expand Down Expand Up @@ -247,14 +295,6 @@ def _safe_logaddexp_tensor_number(x, y):
return _safe_logaddexp_number_tensor(y, x)


@UnaryOp.make
def logsumexp(x, dim):
amax = np.amax(x, axis=dim, keepdims=True)
# treat the case x = -inf
amax = np.where(np.isfinite(amax), amax, 0.0)
return log(np.sum(np.exp(x - amax), axis=dim)) + amax.squeeze(axis=dim)


max.register(array, array)(np.maximum)
min.register(array, array)(np.minimum)

Expand All @@ -279,16 +319,6 @@ def _min(x, y):
return np.clip(x, a_min=None, a_max=y)


@UnaryOp.make
def argmax(x, dim):
raise NotImplementedError


@argmax.register(array)
def _argmax(x, dim):
return np.argmax(x, dim)


@UnaryOp.make
def new_arange(x, start=None, stop=None, step=None):
raise NotImplementedError
Expand Down Expand Up @@ -413,6 +443,7 @@ def unsqueeze(x, dim):
"amin",
"any",
"argmax",
"argmin",
"astype",
"cat",
"cholesky",
Expand All @@ -429,6 +460,7 @@ def unsqueeze(x, dim):
"isnan",
"logaddexp",
"logsumexp",
"mean",
"new_arange",
"new_eye",
"new_full",
Expand All @@ -439,10 +471,12 @@ def unsqueeze(x, dim):
"scatter",
"scatter_add",
"stack",
"std",
"sum",
"transpose",
"triangular_solve",
"unsqueeze",
"var",
]

declare_op_types(globals(), __all__, __name__)
Expand Down
9 changes: 9 additions & 0 deletions funsor/ops/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,15 @@ def _list_to_tuple(cls, arg, *args, **kwargs):
return op(arg)


class ReductionOp(UnaryOp):
"""
Reduction operations are defined in a broad sense - not only
associative operations. This helps to unify find_domain logic.
"""

pass


class TransformOp(UnaryOp):
def set_inv(self, fn):
"""
Expand Down
Loading