Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

get_module_name infinite loop #95

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,10 @@ htmlcov
.idea/
.history/
.vscode/
# Eclipse
/.settings
.project
.pydevproject


/tmp/
199 changes: 169 additions & 30 deletions pyan/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ def resolve_imports(self):
while len(imports_to_resolve) > 0:
from_node = imports_to_resolve.pop()
if from_node in import_mapping:
# already marked to map - skip
continue
to_uses = self.uses_edges.get(from_node, set([from_node]))
assert len(to_uses) == 1
Expand Down Expand Up @@ -218,6 +219,9 @@ def resolve_imports(self):
f"{from_node.namespace}.{from_node.name}" == node.namespace
and from_node.flavor == Flavor.IMPORTEDITEM
):
if to_node not in self.defines_edges:
self.logger.warning("unable to find define edges from %s", to_node)
continue
# use define edges as potential candidates
for candidate_to_node in self.defines_edges[to_node]: #
if candidate_to_node.name == node.name:
Expand All @@ -238,7 +242,20 @@ def resolve_imports(self):
if len(to_nodes) > 0
}

def filter(self, node: Union[None, Node] = None, namespace: Union[str, None] = None, max_iter: int = 1000):
def filter_data(self, function: Union[None, str] = None, namespace: Union[None, str] = None, max_iter: int = 1000,
filter_down=True, filter_up=False):
if function:
function_name = function.split(".")[-1]
function_namespace = ".".join(function.split(".")[:-1])
node = self.get_node(function_namespace, function_name)

else:
node = None

self.filter(node=node, namespace=namespace, filter_down=filter_down, filter_up=filter_up)

def filter(self, node: Union[None, Node] = None, namespace: Union[str, None] = None, max_iter: int = 1000,
filter_down: bool = True, filter_up: bool = False):
"""
filter callgraph nodes that related to `node` or are in `namespace`

Expand All @@ -247,12 +264,15 @@ def filter(self, node: Union[None, Node] = None, namespace: Union[str, None] = N
namespace: namespace to search in (name of top level module),
if None, determines namespace from `node`
max_iter: maximum number of iterations and nodes to iterate
filter_down: filter nodes in downward
filter_up: filter nodes in upward

Returns:
self
"""
# filter the nodes to avoid cluttering the callgraph with irrelevant information
filtered_nodes = self.get_related_nodes(node, namespace=namespace, max_iter=max_iter)
filtered_nodes = self.get_related_nodes(node, namespace=namespace, max_iter=max_iter,
find_downward=filter_down, find_upward=filter_up)

self.nodes = {name: [node for node in nodes if node in filtered_nodes] for name, nodes in self.nodes.items()}
self.uses_edges = {
Expand All @@ -268,7 +288,8 @@ def filter(self, node: Union[None, Node] = None, namespace: Union[str, None] = N
return self

def get_related_nodes(
self, node: Union[None, Node] = None, namespace: Union[str, None] = None, max_iter: int = 1000
self, node: Union[None, Node] = None, namespace: Union[str, None] = None, max_iter: int = 1000,
find_downward: bool = True, find_upward: bool = False
) -> set:
"""
get nodes that related to `node` or are in `namespace`
Expand All @@ -278,6 +299,8 @@ def get_related_nodes(
namespace: namespace to search in (name of top level module),
if None, determines namespace from `node`
max_iter: maximum number of iterations and nodes to iterate
find_downward: look for nodes in downward
find_upward: look for nodes in upward

Returns:
set: set of nodes related to `node` including `node` itself
Expand All @@ -301,33 +324,69 @@ def get_related_nodes(
namespace = node.namespace.strip(".").split(".", 1)[0]
queue = [node]

# use queue system to search through nodes
# essentially add a node to the queue and then search all connected nodes which are in turn added to the queue
# until the queue itself is empty or the maximum limit of max_iter searches have been hit
i = max_iter
while len(queue) > 0:
item = queue.pop()
if item not in new_nodes:
new_nodes.add(item)
i -= 1
if i < 0:
break
queue.extend(
[
n
for n in self.uses_edges.get(item, [])
if n in self.uses_edges and n not in new_nodes and namespace in n.namespace
]
)
queue.extend(
[
n
for n in self.defines_edges.get(item, [])
if n in self.defines_edges and n not in new_nodes and namespace in n.namespace
]
)
downstream_new_nodes = set()
if find_downward:
# use queue system to search through nodes
# essentially add a node to the queue and then search all connected nodes which are in turn added to the queue
# until the queue itself is empty or the maximum limit of max_iter searches have been hit
downstream_new_nodes = new_nodes.copy()
downstream_queue = queue.copy()
i = max_iter
while len(downstream_queue) > 0:
item = downstream_queue.pop()
if item not in downstream_new_nodes:
downstream_new_nodes.add(item)
i -= 1
if i < 0:
break
# add used nodes that are not already added and are in desired namespace
downstream_queue.extend(
[
n
for n in self.uses_edges.get(item, [])
if n in self.uses_edges and n not in downstream_new_nodes and namespace in n.namespace
]
)
# add defined nodes that are not already added and are in desired namespace
downstream_queue.extend(
[
n
for n in self.defines_edges.get(item, [])
if n in self.defines_edges and n not in downstream_new_nodes and namespace in n.namespace
]
)

upstream_new_nodes = set()
if find_upward:
# get callers of node
upstream_new_nodes = new_nodes.copy()
upstream_queue = queue.copy()
i = max_iter
while len(upstream_queue) > 0:
item = upstream_queue.pop()
if item not in upstream_new_nodes:
upstream_new_nodes.add(item)
i -= 1
if i < 0:
break
# add used nodes that are not already added and are in desired namespace
upstream_queue.extend(
[
n
for n in self.get_callers(self.uses_edges, item)
if n in self.uses_edges and n not in upstream_new_nodes and namespace in n.namespace
]
)
# add defined nodes that are not already added and are in desired namespace
upstream_queue.extend(
[
n
for n in self.get_callers(self.defines_edges, item)
if n in self.defines_edges and n not in upstream_new_nodes and namespace in n.namespace
]
)

return new_nodes
return downstream_new_nodes.union(upstream_new_nodes)

def visit_Module(self, node):
self.logger.debug("Module %s, %s" % (self.module_name, self.filename))
Expand Down Expand Up @@ -611,6 +670,14 @@ def analyze_module_import(self, import_item, ast_node):
alias_name = mod_node.name
self.add_uses_edge(from_node, mod_node)
self.logger.info("New edge added for Use import %s in %s" % (mod_node, from_node))

curr_scope = self.scope_stack[-1]
if alias_name not in curr_scope.defs:
# seems that symtable module does not handle following "import aaa.bbb" as expected in pyan
# it returns "aaa", but the analyzer expects full symbol "aaa.bbb"
# workaround is to add missing empty symbol
curr_scope.defs[alias_name] = None

self.set_value(alias_name, mod_node) # set node to be discoverable in module
self.logger.info("From setting name %s to %s" % (alias_name, mod_node))

Expand Down Expand Up @@ -701,7 +768,7 @@ def visit_Attribute(self, node):
)
if self.add_uses_edge(from_node, to_node):
self.logger.info(
"New edge added for Use from {from_node} to {to_node} (target obj {obj_node} known but "
f"New edge added for Use from {from_node} to {to_node} (target obj {obj_node} known but "
f"target attr {node.attr} not resolved; maybe fwd ref or unanalyzed import)"
)

Expand All @@ -718,6 +785,19 @@ def visit_Attribute(self, node):
def visit_Name(self, node):
self.logger.debug("Name %s in context %s, %s:%s" % (node.id, type(node.ctx), self.filename, node.lineno))

in_class_ns = self.context_stack[-1].startswith("ClassDef")
if in_class_ns:
# add enum value or class static field
#TODO: there should be additional Flavor and proper handling of static fields
tgt_name = node.id
val_node = self.get_value(tgt_name) # resolves "self" if needed
if val_node is None and tgt_name not in ["staticmethod", "classmethod"]:
# static field case
from_node = self.get_node_of_current_namespace()
ns = from_node.get_name()
to_node = self.get_node(ns, tgt_name, node, flavor=Flavor.ATTRIBUTE)
self.add_defines_edge(from_node, to_node)

# TODO: self.last_value is a hack. Handle names in store context (LHS)
# in analyze_binding(), so that visit_Name() only needs to handle
# the load context (i.e. detect uses of the name).
Expand Down Expand Up @@ -1189,6 +1269,12 @@ def resolve_attribute(self, ast_node):
self.logger.debug("Resolved to attr %s of %s" % (ast_node.attr, sc.defs[attr_name]))
return sc.defs[attr_name], ast_node.attr

attr_name = get_ast_node_name(ast_node.value)
obj_node = self.get_value(attr_name) # resolves chained calls and imports ("aaa.bbb") if needed
if obj_node is not None:
self.logger.debug("Resolved to attr %s of %s" % (attr_name, obj_node))
return obj_node, attr_name

# It may happen that ast_node.value has no corresponding graph Node,
# if this is a forward-reference, or a reference to a file
# not in the analyzed set.
Expand Down Expand Up @@ -1536,6 +1622,7 @@ def add_defines_edge(self, from_node, to_node):
status = True
from_node.defined = True
if to_node is None or to_node in self.defines_edges[from_node]:
# edge already defined
return status
self.defines_edges[from_node].add(to_node)
to_node.defined = True
Expand Down Expand Up @@ -1656,6 +1743,49 @@ def remove_wild(self, from_node, to_node, name):
self.logger.info("Use from %s to %s resolves %s; removing wildcard" % (from_node, to_node, wild_node))
self.remove_uses_edge(from_node, wild_node)

def remove_node(self, node):
# remove from edges
if node in self.uses_edges:
del self.uses_edges[node]
# remove to edges
for key in list(self.uses_edges.keys()):
items = self.uses_edges[key]
if node in items:
items.remove(node)
if not items:
# empty list
del self.uses_edges[key]

# remove nodes
for key in list(self.nodes.keys()):
items = self.nodes[key]
if node in items:
items.remove(node)
if not items:
# empty list
del self.nodes[key]

def remove_packages(self):
for key in list(self.nodes.keys()):
if key not in self.nodes:
continue
nodes = self.nodes[key]
for n in nodes.copy():
if n.flavor is Flavor.MODULE:
self.remove_node(n)

# remove parent class nodes that are not used
for key in list(self.nodes.keys()):
if key not in self.nodes:
continue
nodes = self.nodes[key]
for n in nodes.copy():
if n.flavor is not Flavor.CLASS:
continue
callers = self.get_callers(self.uses_edges, n)
if not callers:
self.remove_node(n)

###########################################################################
# Postprocessing

Expand Down Expand Up @@ -1762,3 +1892,12 @@ def collapse_inner(self):
self.logger.info("Collapsing inner from %s to %s, uses %s" % (n, pn, n2))
self.add_uses_edge(pn, n2)
n.defined = False


# return list of callers (keys) of given node in edges_dict
def get_callers(self, edges_dict, node):
ret_list = []
for caller, callees_list in edges_dict.items():
if node in callees_list:
ret_list.append(caller)
return ret_list
6 changes: 6 additions & 0 deletions pyan/anutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ def get_module_name(filename, root: str = None):
# otherwise it is the filename without extension
module_path = filename.replace(".py", "")

# will enter infinite loop or exception in case of non-absolute path (starting with ".")
module_path = os.path.abspath(module_path)

# find the module root - walk up the tree and check if it contains .py files - if yes. it is the new root
directories = [(module_path, True)]
if root is None:
Expand All @@ -47,6 +50,9 @@ def get_module_name(filename, root: str = None):
else: # root is already known - just walk up until it is matched
while directories[0][0] != root:
potential_root = os.path.dirname(directories[0][0])
if potential_root == directories[0][0]:
# root directory reached - stop iteration
break
directories.insert(0, (potential_root, True))

mod_name = ".".join([os.path.basename(f[0]) for f in directories])
Expand Down
28 changes: 18 additions & 10 deletions pyan/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@ def main(cli_args=None):

parser.add_argument("--namespace", dest="namespace", help="filter for NAMESPACE", metavar="NAMESPACE", default=None)

parser.add_argument("--function", dest="function", help="filter for FUNCTION", metavar="FUNCTION", default=None)
parser.add_argument("--function", dest="function", help="filter for FUNCTION (generates call subtree)", metavar="FUNCTION", default=None)

parser.add_argument("--filterdown", dest="filterdown", help="filter downstream (FUNCTION will be root in call tree)", action="store", default=True)

parser.add_argument("--filterup", dest="filterup", help="filter upstream (FUNCTION will be a leaf in call tree)", action="store", default=False)

parser.add_argument("-l", "--log", dest="logname", help="write log to LOG", metavar="LOG")

Expand Down Expand Up @@ -93,6 +97,14 @@ def main(cli_args=None):
help="do not add edges for 'uses' relationships",
)

parser.add_argument(
"--no-packages",
action="store_false",
default=True,
dest="packages",
help="do not add packages and import relationships",
)

parser.add_argument(
"-c",
"--colored",
Expand Down Expand Up @@ -206,16 +218,12 @@ def main(cli_args=None):
v = CallGraphVisitor(filenames, logger=logger, root=root)

if known_args.function or known_args.namespace:
filter_down = known_args.filterdown in ["T", "t", "True", "true", True]
filter_up = known_args.filterup in ["T", "t", "True", "true", True]
v.filter_data(function=known_args.function, namespace=known_args.namespace, filter_down=filter_down, filter_up=filter_up)

if known_args.function:
function_name = known_args.function.split(".")[-1]
namespace = ".".join(known_args.function.split(".")[:-1])
node = v.get_node(namespace, function_name)

else:
node = None

v.filter(node=node, namespace=known_args.namespace)
if not known_args.packages:
v.remove_packages()

graph = VisualGraph.from_visitor(v, options=graph_options, logger=logger)

Expand Down
Loading