Skip to content

Commit

Permalink
Fix quant fused conv op snippet bug
Browse files Browse the repository at this point in the history
  • Loading branch information
dboyliao committed Jul 11, 2019
1 parent f56b40b commit 40f846f
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 8 deletions.
4 changes: 2 additions & 2 deletions utensor_cgen/backend/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,12 +504,12 @@ def __init__(self, op_info, **kwargs):
padding = op_info.op_attr['_utensor_conv']["padding"].value.decode('utf8')
parser = NamescopedKWArgsParser(RefCntOptimizer.KWARGS_NAMESCOPE,
op_info.op_attr)
ref_count = parser.get('ref_counts', [0])[0]
ref_counts = parser.get('ref_counts', None)
to_eval = parser.get('to_eval', False)
self._snippet = QuantizedFusedConv2DMaxpoolOpSnippet(
inputs, outputs, strides, ksize, padding,
in_dtype=in_dtype, filter_dtype=filter_dtype, out_dtypes=out_dtypes,
ref_count=ref_count, to_eval=to_eval
ref_counts=ref_counts, to_eval=to_eval
)


Expand Down
11 changes: 6 additions & 5 deletions utensor_cgen/backend/snippets/_snippets.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,18 +671,19 @@ class QuantizedFusedConv2DMaxpoolOpSnippet(Snippet):

def __init__(self, inputs, outputs, strides, ksize, padding,
in_dtype, filter_dtype, out_dtypes,
ref_count=0,
ref_counts=None,
to_eval=False):
# import pdb; pdb.set_trace()
Snippet.__init__(self)
if ref_count:
self.template_vars["ref_count"] = ref_count
if ref_counts:
self.template_vars["ref_counts"] = ref_counts
self.template_vars["inputs"] = inputs
self.template_vars["outputs"] = outputs
self.template_vars["in_dtype"] = NP_TYPES_MAP[in_dtype].tensor_type_str
self.template_vars["filter_dtype"] = NP_TYPES_MAP[filter_dtype].tensor_type_str
self.template_vars["out_dtypes"] = [
NP_TYPES_MAP[out_dtype].tensor_type_str
for out_dtype in out_dtypes
NP_TYPES_MAP[dtype].tensor_type_str
for dtype in out_dtypes
]
self.template_vars["strides"] = strides
self.template_vars["ksize"] = ksize
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
ctx.add(new RamTensor<{{out_dtypes[1]}}>({1}), "{{outputs[1]}}");
ctx.add(new RamTensor<{{out_dtypes[2]}}>({1}), "{{outputs[2]}}");
{% endif %}
ctx.push(new QuantizedFusedConvMaxpoolOp<{{in_dtype}}, {{filter_dtype}}, {{out_dtype}}>({ {% for s in strides[:-1]%}{{s}}, {%endfor%}{{strides[-1]}} }, { {% for s in ksize[:-1]%}{{s}}, {%endfor%}{{ksize[-1]}} },{{padding}}),
ctx.push(new QuantizedFusedConvMaxpoolOp<{{in_dtype}}, {{filter_dtype}}, {{out_dtypes[0]}}>({ {% for s in strides[:-1]%}{{s}}, {%endfor%}{{strides[-1]}} }, { {% for s in ksize[:-1]%}{{s}}, {%endfor%}{{ksize[-1]}} },{{padding}}),
{ {% for tname in inputs[:-1]%}"{{tname}}", {%endfor%}"{{inputs[-1]}}" },
{ {% for tname in outputs[:-1]%}"{{tname}}", {%endfor%}"{{outputs[-1]}}" });
{% if to_eval %}
Expand Down

0 comments on commit 40f846f

Please sign in to comment.