Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ReductionOp, ops.mean, ops.std, ops.var #482

Merged
merged 22 commits into from
Mar 20, 2021

Conversation

ordabayevy
Copy link
Member

No description provided.

@fritzo
Copy link
Member

fritzo commented Mar 6, 2021

Could you register implementations in funsor/jax/ops.py the following test to test/test_terms.py:

diff --git a/test/test_terms.py b/test/test_terms.py
index e495e3e..88ed47a 100644
--- a/test/test_terms.py
+++ b/test/test_terms.py
@@ -287,6 +287,29 @@ def test_unary(symbol, data):
     check_funsor(actual, {}, Array[dtype, ()], expected_data)


+@pytest.mark.parametrize("event_shape", [(4,), (3, 2)], ids=str)
+@pytest.mark.parametrize(
+    "name",
+    [
+        "all",
+        "any",
+        "logsumexp",
+        "max",
+        "mean",
+        "min",
+        "prod",
+        "std",
+        "sum",
+        "var",
+    ],
+)
+def test_reduce_event(name, event_shape):
+    dtype = 2 if name in ("any", "all") else "real"
+    x = random_tensor(OrderedDict(i=Bint[5]), output=Array[dtype, event_shape])
+    actual = getattr(x, name)()
+    check_funsor(actual, x.inputs, Array[dtype, ()])
+
+
 BINARY_OPS = [
     "+",
     "-",

@fritzo fritzo added the WIP label Mar 6, 2021
@ordabayevy ordabayevy changed the title ops.mean, ops.std, ops.var Add ops.mean, ops.std, ops.var Mar 6, 2021
@ordabayevy
Copy link
Member Author

Is there a need to allow ddof option for ops.std and ops.var? If yes, how is it passed to Unary? Similar to shape option in .reshape?

    def reshape(self, shape):
        return Unary(ops.ReshapeOp(shape), self)

@fritzo
Copy link
Member

fritzo commented Mar 6, 2021

Is there a need to allow ddof option

I'm not sure what the ddof option is, but it would be reasonable to create, in addition to these default mean,var,std ops, more flexible op classes MeanOp, VarOp, and StdOp to take non-default arguments (axis,keepdims) (in numpy parlance) or (dim,keepdim) (in pytorch parlance). You're correct that these could follow ReshapeOp or GetitemOp by using CachedOpMeta. However I think those ops would best be separate and complementary to the default ops, and we would also want to provide them for sum etc.:

class Funsor(...):
    ...
    def mean(self, axis=None, keepdims=False):
        if axis is None and keepdims is False:
            op = ops.mean  # default version from this PR
        else:
            op = ops.MeanOp(axis, keepdims)  # fancy version from a future PR
        return Unary(op, self)

Or if you refactored such that ops.mean = MeanOp(), you could simplify to

class Funsor(...):
    ...
    def mean(self, axis=None, keepdims=False):
        return Unary(ops.MeanOp(axis, keepdims), self)

funsor/jax/ops.py Outdated Show resolved Hide resolved
@fritzo fritzo added the enhancement New feature or request label Mar 7, 2021
@ordabayevy ordabayevy changed the title Add ops.mean, ops.std, ops.var Add ReductionOp, ops.mean, ops.std, ops.var Mar 19, 2021
@ordabayevy ordabayevy marked this pull request as ready for review March 19, 2021 06:44
Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! I'm still uncertain of argmin, argmax...

funsor/tensor.py Outdated Show resolved Hide resolved
test/test_tensor.py Outdated Show resolved Hide resolved
test/test_tensor.py Outdated Show resolved Hide resolved
@ordabayevy
Copy link
Member Author

I'm still uncertain of argmin, argmax...

There were some rough edges and argmin and argmax are definitely one of them. Probably should be in a separate ArgreductionOp as you suggested.

@ordabayevy ordabayevy requested a review from fritzo March 20, 2021 01:38
Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎉

@fritzo fritzo merged commit 4446a3d into pyro-ppl:master Mar 20, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
awaiting review enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants