Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Root Node to Graph #192

Merged
merged 20 commits into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions thicket/tests/test_add_root_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright 2022 Lawrence Livermore National Security, LLC and other
# Thicket Project Developers. See the top-level LICENSE file for details.
#
# SPDX-License-Identifier: MIT

from hatchet.node import Node


def test_add_root_node(literal_thickets):
tk, _, _ = literal_thickets

assert len(tk.graph) == 4

# Call add_root_node
tk.add_root_node({"name": "Test", "type": "function"})
# Get node variable
test_node = tk.get_node("Test")

# Check if node was inserted in all components
assert isinstance(test_node, Node)
assert test_node._hatchet_nid == 3
assert test_node._depth == 0
assert len(tk.graph) == 5
assert len(tk.statsframe.graph) == 5
assert test_node in tk.dataframe.index.get_level_values("node")
assert test_node in tk.statsframe.dataframe.index.get_level_values("node")

assert tk.dataframe.loc[test_node, "name"].values[0] == "Test"
assert tk.statsframe.dataframe.loc[test_node, "name"] == "Test"
20 changes: 20 additions & 0 deletions thicket/tests/test_get_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright 2022 Lawrence Livermore National Security, LLC and other
# Thicket Project Developers. See the top-level LICENSE file for details.
#
# SPDX-License-Identifier: MIT

import pytest


def test_get_node(literal_thickets):
tk, _, _ = literal_thickets

with pytest.raises(KeyError):
tk.get_node("Foo")

baz = tk.get_node("Baz")

# Check node properties
assert baz.frame["name"] == "Baz"
assert baz.frame["type"] == "function"
assert baz._hatchet_nid == 0
53 changes: 53 additions & 0 deletions thicket/thicket.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
import pandas as pd
import numpy as np
from hatchet import GraphFrame
from hatchet.frame import Frame
from hatchet.graph import Graph
from hatchet.node import Node
from hatchet.query import QueryEngine
from thicket.query import (
Query,
Expand Down Expand Up @@ -1514,6 +1516,57 @@ def get_unique_metadata(self):

return sorted_meta

def add_root_node(self, attrs):
"""Add node at root level with given attributes.

Arguments:
attrs (dict): attributes for the new node which will be used to initilize the
node.frame.
"""

new_node = Node(frame_obj=Frame(attrs=attrs))

# graph and statsframe.graph
self.graph.roots.append(new_node)

# Set hatchet nid and depth
self.graph.enumerate_traverse()

# dataframe
idx_levels = self.dataframe.index.names
new_idx = [[new_node]] + [self.profile]
new_node_df = pd.DataFrame(
index=pd.MultiIndex.from_product(new_idx, names=idx_levels)
)
new_node_df["name"] = attrs["name"]
self.dataframe = pd.concat([self.dataframe, new_node_df])

# statsframe.dataframe
self.statsframe.dataframe = helpers._new_statsframe_df(self.dataframe)
# Reapply stats operations after clearing statsframe dataframe
self.reapply_stats_operations()

# Check Thicket state
validate_nodes(self)

def get_node(self, name):
"""Get a node object in the Thicket by its Node.frame['name']. If more than one
node has the same name, a list of nodes is returned.

Arguments:
name (str): name of the node (Node.frame['name']).

Returns:
(Node or list(Node)): Node object with the given name or list of Node objects
with the given name.
"""
node = [n for n in self.graph.traverse() if n.frame["name"] == name]

if len(node) == 0:
raise KeyError(f'Node with name "{name}" not found.')

return node[0] if len(node) == 1 else node

def _sync_profile_components(self, component):
"""Synchronize the Performance DataFrame, Metadata Dataframe, profile and
profile mapping objects based on the component's index or a list of profiles.
Expand Down
Loading