From 2fb47566af4185336fe5ccccf7d356d46d0e8225 Mon Sep 17 00:00:00 2001 From: anetczuk Date: Sat, 3 Feb 2024 22:17:16 +0100 Subject: [PATCH 1/8] Ignore Eclipse project files --- .gitignore | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/.gitignore b/.gitignore index 990fdc0..ad238c0 100644 --- a/.gitignore +++ b/.gitignore @@ -162,3 +162,10 @@ htmlcov .idea/ .history/ .vscode/ +# Eclipse +/.settings +.project +.pydevproject + + +/tmp/ From a22254f49f1ef08817506b0bf8e2ff1ceb32da31 Mon Sep 17 00:00:00 2001 From: anetczuk Date: Sat, 3 Feb 2024 23:05:05 +0100 Subject: [PATCH 2/8] Fix 'get_module_name()' passing relative paths --- pyan/anutils.py | 6 ++++++ tests/test_anutils.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 34 insertions(+) create mode 100644 tests/test_anutils.py diff --git a/pyan/anutils.py b/pyan/anutils.py index 8063e61..084c611 100644 --- a/pyan/anutils.py +++ b/pyan/anutils.py @@ -32,6 +32,9 @@ def get_module_name(filename, root: str = None): # otherwise it is the filename without extension module_path = filename.replace(".py", "") + # will enter infinite loop or exception in case of non-absolute path (starting with ".") + module_path = os.path.abspath(module_path) + # find the module root - walk up the tree and check if it contains .py files - if yes. it is the new root directories = [(module_path, True)] if root is None: @@ -47,6 +50,9 @@ def get_module_name(filename, root: str = None): else: # root is already known - just walk up until it is matched while directories[0][0] != root: potential_root = os.path.dirname(directories[0][0]) + if potential_root == directories[0][0]: + # root directory reached - stop iteration + break directories.insert(0, (potential_root, True)) mod_name = ".".join([os.path.basename(f[0]) for f in directories]) diff --git a/tests/test_anutils.py b/tests/test_anutils.py new file mode 100644 index 0000000..de4d1df --- /dev/null +++ b/tests/test_anutils.py @@ -0,0 +1,28 @@ +from glob import glob +import logging +import os + +import pytest + +from pyan.anutils import get_module_name + + +def test_get_module_name_filename_not_existing(): + mod_name = get_module_name("just_filename.py") + assert mod_name == "just_filename" + + +def test_get_module_name_absolute(): + mod_name = get_module_name(__file__) + assert mod_name == "test_anutils" + + +def test_get_module_name_absolute_not_existing(): + with pytest.raises(FileNotFoundError) as e_info: + get_module_name("/not/existing/abs_path/mod.py") + + mod_name = get_module_name("/not/existing/mod_dir/mod.py", "mod_dir") + assert mod_name == ".not.existing.mod_dir.mod" + + mod_name = get_module_name("/not/existing/mod_dir/mod.py", "invalid_root") + assert mod_name == ".not.existing.mod_dir.mod" From 81eaa94fd29e8f7eab22e3df84636e9ed8ae5a3d Mon Sep 17 00:00:00 2001 From: anetczuk Date: Sun, 4 Feb 2024 12:02:27 +0100 Subject: [PATCH 3/8] Handle chained imports and uses --- pyan/analyzer.py | 14 ++++++++++++++ tests/test_analyzer.py | 11 +++++++++++ tests/test_code/submodule3.py | 5 +++++ 3 files changed, 30 insertions(+) create mode 100644 tests/test_code/submodule3.py diff --git a/pyan/analyzer.py b/pyan/analyzer.py index 75e9cb1..ce66a88 100644 --- a/pyan/analyzer.py +++ b/pyan/analyzer.py @@ -611,6 +611,14 @@ def analyze_module_import(self, import_item, ast_node): alias_name = mod_node.name self.add_uses_edge(from_node, mod_node) self.logger.info("New edge added for Use import %s in %s" % (mod_node, from_node)) + + curr_scope = self.scope_stack[-1] + if alias_name not in curr_scope.defs: + # seems that symtable module does not handle following "import aaa.bbb" as expected in pyan + # it returns "aaa", but the analyzer expects full symbol "aaa.bbb" + # workaround is to add missing empty symbol + curr_scope.defs[alias_name] = None + self.set_value(alias_name, mod_node) # set node to be discoverable in module self.logger.info("From setting name %s to %s" % (alias_name, mod_node)) @@ -1189,6 +1197,12 @@ def resolve_attribute(self, ast_node): self.logger.debug("Resolved to attr %s of %s" % (ast_node.attr, sc.defs[attr_name])) return sc.defs[attr_name], ast_node.attr + attr_name = get_ast_node_name(ast_node.value) + obj_node = self.get_value(attr_name) # resolves chained calls and imports ("aaa.bbb") if needed + if obj_node is not None: + self.logger.debug("Resolved to attr %s of %s" % (attr_name, obj_node)) + return obj_node, attr_name + # It may happen that ast_node.value has no corresponding graph Node, # if this is a forward-reference, or a reference to a file # not in the analyzed set. diff --git a/tests/test_analyzer.py b/tests/test_analyzer.py index f1e1d57..b02ca0e 100644 --- a/tests/test_analyzer.py +++ b/tests/test_analyzer.py @@ -34,6 +34,12 @@ def test_resolve_import_as(callgraph): get_node(imports, "test_code.subpackage1") +def test_resolve_import(callgraph): + imports = get_in_dict(callgraph.uses_edges, "test_code.submodule3") + get_node(imports, "test_code.subpackage2.submodule_hidden1") + assert len(imports) == 1, "only one effective import" + + def test_import_relative(callgraph): imports = get_in_dict(callgraph.uses_edges, "test_code.subpackage1.submodule1") get_node(imports, "test_code.submodule2.test_2") @@ -50,6 +56,11 @@ def test_resolve_use_in_function(callgraph): get_node(uses, "test_code.submodule1.test_func2") +def test_resolve_use_in_function_02(callgraph): + uses = get_in_dict(callgraph.uses_edges, "test_code.submodule3.test_3") + get_node(uses, "test_code.subpackage2.submodule_hidden1.test_func1") + + def test_resolve_package_without___init__(callgraph): defines = get_in_dict(callgraph.defines_edges, "test_code.subpackage2.submodule_hidden1") get_node(defines, "test_code.subpackage2.submodule_hidden1.test_func1") diff --git a/tests/test_code/submodule3.py b/tests/test_code/submodule3.py new file mode 100644 index 0000000..571015b --- /dev/null +++ b/tests/test_code/submodule3.py @@ -0,0 +1,5 @@ +import test_code.subpackage2.submodule_hidden1 + + +def test_3(): + return test_code.subpackage2.submodule_hidden1.test_func1() From 3add09467823b4a7b54f8d038eeabc70a984b800 Mon Sep 17 00:00:00 2001 From: anetczuk Date: Sun, 4 Feb 2024 15:18:52 +0100 Subject: [PATCH 4/8] Function filter: include callers --- pyan/analyzer.py | 70 ++++++++++++++++++++++++++++++++++++------ pyan/main.py | 13 ++------ tests/test_analyzer.py | 7 +++++ 3 files changed, 70 insertions(+), 20 deletions(-) diff --git a/pyan/analyzer.py b/pyan/analyzer.py index ce66a88..ba48dac 100644 --- a/pyan/analyzer.py +++ b/pyan/analyzer.py @@ -238,6 +238,17 @@ def resolve_imports(self): if len(to_nodes) > 0 } + def filter_data(self, function: Union[None, str] = None, namespace: Union[None, str] = None, max_iter: int = 1000): + if function: + function_name = function.split(".")[-1] + function_namespace = ".".join(function.split(".")[:-1]) + node = self.get_node(function_namespace, function_name) + + else: + node = None + + self.filter(node=node, namespace=namespace) + def filter(self, node: Union[None, Node] = None, namespace: Union[str, None] = None, max_iter: int = 1000): """ filter callgraph nodes that related to `node` or are in `namespace` @@ -304,30 +315,62 @@ def get_related_nodes( # use queue system to search through nodes # essentially add a node to the queue and then search all connected nodes which are in turn added to the queue # until the queue itself is empty or the maximum limit of max_iter searches have been hit + downstream_new_nodes = new_nodes.copy() + downstream_queue = queue.copy() i = max_iter - while len(queue) > 0: - item = queue.pop() - if item not in new_nodes: - new_nodes.add(item) + while len(downstream_queue) > 0: + item = downstream_queue.pop() + if item not in downstream_new_nodes: + downstream_new_nodes.add(item) i -= 1 if i < 0: break - queue.extend( + # add used nodes that are not already added and are in desired namespace + downstream_queue.extend( [ n for n in self.uses_edges.get(item, []) - if n in self.uses_edges and n not in new_nodes and namespace in n.namespace + if n in self.uses_edges and n not in downstream_new_nodes and namespace in n.namespace ] ) - queue.extend( + # add defined nodes that are not already added and are in desired namespace + downstream_queue.extend( [ n for n in self.defines_edges.get(item, []) - if n in self.defines_edges and n not in new_nodes and namespace in n.namespace + if n in self.defines_edges and n not in downstream_new_nodes and namespace in n.namespace ] ) - return new_nodes + # get callers of node + upstream_new_nodes = new_nodes.copy() + upstream_queue = queue.copy() + i = max_iter + while len(upstream_queue) > 0: + item = upstream_queue.pop() + if item not in upstream_new_nodes: + upstream_new_nodes.add(item) + i -= 1 + if i < 0: + break + # add used nodes that are not already added and are in desired namespace + upstream_queue.extend( + [ + n + for n in self.get_callers(self.uses_edges, item) + if n in self.uses_edges and n not in upstream_new_nodes and namespace in n.namespace + ] + ) + # add defined nodes that are not already added and are in desired namespace + upstream_queue.extend( + [ + n + for n in self.get_callers(self.defines_edges, item) + if n in self.defines_edges and n not in upstream_new_nodes and namespace in n.namespace + ] + ) + + return downstream_new_nodes.union(upstream_new_nodes) def visit_Module(self, node): self.logger.debug("Module %s, %s" % (self.module_name, self.filename)) @@ -1776,3 +1819,12 @@ def collapse_inner(self): self.logger.info("Collapsing inner from %s to %s, uses %s" % (n, pn, n2)) self.add_uses_edge(pn, n2) n.defined = False + + + # return list of callers (keys) of given node in edges_dict + def get_callers(self, edges_dict, node): + ret_list = [] + for caller, callees_list in edges_dict.items(): + if node in callees_list: + ret_list.append(caller) + return ret_list diff --git a/pyan/main.py b/pyan/main.py index 5d07971..2650488 100644 --- a/pyan/main.py +++ b/pyan/main.py @@ -43,7 +43,7 @@ def main(cli_args=None): parser.add_argument("--namespace", dest="namespace", help="filter for NAMESPACE", metavar="NAMESPACE", default=None) - parser.add_argument("--function", dest="function", help="filter for FUNCTION", metavar="FUNCTION", default=None) + parser.add_argument("--function", dest="function", help="filter for FUNCTION (generates call subtree)", metavar="FUNCTION", default=None) parser.add_argument("-l", "--log", dest="logname", help="write log to LOG", metavar="LOG") @@ -206,16 +206,7 @@ def main(cli_args=None): v = CallGraphVisitor(filenames, logger=logger, root=root) if known_args.function or known_args.namespace: - - if known_args.function: - function_name = known_args.function.split(".")[-1] - namespace = ".".join(known_args.function.split(".")[:-1]) - node = v.get_node(namespace, function_name) - - else: - node = None - - v.filter(node=node, namespace=known_args.namespace) + v.filter_data(function=known_args.function, namespace=known_args.namespace) graph = VisualGraph.from_visitor(v, options=graph_options, logger=logger) diff --git a/tests/test_analyzer.py b/tests/test_analyzer.py index b02ca0e..d550193 100644 --- a/tests/test_analyzer.py +++ b/tests/test_analyzer.py @@ -73,3 +73,10 @@ def test_resolve_package_with_known_root(): dirname_base = os.path.basename(dirname) defines = get_in_dict(callgraph.defines_edges, f"{dirname_base}.test_code.subpackage2.submodule_hidden1") get_node(defines, f"{dirname_base}.test_code.subpackage2.submodule_hidden1.test_func1") + + +def test_filter_function_parent(callgraph): + callgraph.filter_data(function="test_code.submodule2.test_2") + # get parent of filtered function + uses = get_in_dict(callgraph.uses_edges, "test_code.subpackage1.submodule1.A.__init__") + get_node(uses, "test_code.submodule2.test_2") From a27044fc798fd3e3e830980e1331cd1894d5320d Mon Sep 17 00:00:00 2001 From: anetczuk Date: Sun, 4 Feb 2024 16:25:16 +0100 Subject: [PATCH 5/8] CLI option --no-packages: hide packages representation --- pyan/analyzer.py | 32 ++++++++++++++++++++++++++++++++ pyan/main.py | 11 +++++++++++ 2 files changed, 43 insertions(+) diff --git a/pyan/analyzer.py b/pyan/analyzer.py index ba48dac..e7187ee 100644 --- a/pyan/analyzer.py +++ b/pyan/analyzer.py @@ -1713,6 +1713,38 @@ def remove_wild(self, from_node, to_node, name): self.logger.info("Use from %s to %s resolves %s; removing wildcard" % (from_node, to_node, wild_node)) self.remove_uses_edge(from_node, wild_node) + def remove_node(self, node): + # remove from edges + if node in self.uses_edges: + del self.uses_edges[node] + # remove to edges + for key in list(self.uses_edges.keys()): + items = self.uses_edges[key] + if node in items: + items.remove(node) + if not items: + # empty list + del self.uses_edges[key] + + # remove nodes + for key in list(self.nodes.keys()): + items = self.nodes[key] + if node in items: + items.remove(node) + if not items: + # empty list + del self.nodes[key] + + def remove_packages(self): + node_type = Flavor.MODULE + for key in list(self.nodes.keys()): + if key not in self.nodes: + continue + nodes = self.nodes[key] + for n in nodes.copy(): + if n.flavor in [Flavor.MODULE]: + self.remove_node(n) + ########################################################################### # Postprocessing diff --git a/pyan/main.py b/pyan/main.py index 2650488..f6b2540 100644 --- a/pyan/main.py +++ b/pyan/main.py @@ -93,6 +93,14 @@ def main(cli_args=None): help="do not add edges for 'uses' relationships", ) + parser.add_argument( + "--no-packages", + action="store_false", + default=True, + dest="packages", + help="do not add packages and import relationships", + ) + parser.add_argument( "-c", "--colored", @@ -208,6 +216,9 @@ def main(cli_args=None): if known_args.function or known_args.namespace: v.filter_data(function=known_args.function, namespace=known_args.namespace) + if not known_args.packages: + v.remove_packages() + graph = VisualGraph.from_visitor(v, options=graph_options, logger=logger) writer = None From ce2b7e75bbe6005db7497ff464554d54a5a7e663 Mon Sep 17 00:00:00 2001 From: anetczuk Date: Mon, 5 Feb 2024 00:45:41 +0100 Subject: [PATCH 6/8] Handle class static fields and enum values --- pyan/analyzer.py | 20 +++++++++++++++++++- tests/test_analyzer.py | 23 +++++++++++++++++++++-- tests/test_code/submodule3.py | 5 +++++ tests/test_code/subpackage1/__init__.py | 4 +++- tests/test_code/subpackage1/enum.py | 6 ++++++ tests/test_code/subpackage1/submodule1.py | 10 ++++++++++ 6 files changed, 64 insertions(+), 4 deletions(-) create mode 100644 tests/test_code/subpackage1/enum.py diff --git a/pyan/analyzer.py b/pyan/analyzer.py index e7187ee..0cc765a 100644 --- a/pyan/analyzer.py +++ b/pyan/analyzer.py @@ -185,6 +185,7 @@ def resolve_imports(self): while len(imports_to_resolve) > 0: from_node = imports_to_resolve.pop() if from_node in import_mapping: + # already marked to map - skip continue to_uses = self.uses_edges.get(from_node, set([from_node])) assert len(to_uses) == 1 @@ -218,6 +219,9 @@ def resolve_imports(self): f"{from_node.namespace}.{from_node.name}" == node.namespace and from_node.flavor == Flavor.IMPORTEDITEM ): + if to_node not in self.defines_edges: + self.logger.warning("unable to find define edges from %s", to_node) + continue # use define edges as potential candidates for candidate_to_node in self.defines_edges[to_node]: # if candidate_to_node.name == node.name: @@ -752,7 +756,7 @@ def visit_Attribute(self, node): ) if self.add_uses_edge(from_node, to_node): self.logger.info( - "New edge added for Use from {from_node} to {to_node} (target obj {obj_node} known but " + f"New edge added for Use from {from_node} to {to_node} (target obj {obj_node} known but " f"target attr {node.attr} not resolved; maybe fwd ref or unanalyzed import)" ) @@ -769,6 +773,19 @@ def visit_Attribute(self, node): def visit_Name(self, node): self.logger.debug("Name %s in context %s, %s:%s" % (node.id, type(node.ctx), self.filename, node.lineno)) + in_class_ns = self.context_stack[-1].startswith("ClassDef") + if in_class_ns: + # add enum value or class static field + #TODO: there should be additional Flavor and proper handling of static fields + tgt_name = node.id + val_node = self.get_value(tgt_name) # resolves "self" if needed + if val_node is None and tgt_name not in ["staticmethod", "classmethod"]: + # static field case + from_node = self.get_node_of_current_namespace() + ns = from_node.get_name() + to_node = self.get_node(ns, tgt_name, node, flavor=Flavor.ATTRIBUTE) + self.add_defines_edge(from_node, to_node) + # TODO: self.last_value is a hack. Handle names in store context (LHS) # in analyze_binding(), so that visit_Name() only needs to handle # the load context (i.e. detect uses of the name). @@ -1593,6 +1610,7 @@ def add_defines_edge(self, from_node, to_node): status = True from_node.defined = True if to_node is None or to_node in self.defines_edges[from_node]: + # edge already defined return status self.defines_edges[from_node].add(to_node) to_node.defined = True diff --git a/tests/test_analyzer.py b/tests/test_analyzer.py index d550193..450dbda 100644 --- a/tests/test_analyzer.py +++ b/tests/test_analyzer.py @@ -20,6 +20,11 @@ def get_node(nodes, name): return filtered_nodes[0] +def contains_node(nodes, name) -> bool: + filtered_nodes = [node for node in nodes if node.get_name() == name] + return len(filtered_nodes) > 0 + + def get_in_dict(node_dict, name): return node_dict[get_node(node_dict.keys(), name)] @@ -36,8 +41,7 @@ def test_resolve_import_as(callgraph): def test_resolve_import(callgraph): imports = get_in_dict(callgraph.uses_edges, "test_code.submodule3") - get_node(imports, "test_code.subpackage2.submodule_hidden1") - assert len(imports) == 1, "only one effective import" + assert contains_node(imports, "test_code.subpackage2.submodule_hidden1") def test_import_relative(callgraph): @@ -80,3 +84,18 @@ def test_filter_function_parent(callgraph): # get parent of filtered function uses = get_in_dict(callgraph.uses_edges, "test_code.subpackage1.submodule1.A.__init__") get_node(uses, "test_code.submodule2.test_2") + + +def test_staticmethod_decorator(callgraph): + members = get_in_dict(callgraph.uses_edges, "test_code.subpackage1.submodule1.A") + assert not contains_node(members, "test_code.subpackage1.submodule1.A.staticmethod") + + +def test_use_enum_value(callgraph): + imports = get_in_dict(callgraph.uses_edges, "test_code.submodule3.test_3") + assert contains_node(imports, "test_code.subpackage1.enum.EnumType.ENUM_1") + + +def test_use_class_static(callgraph): + imports = get_in_dict(callgraph.uses_edges, "test_code.submodule3.test_3") + assert contains_node(imports, "test_code.subpackage1.submodule1.A2.STATIC_VAL") diff --git a/tests/test_code/submodule3.py b/tests/test_code/submodule3.py index 571015b..cc3396f 100644 --- a/tests/test_code/submodule3.py +++ b/tests/test_code/submodule3.py @@ -1,5 +1,10 @@ import test_code.subpackage2.submodule_hidden1 +from test_code.subpackage1 import A2 +from test_code.subpackage1 import EnumType + def test_3(): + print(A2.STATIC_VAL) + print(EnumType.ENUM_1) return test_code.subpackage2.submodule_hidden1.test_func1() diff --git a/tests/test_code/subpackage1/__init__.py b/tests/test_code/subpackage1/__init__.py index d213d49..24df05e 100644 --- a/tests/test_code/subpackage1/__init__.py +++ b/tests/test_code/subpackage1/__init__.py @@ -1,3 +1,5 @@ from test_code.subpackage1.submodule1 import A +from test_code.subpackage1.submodule1 import A2 +from test_code.subpackage1.enum import EnumType -__all__ = ["A"] +__all__ = ["A", "A2", "EnumType"] diff --git a/tests/test_code/subpackage1/enum.py b/tests/test_code/subpackage1/enum.py new file mode 100644 index 0000000..cfad829 --- /dev/null +++ b/tests/test_code/subpackage1/enum.py @@ -0,0 +1,6 @@ +from enum import IntEnum + + +class EnumType(IntEnum): + ENUM_1 = 0 + ENUM_2 = 1 diff --git a/tests/test_code/subpackage1/submodule1.py b/tests/test_code/subpackage1/submodule1.py index 7798ee2..578d7e7 100644 --- a/tests/test_code/subpackage1/submodule1.py +++ b/tests/test_code/subpackage1/submodule1.py @@ -2,5 +2,15 @@ class A: + def __init__(self, b): + self.a = None self.b = test_2(b) + + @staticmethod + def test_static(): + pass + + +class A2: + STATIC_VAL = 123 From d94086adf86e1509c76c8db8a3fe84a932642cc3 Mon Sep 17 00:00:00 2001 From: anetczuk Date: Mon, 5 Feb 2024 23:58:24 +0100 Subject: [PATCH 7/8] remove_packages: remove "dangling" class node --- pyan/analyzer.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/pyan/analyzer.py b/pyan/analyzer.py index 0cc765a..25c713f 100644 --- a/pyan/analyzer.py +++ b/pyan/analyzer.py @@ -1754,13 +1754,24 @@ def remove_node(self, node): del self.nodes[key] def remove_packages(self): - node_type = Flavor.MODULE for key in list(self.nodes.keys()): if key not in self.nodes: continue nodes = self.nodes[key] for n in nodes.copy(): - if n.flavor in [Flavor.MODULE]: + if n.flavor is Flavor.MODULE: + self.remove_node(n) + + # remove parent class nodes that are not used + for key in list(self.nodes.keys()): + if key not in self.nodes: + continue + nodes = self.nodes[key] + for n in nodes.copy(): + if n.flavor is not Flavor.CLASS: + continue + callers = self.get_callers(self.uses_edges, n) + if not callers: self.remove_node(n) ########################################################################### From 5611c7d6913da97e8c0d2330765e3df663e2a403 Mon Sep 17 00:00:00 2001 From: anetczuk Date: Tue, 6 Feb 2024 00:21:47 +0100 Subject: [PATCH 8/8] Control function filtering: downward or upward --- pyan/analyzer.py | 134 ++++++++++++++++++++++------------------- pyan/main.py | 8 ++- tests/test_analyzer.py | 11 +++- 3 files changed, 89 insertions(+), 64 deletions(-) diff --git a/pyan/analyzer.py b/pyan/analyzer.py index 25c713f..1eab7a0 100644 --- a/pyan/analyzer.py +++ b/pyan/analyzer.py @@ -242,7 +242,8 @@ def resolve_imports(self): if len(to_nodes) > 0 } - def filter_data(self, function: Union[None, str] = None, namespace: Union[None, str] = None, max_iter: int = 1000): + def filter_data(self, function: Union[None, str] = None, namespace: Union[None, str] = None, max_iter: int = 1000, + filter_down=True, filter_up=False): if function: function_name = function.split(".")[-1] function_namespace = ".".join(function.split(".")[:-1]) @@ -251,9 +252,10 @@ def filter_data(self, function: Union[None, str] = None, namespace: Union[None, else: node = None - self.filter(node=node, namespace=namespace) + self.filter(node=node, namespace=namespace, filter_down=filter_down, filter_up=filter_up) - def filter(self, node: Union[None, Node] = None, namespace: Union[str, None] = None, max_iter: int = 1000): + def filter(self, node: Union[None, Node] = None, namespace: Union[str, None] = None, max_iter: int = 1000, + filter_down: bool = True, filter_up: bool = False): """ filter callgraph nodes that related to `node` or are in `namespace` @@ -262,12 +264,15 @@ def filter(self, node: Union[None, Node] = None, namespace: Union[str, None] = N namespace: namespace to search in (name of top level module), if None, determines namespace from `node` max_iter: maximum number of iterations and nodes to iterate + filter_down: filter nodes in downward + filter_up: filter nodes in upward Returns: self """ # filter the nodes to avoid cluttering the callgraph with irrelevant information - filtered_nodes = self.get_related_nodes(node, namespace=namespace, max_iter=max_iter) + filtered_nodes = self.get_related_nodes(node, namespace=namespace, max_iter=max_iter, + find_downward=filter_down, find_upward=filter_up) self.nodes = {name: [node for node in nodes if node in filtered_nodes] for name, nodes in self.nodes.items()} self.uses_edges = { @@ -283,7 +288,8 @@ def filter(self, node: Union[None, Node] = None, namespace: Union[str, None] = N return self def get_related_nodes( - self, node: Union[None, Node] = None, namespace: Union[str, None] = None, max_iter: int = 1000 + self, node: Union[None, Node] = None, namespace: Union[str, None] = None, max_iter: int = 1000, + find_downward: bool = True, find_upward: bool = False ) -> set: """ get nodes that related to `node` or are in `namespace` @@ -293,6 +299,8 @@ def get_related_nodes( namespace: namespace to search in (name of top level module), if None, determines namespace from `node` max_iter: maximum number of iterations and nodes to iterate + find_downward: look for nodes in downward + find_upward: look for nodes in upward Returns: set: set of nodes related to `node` including `node` itself @@ -316,63 +324,67 @@ def get_related_nodes( namespace = node.namespace.strip(".").split(".", 1)[0] queue = [node] - # use queue system to search through nodes - # essentially add a node to the queue and then search all connected nodes which are in turn added to the queue - # until the queue itself is empty or the maximum limit of max_iter searches have been hit - downstream_new_nodes = new_nodes.copy() - downstream_queue = queue.copy() - i = max_iter - while len(downstream_queue) > 0: - item = downstream_queue.pop() - if item not in downstream_new_nodes: - downstream_new_nodes.add(item) - i -= 1 - if i < 0: - break - # add used nodes that are not already added and are in desired namespace - downstream_queue.extend( - [ - n - for n in self.uses_edges.get(item, []) - if n in self.uses_edges and n not in downstream_new_nodes and namespace in n.namespace - ] - ) - # add defined nodes that are not already added and are in desired namespace - downstream_queue.extend( - [ - n - for n in self.defines_edges.get(item, []) - if n in self.defines_edges and n not in downstream_new_nodes and namespace in n.namespace - ] - ) + downstream_new_nodes = set() + if find_downward: + # use queue system to search through nodes + # essentially add a node to the queue and then search all connected nodes which are in turn added to the queue + # until the queue itself is empty or the maximum limit of max_iter searches have been hit + downstream_new_nodes = new_nodes.copy() + downstream_queue = queue.copy() + i = max_iter + while len(downstream_queue) > 0: + item = downstream_queue.pop() + if item not in downstream_new_nodes: + downstream_new_nodes.add(item) + i -= 1 + if i < 0: + break + # add used nodes that are not already added and are in desired namespace + downstream_queue.extend( + [ + n + for n in self.uses_edges.get(item, []) + if n in self.uses_edges and n not in downstream_new_nodes and namespace in n.namespace + ] + ) + # add defined nodes that are not already added and are in desired namespace + downstream_queue.extend( + [ + n + for n in self.defines_edges.get(item, []) + if n in self.defines_edges and n not in downstream_new_nodes and namespace in n.namespace + ] + ) - # get callers of node - upstream_new_nodes = new_nodes.copy() - upstream_queue = queue.copy() - i = max_iter - while len(upstream_queue) > 0: - item = upstream_queue.pop() - if item not in upstream_new_nodes: - upstream_new_nodes.add(item) - i -= 1 - if i < 0: - break - # add used nodes that are not already added and are in desired namespace - upstream_queue.extend( - [ - n - for n in self.get_callers(self.uses_edges, item) - if n in self.uses_edges and n not in upstream_new_nodes and namespace in n.namespace - ] - ) - # add defined nodes that are not already added and are in desired namespace - upstream_queue.extend( - [ - n - for n in self.get_callers(self.defines_edges, item) - if n in self.defines_edges and n not in upstream_new_nodes and namespace in n.namespace - ] - ) + upstream_new_nodes = set() + if find_upward: + # get callers of node + upstream_new_nodes = new_nodes.copy() + upstream_queue = queue.copy() + i = max_iter + while len(upstream_queue) > 0: + item = upstream_queue.pop() + if item not in upstream_new_nodes: + upstream_new_nodes.add(item) + i -= 1 + if i < 0: + break + # add used nodes that are not already added and are in desired namespace + upstream_queue.extend( + [ + n + for n in self.get_callers(self.uses_edges, item) + if n in self.uses_edges and n not in upstream_new_nodes and namespace in n.namespace + ] + ) + # add defined nodes that are not already added and are in desired namespace + upstream_queue.extend( + [ + n + for n in self.get_callers(self.defines_edges, item) + if n in self.defines_edges and n not in upstream_new_nodes and namespace in n.namespace + ] + ) return downstream_new_nodes.union(upstream_new_nodes) diff --git a/pyan/main.py b/pyan/main.py index f6b2540..640cdce 100644 --- a/pyan/main.py +++ b/pyan/main.py @@ -45,6 +45,10 @@ def main(cli_args=None): parser.add_argument("--function", dest="function", help="filter for FUNCTION (generates call subtree)", metavar="FUNCTION", default=None) + parser.add_argument("--filterdown", dest="filterdown", help="filter downstream (FUNCTION will be root in call tree)", action="store", default=True) + + parser.add_argument("--filterup", dest="filterup", help="filter upstream (FUNCTION will be a leaf in call tree)", action="store", default=False) + parser.add_argument("-l", "--log", dest="logname", help="write log to LOG", metavar="LOG") parser.add_argument("-v", "--verbose", action="store_true", default=False, dest="verbose", help="verbose output") @@ -214,7 +218,9 @@ def main(cli_args=None): v = CallGraphVisitor(filenames, logger=logger, root=root) if known_args.function or known_args.namespace: - v.filter_data(function=known_args.function, namespace=known_args.namespace) + filter_down = known_args.filterdown in ["T", "t", "True", "true", True] + filter_up = known_args.filterup in ["T", "t", "True", "true", True] + v.filter_data(function=known_args.function, namespace=known_args.namespace, filter_down=filter_down, filter_up=filter_up) if not known_args.packages: v.remove_packages() diff --git a/tests/test_analyzer.py b/tests/test_analyzer.py index 450dbda..9c2f66d 100644 --- a/tests/test_analyzer.py +++ b/tests/test_analyzer.py @@ -79,8 +79,15 @@ def test_resolve_package_with_known_root(): get_node(defines, f"{dirname_base}.test_code.subpackage2.submodule_hidden1.test_func1") -def test_filter_function_parent(callgraph): - callgraph.filter_data(function="test_code.submodule2.test_2") +def test_filter_function_downward(callgraph): + callgraph.filter_data(function="test_code.subpackage1.submodule1.A.__init__", filter_down=True, filter_up=False) + # get parent of filtered function + uses = get_in_dict(callgraph.uses_edges, "test_code.submodule2.test_2") + get_node(uses, "test_code.submodule1.test_func1") + + +def test_filter_function_upward(callgraph): + callgraph.filter_data(function="test_code.submodule2.test_2", filter_down=False, filter_up=True) # get parent of filtered function uses = get_in_dict(callgraph.uses_edges, "test_code.subpackage1.submodule1.A.__init__") get_node(uses, "test_code.submodule2.test_2")