Skip to content

Commit

Permalink
Fixed get_shape
Browse files Browse the repository at this point in the history
Used code from waleedka#83 .
  • Loading branch information
Strasser-Pablo authored Mar 31, 2022
1 parent 45243d5 commit 847fc6b
Showing 1 changed file with 4 additions and 7 deletions.
11 changes: 4 additions & 7 deletions hiddenlayer/pytorch_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,10 @@ def get_shape(torch_node):
# https://discuss.pytorch.org/t/node-output-shape-from-trace-graph/24351/2
# TODO: find a better way to extract output shape
# TODO: Assuming the node has one output. Update if we encounter a multi-output node.
m = re.match(r".*Float\(([\d\s\,]+)\).*", str(next(torch_node.outputs())))
if m:
shape = m.group(1)
shape = shape.split(",")
shape = tuple(map(int, shape))
else:
shape = None
try:
shape = torch_node.output().type().sizes()
except:
shape = None
return shape


Expand Down

0 comments on commit 847fc6b

Please sign in to comment.