From 317a286cf39f24bc4cda89e44fec8df3e4a3dba8 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 17 Mar 2021 20:55:00 -0400 Subject: [PATCH] Replace null -> ops.null in a few places --- funsor/adjoint.py | 4 ++-- funsor/cnf.py | 34 +++++++++++++++++----------------- funsor/optimizer.py | 21 ++++++++++++--------- 3 files changed, 31 insertions(+), 28 deletions(-) diff --git a/funsor/adjoint.py b/funsor/adjoint.py index 1df57a406..65b1565a5 100644 --- a/funsor/adjoint.py +++ b/funsor/adjoint.py @@ -4,7 +4,7 @@ from collections import defaultdict from collections.abc import Hashable -from funsor.cnf import Contraction, null +from funsor.cnf import Contraction from funsor.interpretations import Interpretation, reflect from funsor.interpreter import stack_reinterpret from funsor.ops import AssociativeOp @@ -233,7 +233,7 @@ def adjoint_contract_generic( def adjoint_contract( adj_sum_op, adj_prod_op, out_adj, sum_op, prod_op, reduced_vars, lhs, rhs ): - if prod_op is adj_prod_op and sum_op in (null, adj_sum_op): + if prod_op is adj_prod_op and sum_op in (ops.null, adj_sum_op): # the only change is here: out_adj = Approximate( diff --git a/funsor/cnf.py b/funsor/cnf.py index b42b08ee5..1ddb3aaa2 100644 --- a/funsor/cnf.py +++ b/funsor/cnf.py @@ -17,7 +17,7 @@ from funsor.gaussian import Gaussian from funsor.interpretations import eager, normalize, reflect from funsor.interpreter import children, recursion_reinterpret -from funsor.ops import DISTRIBUTIVE_OPS, AssociativeOp, NullOp, null +from funsor.ops import DISTRIBUTIVE_OPS, AssociativeOp, NullOp from funsor.tensor import Tensor from funsor.terms import ( _INFIX, @@ -69,7 +69,7 @@ def __init__(self, red_op, bin_op, reduced_vars, terms): for v in terms: inputs.update((k, d) for k, d in v.inputs.items() if k not in bound) - if bin_op is null: + if bin_op is ops.null: output = terms[0].output else: output = reduce( @@ -107,7 +107,7 @@ def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): if not sampled_vars: return self - if self.red_op in (ops.logaddexp, null): + if self.red_op in (ops.null, ops.logaddexp): if self.bin_op in (ops.null, ops.logaddexp): if rng_key is not None and get_backend() == "jax": import jax @@ -277,7 +277,7 @@ def eager_contraction_generic_recursive(red_op, bin_op, reduced_vars, terms): if unique_vars: result = term.reduce(red_op, unique_vars) if result is not normalize.interpret( - Contraction, red_op, null, unique_vars, (term,) + Contraction, red_op, ops.null, unique_vars, (term,) ): terms[i] = result reduced_vars -= unique_vars @@ -432,7 +432,7 @@ def normalize_contraction_commutative_canonical_order( ) def normalize_contraction_commute_joint(red_op, bin_op, reduced_vars, mixture, other): return Contraction( - mixture.red_op if red_op is null else red_op, + mixture.red_op if red_op is ops.null else red_op, bin_op, reduced_vars | mixture.reduced_vars, *(mixture.terms + (other,)) @@ -444,7 +444,7 @@ def normalize_contraction_commute_joint(red_op, bin_op, reduced_vars, mixture, o ) def normalize_contraction_commute_joint(red_op, bin_op, reduced_vars, other, mixture): return Contraction( - mixture.red_op if red_op is null else red_op, + mixture.red_op if red_op is ops.null else red_op, bin_op, reduced_vars | mixture.reduced_vars, *(mixture.terms + (other,)) @@ -467,13 +467,13 @@ def normalize_trivial(red_op, bin_op, reduced_vars, term): @normalize.register(Contraction, AssociativeOp, AssociativeOp, frozenset, tuple) def normalize_contraction_generic_tuple(red_op, bin_op, reduced_vars, terms): - if not reduced_vars and red_op is not null: - return Contraction(null, bin_op, reduced_vars, *terms) + if not reduced_vars and red_op is not ops.null: + return Contraction(ops.null, bin_op, reduced_vars, *terms) - if len(terms) == 1 and bin_op is not null: - return Contraction(red_op, null, reduced_vars, *terms) + if len(terms) == 1 and bin_op is not ops.null: + return Contraction(red_op, ops.null, reduced_vars, *terms) - if red_op is null and bin_op is null: + if red_op is ops.null and bin_op is ops.null: return terms[0] if red_op is bin_op: @@ -498,11 +498,11 @@ def normalize_contraction_generic_tuple(red_op, bin_op, reduced_vars, terms): continue # fuse operations without distributing - if (v.red_op is null and bin_op is v.bin_op) or ( - bin_op is null and v.red_op in (red_op, null) + if (v.red_op is ops.null and bin_op is v.bin_op) or ( + bin_op is ops.null and v.red_op in (red_op, ops.null) ): - red_op = v.red_op if red_op is null else red_op - bin_op = v.bin_op if bin_op is null else bin_op + red_op = v.red_op if red_op is ops.null else red_op + bin_op = v.bin_op if bin_op is ops.null else bin_op new_terms = terms[:i] + v.terms + terms[i + 1 :] return Contraction( red_op, bin_op, reduced_vars | v.reduced_vars, *new_terms @@ -519,12 +519,12 @@ def normalize_contraction_generic_tuple(red_op, bin_op, reduced_vars, terms): @normalize.register(Binary, AssociativeOp, Funsor, Funsor) def binary_to_contract(op, lhs, rhs): - return Contraction(null, op, frozenset(), lhs, rhs) + return Contraction(ops.null, op, frozenset(), lhs, rhs) @normalize.register(Reduce, AssociativeOp, Funsor, frozenset) def reduce_funsor(op, arg, reduced_vars): - return Contraction(op, null, reduced_vars, arg) + return Contraction(op, ops.null, reduced_vars, arg) @normalize.register( diff --git a/funsor/optimizer.py b/funsor/optimizer.py index b401a1cdd..ed5975652 100644 --- a/funsor/optimizer.py +++ b/funsor/optimizer.py @@ -6,7 +6,7 @@ from opt_einsum.paths import greedy import funsor.interpreter as interpreter -from funsor.cnf import Contraction, null +from funsor.cnf import Contraction from funsor.interpretations import ( DispatchedInterpretation, PrioritizedInterpretation, @@ -19,6 +19,8 @@ from funsor.terms import Funsor from funsor.typing import Variadic +from . import ops + unfold_base = DispatchedInterpretation() unfold = PrioritizedInterpretation(unfold_base, normalize_base, lazy) @@ -31,7 +33,7 @@ def unfold_contraction_generic_tuple(red_op, bin_op, reduced_vars, terms): if not isinstance(v, Contraction): continue - if v.red_op is null and (v.bin_op, bin_op) in DISTRIBUTIVE_OPS: + if v.red_op is ops.null and (v.bin_op, bin_op) in DISTRIBUTIVE_OPS: # a * e * (b + c + d) -> (a * e * b) + (a * e * c) + (a * e * d) new_terms = tuple( Contraction( @@ -44,7 +46,7 @@ def unfold_contraction_generic_tuple(red_op, bin_op, reduced_vars, terms): ) return Contraction(red_op, v.bin_op, reduced_vars, *new_terms) - if red_op in (v.red_op, null) and (v.red_op, bin_op) in DISTRIBUTIVE_OPS: + if red_op in (v.red_op, ops.null) and (v.red_op, bin_op) in DISTRIBUTIVE_OPS: new_terms = ( terms[:i] + (Contraction(v.red_op, v.bin_op, frozenset(), *v.terms),) @@ -54,9 +56,9 @@ def unfold_contraction_generic_tuple(red_op, bin_op, reduced_vars, terms): red_op, reduced_vars ) - if v.red_op in (red_op, null) and bin_op in (v.bin_op, null): - red_op = v.red_op if red_op is null else red_op - bin_op = v.bin_op if bin_op is null else bin_op + if v.red_op in (red_op, ops.null) and bin_op in (v.bin_op, ops.null): + red_op = v.red_op if red_op is ops.null else red_op + bin_op = v.bin_op if bin_op is ops.null else bin_op new_terms = terms[:i] + v.terms + terms[i + 1 :] return Contraction( red_op, bin_op, reduced_vars | v.reduced_vars, *new_terms @@ -93,8 +95,9 @@ def eager_contract_base(red_op, bin_op, reduced_vars, *terms): @optimize.register(Contraction, AssociativeOp, AssociativeOp, frozenset, tuple) def optimize_contract_finitary_funsor(red_op, bin_op, reduced_vars, terms): - - if red_op is null or bin_op is null or not (red_op, bin_op) in DISTRIBUTIVE_OPS: + if red_op is ops.null or bin_op is ops.null: + return None + if (red_op, bin_op) not in DISTRIBUTIVE_OPS: return None # build opt_einsum optimizer IR @@ -140,7 +143,7 @@ def optimize_contract_finitary_funsor(red_op, bin_op, reduced_vars, terms): ) path_end = Contraction( - red_op if path_end_reduced_vars else null, + red_op if path_end_reduced_vars else ops.null, bin_op, path_end_reduced_vars, ta,