Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace null -> ops.null in a few places #498

Merged
merged 1 commit into from
Mar 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions funsor/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections import defaultdict
from collections.abc import Hashable

from funsor.cnf import Contraction, null
from funsor.cnf import Contraction
from funsor.interpretations import Interpretation, reflect
from funsor.interpreter import stack_reinterpret
from funsor.ops import AssociativeOp
Expand Down Expand Up @@ -233,7 +233,7 @@ def adjoint_contract_generic(
def adjoint_contract(
adj_sum_op, adj_prod_op, out_adj, sum_op, prod_op, reduced_vars, lhs, rhs
):
if prod_op is adj_prod_op and sum_op in (null, adj_sum_op):
if prod_op is adj_prod_op and sum_op in (ops.null, adj_sum_op):

# the only change is here:
out_adj = Approximate(
Expand Down
34 changes: 17 additions & 17 deletions funsor/cnf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from funsor.gaussian import Gaussian
from funsor.interpretations import eager, normalize, reflect
from funsor.interpreter import children, recursion_reinterpret
from funsor.ops import DISTRIBUTIVE_OPS, AssociativeOp, NullOp, null
from funsor.ops import DISTRIBUTIVE_OPS, AssociativeOp, NullOp
from funsor.tensor import Tensor
from funsor.terms import (
_INFIX,
Expand Down Expand Up @@ -69,7 +69,7 @@ def __init__(self, red_op, bin_op, reduced_vars, terms):
for v in terms:
inputs.update((k, d) for k, d in v.inputs.items() if k not in bound)

if bin_op is null:
if bin_op is ops.null:
output = terms[0].output
else:
output = reduce(
Expand Down Expand Up @@ -107,7 +107,7 @@ def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None):
if not sampled_vars:
return self

if self.red_op in (ops.logaddexp, null):
if self.red_op in (ops.null, ops.logaddexp):
if self.bin_op in (ops.null, ops.logaddexp):
if rng_key is not None and get_backend() == "jax":
import jax
Expand Down Expand Up @@ -277,7 +277,7 @@ def eager_contraction_generic_recursive(red_op, bin_op, reduced_vars, terms):
if unique_vars:
result = term.reduce(red_op, unique_vars)
if result is not normalize.interpret(
Contraction, red_op, null, unique_vars, (term,)
Contraction, red_op, ops.null, unique_vars, (term,)
):
terms[i] = result
reduced_vars -= unique_vars
Expand Down Expand Up @@ -432,7 +432,7 @@ def normalize_contraction_commutative_canonical_order(
)
def normalize_contraction_commute_joint(red_op, bin_op, reduced_vars, mixture, other):
return Contraction(
mixture.red_op if red_op is null else red_op,
mixture.red_op if red_op is ops.null else red_op,
bin_op,
reduced_vars | mixture.reduced_vars,
*(mixture.terms + (other,))
Expand All @@ -444,7 +444,7 @@ def normalize_contraction_commute_joint(red_op, bin_op, reduced_vars, mixture, o
)
def normalize_contraction_commute_joint(red_op, bin_op, reduced_vars, other, mixture):
return Contraction(
mixture.red_op if red_op is null else red_op,
mixture.red_op if red_op is ops.null else red_op,
bin_op,
reduced_vars | mixture.reduced_vars,
*(mixture.terms + (other,))
Expand All @@ -467,13 +467,13 @@ def normalize_trivial(red_op, bin_op, reduced_vars, term):
@normalize.register(Contraction, AssociativeOp, AssociativeOp, frozenset, tuple)
def normalize_contraction_generic_tuple(red_op, bin_op, reduced_vars, terms):

if not reduced_vars and red_op is not null:
return Contraction(null, bin_op, reduced_vars, *terms)
if not reduced_vars and red_op is not ops.null:
return Contraction(ops.null, bin_op, reduced_vars, *terms)

if len(terms) == 1 and bin_op is not null:
return Contraction(red_op, null, reduced_vars, *terms)
if len(terms) == 1 and bin_op is not ops.null:
return Contraction(red_op, ops.null, reduced_vars, *terms)

if red_op is null and bin_op is null:
if red_op is ops.null and bin_op is ops.null:
return terms[0]

if red_op is bin_op:
Expand All @@ -498,11 +498,11 @@ def normalize_contraction_generic_tuple(red_op, bin_op, reduced_vars, terms):
continue

# fuse operations without distributing
if (v.red_op is null and bin_op is v.bin_op) or (
bin_op is null and v.red_op in (red_op, null)
if (v.red_op is ops.null and bin_op is v.bin_op) or (
bin_op is ops.null and v.red_op in (red_op, ops.null)
):
red_op = v.red_op if red_op is null else red_op
bin_op = v.bin_op if bin_op is null else bin_op
red_op = v.red_op if red_op is ops.null else red_op
bin_op = v.bin_op if bin_op is ops.null else bin_op
new_terms = terms[:i] + v.terms + terms[i + 1 :]
return Contraction(
red_op, bin_op, reduced_vars | v.reduced_vars, *new_terms
Expand All @@ -519,12 +519,12 @@ def normalize_contraction_generic_tuple(red_op, bin_op, reduced_vars, terms):

@normalize.register(Binary, AssociativeOp, Funsor, Funsor)
def binary_to_contract(op, lhs, rhs):
return Contraction(null, op, frozenset(), lhs, rhs)
return Contraction(ops.null, op, frozenset(), lhs, rhs)


@normalize.register(Reduce, AssociativeOp, Funsor, frozenset)
def reduce_funsor(op, arg, reduced_vars):
return Contraction(op, null, reduced_vars, arg)
return Contraction(op, ops.null, reduced_vars, arg)


@normalize.register(
Expand Down
21 changes: 12 additions & 9 deletions funsor/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from opt_einsum.paths import greedy

import funsor.interpreter as interpreter
from funsor.cnf import Contraction, null
from funsor.cnf import Contraction
from funsor.interpretations import (
DispatchedInterpretation,
PrioritizedInterpretation,
Expand All @@ -19,6 +19,8 @@
from funsor.terms import Funsor
from funsor.typing import Variadic

from . import ops

unfold_base = DispatchedInterpretation()
unfold = PrioritizedInterpretation(unfold_base, normalize_base, lazy)

Expand All @@ -31,7 +33,7 @@ def unfold_contraction_generic_tuple(red_op, bin_op, reduced_vars, terms):
if not isinstance(v, Contraction):
continue

if v.red_op is null and (v.bin_op, bin_op) in DISTRIBUTIVE_OPS:
if v.red_op is ops.null and (v.bin_op, bin_op) in DISTRIBUTIVE_OPS:
# a * e * (b + c + d) -> (a * e * b) + (a * e * c) + (a * e * d)
new_terms = tuple(
Contraction(
Expand All @@ -44,7 +46,7 @@ def unfold_contraction_generic_tuple(red_op, bin_op, reduced_vars, terms):
)
return Contraction(red_op, v.bin_op, reduced_vars, *new_terms)

if red_op in (v.red_op, null) and (v.red_op, bin_op) in DISTRIBUTIVE_OPS:
if red_op in (v.red_op, ops.null) and (v.red_op, bin_op) in DISTRIBUTIVE_OPS:
new_terms = (
terms[:i]
+ (Contraction(v.red_op, v.bin_op, frozenset(), *v.terms),)
Expand All @@ -54,9 +56,9 @@ def unfold_contraction_generic_tuple(red_op, bin_op, reduced_vars, terms):
red_op, reduced_vars
)

if v.red_op in (red_op, null) and bin_op in (v.bin_op, null):
red_op = v.red_op if red_op is null else red_op
bin_op = v.bin_op if bin_op is null else bin_op
if v.red_op in (red_op, ops.null) and bin_op in (v.bin_op, ops.null):
red_op = v.red_op if red_op is ops.null else red_op
bin_op = v.bin_op if bin_op is ops.null else bin_op
new_terms = terms[:i] + v.terms + terms[i + 1 :]
return Contraction(
red_op, bin_op, reduced_vars | v.reduced_vars, *new_terms
Expand Down Expand Up @@ -93,8 +95,9 @@ def eager_contract_base(red_op, bin_op, reduced_vars, *terms):

@optimize.register(Contraction, AssociativeOp, AssociativeOp, frozenset, tuple)
def optimize_contract_finitary_funsor(red_op, bin_op, reduced_vars, terms):

if red_op is null or bin_op is null or not (red_op, bin_op) in DISTRIBUTIVE_OPS:
if red_op is ops.null or bin_op is ops.null:
return None
if (red_op, bin_op) not in DISTRIBUTIVE_OPS:
return None

# build opt_einsum optimizer IR
Expand Down Expand Up @@ -140,7 +143,7 @@ def optimize_contract_finitary_funsor(red_op, bin_op, reduced_vars, terms):
)

path_end = Contraction(
red_op if path_end_reduced_vars else null,
red_op if path_end_reduced_vars else ops.null,
bin_op,
path_end_reduced_vars,
ta,
Expand Down