From 847fc6baf0e95d92c68eaa904039513f74274793 Mon Sep 17 00:00:00 2001 From: Strasser-Pablo Date: Thu, 31 Mar 2022 18:19:34 +0200 Subject: [PATCH] Fixed get_shape Used code from https://github.com/waleedka/hiddenlayer/issues/83 . --- hiddenlayer/pytorch_builder.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/hiddenlayer/pytorch_builder.py b/hiddenlayer/pytorch_builder.py index 702c167..9af958b 100644 --- a/hiddenlayer/pytorch_builder.py +++ b/hiddenlayer/pytorch_builder.py @@ -53,13 +53,10 @@ def get_shape(torch_node): # https://discuss.pytorch.org/t/node-output-shape-from-trace-graph/24351/2 # TODO: find a better way to extract output shape # TODO: Assuming the node has one output. Update if we encounter a multi-output node. - m = re.match(r".*Float\(([\d\s\,]+)\).*", str(next(torch_node.outputs()))) - if m: - shape = m.group(1) - shape = shape.split(",") - shape = tuple(map(int, shape)) - else: - shape = None + try: + shape = torch_node.output().type().sizes() + except: + shape = None return shape