diff --git a/visualdl/component/graph/exporter.py b/visualdl/component/graph/exporter.py index bd6d3e25..c04fdadb 100644 --- a/visualdl/component/graph/exporter.py +++ b/visualdl/component/graph/exporter.py @@ -15,7 +15,6 @@ import json import os import tempfile -import paddle from .graph_component import analyse_model from .graph_component import analyse_pir @@ -24,6 +23,13 @@ def translate_graph(model, input_spec, verbose=True, **kwargs): + try: + import paddle + except Exception: + print("Paddlepaddle is required to use add_graph interface.\n\ + Please refer to \ + https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html\ + to install paddlepaddle.") is_pir = kwargs.get('is_pir', False) with tempfile.TemporaryDirectory() as tmp: if (not is_pir): diff --git a/visualdl/component/graph/graph_component.py b/visualdl/component/graph/graph_component.py index d432e8fe..775da25c 100644 --- a/visualdl/component/graph/graph_component.py +++ b/visualdl/component/graph/graph_component.py @@ -16,7 +16,6 @@ import os.path import pathlib import re -import paddle from . import utils @@ -444,9 +443,16 @@ def safe_get_persistable(op): def get_sub_ops(op, op_name, all_ops, all_vars): + try: + from paddle.utils.unique_name import generate + except Exception: + print("Paddlepaddle is required to use add_graph interface.\n\ + Please refer to \ + https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html\ + to install paddlepaddle.") for sub_block in op.blocks(): for sub_op in sub_block.ops: - sub_op_name0 = paddle.utils.unique_name.generate(sub_op.name()) + sub_op_name0 = generate(sub_op.name()) sub_op_name = op_name + '/' + sub_op_name0 all_ops[sub_op_name] = {} all_ops[sub_op_name]['name'] = sub_op_name @@ -561,7 +567,13 @@ def update_node_connections(all_vars, all_ops): def analyse_pir(program): - from paddle.utils.unique_name import generate + try: + from paddle.utils.unique_name import generate + except Exception: + print("Paddlepaddle is required to use add_graph interface.\n\ + Please refer to \ + https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html\ + to install paddlepaddle.") all_ops = {} all_vars = {} diff --git a/visualdl/reader/graph_reader.py b/visualdl/reader/graph_reader.py index e1b1df83..4e8035bd 100644 --- a/visualdl/reader/graph_reader.py +++ b/visualdl/reader/graph_reader.py @@ -20,7 +20,6 @@ from visualdl.component.graph import analyse_pir from visualdl.component.graph import Model from visualdl.io import bfile -from paddle.jit import load def is_VDLGraph_file(path): @@ -140,6 +139,13 @@ def get_graph(self, if 'pdmodel' in self.walks[run]: graph_model = Model(analyse_model(data)) elif 'json' in self.walks[run]: + try: + from paddle.jit import load + except Exception: + print("Paddlepaddle is required to load json file.\n\ + Please refer to \ + https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html\ + to install paddlepaddle.") json_object = json.loads(data) with tempfile.TemporaryDirectory() as tmp: with open(os.path.join(tmp, 'temp.json'), 'w') as json_file: @@ -174,6 +180,13 @@ def search_graph_node(self, run, nodeid, keep_state=False, is_node=True): if 'pdmodel' in self.walks[run]: graph_model = Model(analyse_model(data)) elif 'json' in self.walks[run]: + try: + from paddle.jit import load + except Exception: + print("Paddlepaddle is required to load json file.\n\ + Please refer to \ + https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html\ + to install paddlepaddle.") json_object = json.loads(data) with tempfile.TemporaryDirectory() as tmp: with open(os.path.join(tmp, 'temp.json'), 'w') as json_file: @@ -202,6 +215,13 @@ def get_all_nodes(self, run): if 'pdmodel' in self.walks[run]: graph_model = Model(analyse_model(data)) elif 'json' in self.walks[run]: + try: + from paddle.jit import load + except Exception: + print("Paddlepaddle is required to load json file.\n\ + Please refer to \ + https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html\ + to install paddlepaddle.") json_object = json.loads(data) with tempfile.TemporaryDirectory() as tmp: with open(os.path.join(tmp, 'temp.json'), 'w') as json_file: @@ -241,6 +261,13 @@ def set_input_graph(self, content, file_type='pdmodel'): self.graph_buffer['manual_input_model'] = Model(data) elif file_type == 'json': + try: + from paddle.jit import load + except Exception: + print("Paddlepaddle is required to load json file.\n\ + Please refer to \ + https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html\ + to install paddlepaddle.") json_object = json.loads(content) with tempfile.TemporaryDirectory() as tmp: with open(os.path.join(tmp, 'temp.json'), 'w') as json_file: