forked from nouiz/ccw_tutorial_theano
-
Notifications
You must be signed in to change notification settings - Fork 14
/
opt.py
29 lines (23 loc) · 851 Bytes
/
opt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from scalmulop import ScalMulV1
from doubleop import DoubleOp
from doublecop import DoubleCOp
from doublec import DoubleC
from doublecgpu import DoubleCGpu
from theano.gof import local_optimizer
from theano.tensor.opt import register_specialize
from theano.gpuarray.opt import (register_opt, op_lifter,
register_opt2)
@register_specialize
@local_optimizer([ScalMulV1])
def local_scalmul_double(node):
if not (isinstance(node.op, ScalMulV1) and
node.op.scal == 2):
return False
return [DoubleOp()(node.inputs[0])]
@register_opt('fast_compile')
@op_lifter([DoubleOp, DoubleC, DoubleCOp])
@register_opt2([DoubleOp, DoubleC, DoubleCOp],
'fast_compile')
def local_scalmul_double_gpu(op, context_name, inputs,
outputs):
return DoubleCGpu