Skip to content

Commit

Permalink
Use direct theano.gof imports in theano.scan.op
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Jan 3, 2021
1 parent 9c445e0 commit d5f6550
Showing 1 changed file with 25 additions and 23 deletions.
48 changes: 25 additions & 23 deletions theano/scan/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,23 +53,27 @@
import numpy as np

import theano
from theano import compile, gof, gradient, tensor
from theano import tensor
from theano.compile.builders import infer_shape
from theano.compile.function import function
from theano.compile.io import In, Out
from theano.compile.mode import AddFeatureOptimizer
from theano.compile.profiling import ScanProfileStats
from theano.compile.mode import AddFeatureOptimizer, get_mode
from theano.compile.profiling import ScanProfileStats, register_profiler_printer
from theano.configdefaults import config
from theano.gof import Apply, Op
from theano.gof.graph import equal_computations, io_connection_pattern
from theano.gof.fg import MissingInputError
from theano.gof.graph import Apply, Variable, equal_computations
from theano.gof.graph import inputs as graph_inputs
from theano.gof.graph import io_connection_pattern
from theano.gof.op import Op, ops_with_inner_function
from theano.gof.toolbox import NoOutputFromInplace
from theano.gradient import DisconnectedType, NullType, grad_undefined
from theano.gradient import DisconnectedType, NullType, grad, grad_undefined
from theano.link.c.basic import CLinker
from theano.link.c.exceptions import MissingGXX
from theano.link.utils import raise_with_op
from theano.scan.utils import Validator, forced_replace, hash_listsDictsTuples, safe_new
from theano.tensor import TensorType, as_tensor_variable
from theano.tensor.basic import as_tensor_variable
from theano.tensor.opt import Shape_i
from theano.tensor.type import TensorType


__docformat__ = "restructedtext en"
Expand Down Expand Up @@ -169,7 +173,7 @@ def tensorConstructor(broadcastable, dtype):
if self.as_while:
self.output_types = self.output_types[:-1]

mode_instance = compile.mode.get_mode(self.mode)
mode_instance = get_mode(self.mode)
# Clone mode_instance, altering "allow_gc" for the linker,
# and adding a message if we profile
if self.name:
Expand Down Expand Up @@ -202,11 +206,9 @@ def tensorConstructor(broadcastable, dtype):
self._hash_inner_graph = self.info["gpu_hash"]
else:
# Do the missing inputs check here to have the error early.
for var in theano.gof.graph.inputs(self.outputs, self.inputs):
for var in graph_inputs(self.outputs, self.inputs):
if var not in self.inputs and not isinstance(var, theano.Constant):
raise theano.gof.MissingInputError(
f"ScanOp is missing an input: {repr(var)}"
)
raise MissingInputError(f"ScanOp is missing an input: {repr(var)}")
self._cmodule_key = CLinker().cmodule_key_variables(
self.inputs, self.outputs, []
)
Expand Down Expand Up @@ -317,7 +319,7 @@ def make_node(self, *inputs):
the inner function)
"""
assert np.all(isinstance(i, gof.Variable) for i in inputs)
assert np.all(isinstance(i, Variable) for i in inputs)
# Check that the number of inputs to the Scan node corresponds to
# the number of inputs of the inner function of scan
n_outer_ins = len(inputs) - len(self.outer_nitsot(inputs)) - 1
Expand Down Expand Up @@ -2173,7 +2175,7 @@ def compute_all_gradients(known_grads):

wrt = [
x
for x in theano.gof.graph.inputs(y_s)
for x in graph_inputs(y_s)
if (x in diff_inputs)
and get_inp_idx(self_inputs.index(x)) in connected_inputs
]
Expand All @@ -2188,7 +2190,7 @@ def compute_all_gradients(known_grads):
# to X.
known_grads = OrderedDict([(k.copy(), v) for (k, v) in known_grads.items()])

grads = gradient.grad(
grads = grad(
cost=None,
known_grads=known_grads,
wrt=wrt,
Expand Down Expand Up @@ -2238,7 +2240,7 @@ def compute_all_gradients(known_grads):
)

for pos, inp in enumerate(states):
if inp in theano.gof.graph.inputs([Xt]):
if inp in graph_inputs([Xt]):
# Get the index of the outer output that to which
# the state variable 'inp' corresponds.
outer_oidx = self.var_mappings["outer_out_from_inner_inp"][
Expand Down Expand Up @@ -2456,7 +2458,7 @@ def compute_all_gradients(known_grads):
disconnected = False

for _sh in self.inner_shared(self_inputs):
if _sh in gof.graph.inputs([dC_dinps_t[ins_pos]]):
if _sh in graph_inputs([dC_dinps_t[ins_pos]]):
through_shared = True

ins_pos += 1
Expand Down Expand Up @@ -2511,7 +2513,7 @@ def compute_all_gradients(known_grads):
if not disconnected_dC_dinps_t[ins_pos]:
disconnected = False
for _sh in self.inner_shared(self_inputs):
if _sh in gof.graph.inputs([dC_dinps_t[ins_pos]]):
if _sh in graph_inputs([dC_dinps_t[ins_pos]]):
through_shared = True

n_mitmot_inps += 1
Expand Down Expand Up @@ -2559,7 +2561,7 @@ def compute_all_gradients(known_grads):
inner_out_mitmot.append(dC_dinps_t[ins_pos])

for _sh in self.inner_shared(self_inputs):
if _sh in gof.graph.inputs([dC_dinps_t[ins_pos]]):
if _sh in graph_inputs([dC_dinps_t[ins_pos]]):
through_shared = True

if isinstance(dC_dinps_t[ins_pos].type, NullType):
Expand All @@ -2583,7 +2585,7 @@ def compute_all_gradients(known_grads):
for _p, vl in enumerate(inner_out_sitsot):
through_shared = False
for _sh in self.inner_shared(self_inputs):
if _sh in gof.graph.inputs([vl]):
if _sh in graph_inputs([vl]):
through_shared = True
if isinstance(vl.type, NullType):
type_outs.append(vl.type.why_null)
Expand All @@ -2602,7 +2604,7 @@ def compute_all_gradients(known_grads):
for _p, vl in enumerate(inner_out_nitsot):
through_shared = False
for _sh in self.inner_shared(self_inputs):
if _sh in gof.graph.inputs([vl]):
if _sh in graph_inputs([vl]):
through_shared = True
if isinstance(vl.type, NullType):
type_outs.append(vl.type.why_null)
Expand Down Expand Up @@ -3043,10 +3045,10 @@ def R_op(self, inputs, eval_points):

# Since Scan is an op that contains a Theano compiled function, it is
# useful to let DebugMode know about it.
gof.ops_with_inner_function[Scan] = "fn"
ops_with_inner_function[Scan] = "fn"


@theano.compile.profiling.register_profiler_printer
@register_profiler_printer
def profile_printer(
message, compile_time, fct_call_time, apply_time, apply_cimpl, outputs_size, file
):
Expand Down

0 comments on commit d5f6550

Please sign in to comment.