Skip to content

Commit

Permalink
fix node merge (#387)
Browse files Browse the repository at this point in the history
# Description

Now dictionaries are merged in a way that localized processing functions
are called first.
(node > local > global)

# Checklist

- [x] I have performed a self-review of the changes

*List here tasks to complete in order to mark this PR as ready for
review.*

# To Consider

- Add tests (if functionality is changed)
- Update API reference / tutorials / guides
- Update CONTRIBUTING.md (if devel workflow is changed)
- Update `.ignore` files, scripts (such as `lint`), distribution
manifest (if files are added/deleted)
- Search for references to changed entities in the codebase
  • Loading branch information
RLKRo authored Sep 7, 2024
1 parent 40981ef commit a7df04a
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 69 deletions.
52 changes: 35 additions & 17 deletions chatsky/core/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,20 +55,28 @@ class Node(BaseModel, extra="forbid"):
Can be accessed at runtime via :py:attr:`~chatsky.core.context.Context.current_node`.
"""

def merge(self, other: Node):
def inherit_from_other(self, other: Node):
"""
Merge another node into this one:
Inherit properties from another node into this one:
- Prepend :py:attr:`transitions` of the other node;
- Replace response if ``other.response`` is not ``None``;
- Update :py:attr:`pre_transition`, :py:attr:`pre_response` and :py:attr:`misc` dictionaries.
- Extend ``self.transitions`` with :py:attr:`transitions` of the other node;
- Replace response with ``other.response`` if ``self.response`` is ``None``;
- Dictionaries (:py:attr:`pre_transition`, :py:attr:`pre_response` and :py:attr:`misc`)
are appended to this node's dictionaries except for the repeating keys.
For example, ``inherit_from_other({1: 1, 3: 3}, {1: 0, 2: 2}) == {1: 1, 3: 3, 2: 2}``.
Basically, only non-conflicting properties of ``other`` are inherited.
"""
self.transitions = [*other.transitions, *self.transitions]
if other.response is not None:

def merge_dicts(first: dict, second: dict):
first.update({k: v for k, v in second.items() if k not in first})

self.transitions.extend(other.transitions)
if self.response is None:
self.response = other.response
self.pre_transition.update(**other.pre_transition)
self.pre_response.update(**other.pre_response)
self.misc.update(**other.misc)
merge_dicts(self.pre_transition, other.pre_transition)
merge_dicts(self.pre_response, other.pre_response)
merge_dicts(self.misc, other.misc)
return self


Expand All @@ -81,7 +89,10 @@ class Flow(BaseModel, extra="allow"):
local_node: Node = Field(
validation_alias=AliasChoices("local", "LOCAL", "local_node", "LOCAL_NODE"), default_factory=Node
)
"""Node from which all other nodes in this Flow inherit properties according to :py:meth:`Node.merge`."""
"""
Node from which all other nodes in this Flow inherit properties
according to :py:meth:`Node.inherit_from_other`.
"""
__pydantic_extra__: Dict[str, Node]

@property
Expand Down Expand Up @@ -111,7 +122,10 @@ class Script(BaseModel, extra="allow"):
global_node: Node = Field(
validation_alias=AliasChoices("global", "GLOBAL", "global_node", "GLOBAL_NODE"), default_factory=Node
)
"""Node from which all other nodes in this Script inherit properties according to :py:meth:`Node.merge`."""
"""
Node from which all other nodes in this Script inherit properties
according to :py:meth:`Node.inherit_from_other`.
"""
__pydantic_extra__: Dict[str, Flow]

@property
Expand Down Expand Up @@ -144,14 +158,14 @@ def get_node(self, label: AbsoluteNodeLabel) -> Optional[Node]:

def get_inherited_node(self, label: AbsoluteNodeLabel) -> Optional[Node]:
"""
Return a new node that inherits (using :py:meth:`Node.merge`)
properties from :py:attr:`Script.global_node`, :py:attr:`Flow.local_node`
and :py:class`Node`.
Return a new node that inherits (using :py:meth:`Node.inherit_from_other`)
properties from :py:class:`Node`, :py:attr:`Flow.local_node`
and :py:attr:`Script.global_node` (in that order).
Flow and node are determined by ``label``.
This is essentially a copy of the node specified by ``label``,
that inherits properties from `global_node` and `local_node`.
that inherits properties from ``local_node`` and ``global_node``.
:return: A new node or ``None`` if it doesn't exist.
"""
Expand All @@ -164,7 +178,11 @@ def get_inherited_node(self, label: AbsoluteNodeLabel) -> Optional[Node]:

inheritant_node = Node()

return inheritant_node.merge(self.global_node).merge(flow.local_node).merge(node)
return (
inheritant_node.inherit_from_other(node)
.inherit_from_other(flow.local_node)
.inherit_from_other(self.global_node)
)


GLOBAL = "GLOBAL"
Expand Down
95 changes: 56 additions & 39 deletions tests/core/test_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,41 +11,54 @@ async def call(self, ctx: Context) -> None:
return


@pytest.mark.parametrize(
"first,second,result",
[
(
Node(transitions=[Tr(dst="node1"), Tr(dst="node2")]),
Node(transitions=[Tr(dst="node3"), Tr(dst="node4")]),
Node(transitions=[Tr(dst="node3"), Tr(dst="node4"), Tr(dst="node1"), Tr(dst="node2")]),
),
(
Node(response="msg1"),
Node(response="msg2"),
Node(response="msg2"),
),
(
Node(response="msg1"),
Node(),
Node(response="msg1"),
),
(
Node(
pre_response={"key": MyProcessing(value="1")},
pre_transition={"key": MyProcessing(value="3")},
misc={"k1": "v1"},
class TestNodeMerge:
@pytest.mark.parametrize(
"first,second,result",
[
(
Node(transitions=[Tr(dst="node3"), Tr(dst="node4")]),
Node(transitions=[Tr(dst="node1"), Tr(dst="node2")]),
Node(transitions=[Tr(dst="node3"), Tr(dst="node4"), Tr(dst="node1"), Tr(dst="node2")]),
),
Node(pre_response={"key": MyProcessing(value="2")}, pre_transition={}, misc={"k2": "v2"}),
Node(
pre_response={"key": MyProcessing(value="2")},
pre_transition={"key": MyProcessing(value="3")},
misc={"k1": "v1", "k2": "v2"},
(
Node(response="msg2"),
Node(response="msg1"),
Node(response="msg2"),
),
),
],
)
def test_node_merge(first, second, result):
assert first.merge(second) == result
(
Node(),
Node(response="msg1"),
Node(response="msg1"),
),
(
Node(pre_response={"key": MyProcessing(value="2")}, pre_transition={}, misc={"k2": "v2"}),
Node(
pre_response={"key": MyProcessing(value="1")},
pre_transition={"key": MyProcessing(value="3")},
misc={"k1": "v1"},
),
Node(
pre_response={"key": MyProcessing(value="2")},
pre_transition={"key": MyProcessing(value="3")},
misc={"k1": "v1", "k2": "v2"},
),
),
],
)
def test_node_merge(self, first, second, result):
assert first.inherit_from_other(second) == result

def test_dict_key_order(self):
global_node_dict = {"1": MyProcessing(value="1"), "3": MyProcessing(value="3")}
global_node = Node(pre_response=global_node_dict, pre_transition=global_node_dict, misc=global_node_dict)
local_node_dict = {"1": MyProcessing(value="1*"), "2": MyProcessing(value="2")}
local_node = Node(pre_response=local_node_dict, pre_transition=local_node_dict, misc=local_node_dict)

result_node = local_node.model_copy().inherit_from_other(global_node)

assert list(result_node.pre_response.keys()) == ["1", "2", "3"]
assert list(result_node.pre_transition.keys()) == ["1", "2", "3"]
assert list(result_node.misc.keys()) == ["1", "2", "3"]


def test_flow_get_node():
Expand All @@ -71,14 +84,18 @@ def test_get_inherited_node():
global_node = Node(misc={"k1": "g1", "k2": "g2", "k3": "g3"})
local_node = Node(misc={"k2": "l1", "k3": "l2", "k4": "l3"})
node = Node(misc={"k3": "n1", "k4": "n2", "k5": "n3"})
global_node_copy = global_node.model_copy(deep=True)
local_node_copy = local_node.model_copy(deep=True)
node_copy = node.model_copy(deep=True)

script = Script.model_validate({"global": global_node, "flow": {"local": local_node, "node": node}})

assert script.get_inherited_node(AbsoluteNodeLabel(flow_name="", node_name="")) is None
assert script.get_inherited_node(AbsoluteNodeLabel(flow_name="flow", node_name="")) is None
assert script.get_inherited_node(AbsoluteNodeLabel(flow_name="flow", node_name="node")) == Node(
misc={"k1": "g1", "k2": "l1", "k3": "n1", "k4": "n2", "k5": "n3"}
)
inherited_node = script.get_inherited_node(AbsoluteNodeLabel(flow_name="flow", node_name="node"))
assert inherited_node == Node(misc={"k1": "g1", "k2": "l1", "k3": "n1", "k4": "n2", "k5": "n3"})
assert list(inherited_node.misc.keys()) == ["k3", "k4", "k5", "k2", "k1"]
# assert not changed
assert script.global_node == global_node
assert script.get_flow("flow").local_node == local_node
assert script.get_node(AbsoluteNodeLabel(flow_name="flow", node_name="node")) == node
assert script.global_node == global_node_copy
assert script.get_flow("flow").local_node == local_node_copy
assert script.get_node(AbsoluteNodeLabel(flow_name="flow", node_name="node")) == node_copy
30 changes: 25 additions & 5 deletions tutorials/script/core/7_pre_response_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,14 +146,34 @@ async def modified_response(self, original_response, ctx):
}


# %% [markdown]
"""
The order of execution for processing functions is as follows:
1. All node-specific functions are executed in the order of definition;
2. All local functions are executed in the order of definition except those with
keys matching to previously executed functions;
3. All global functions are executed in the order of definition
except those with keys matching to previously executed functions.
That means that if both global and local nodes
define a processing function with key "processing_name",
only the one inside the local node will be executed.
This demonstrated in the happy path below
(the first prefix in the text is the last one to execute):
"""


# %%
# testing
happy_path = (
(Message(), "l3_local: l2_local: l1_global: first"),
(Message(), "l1_global: l3_local: l2_local: first"),
(Message(), "l3_local: l2_local: l1_step_1: second"),
(Message(), "l3_local: l2_step_2: l1_global: third"),
(Message(), "l3_step_3: l2_local: l1_global: fourth"),
(Message(), "l4_step_4: l3_local: l2_local: l1_global: fifth"),
(Message(), "l3_local: l2_local: l1_global: first"),
(Message(), "l1_global: l3_local: l2_step_2: third"),
(Message(), "l1_global: l2_local: l3_step_3: fourth"),
(Message(), "l1_global: l3_local: l2_local: l4_step_4: fifth"),
(Message(), "l1_global: l3_local: l2_local: first"),
)


Expand Down
16 changes: 8 additions & 8 deletions tutorials/script/core/8_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,30 +92,30 @@ async def call(self, ctx: Context) -> MessageInitTypes:
(
Message(),
"node_name=step_0: current_node.misc="
"{'var1': 'global_data', "
"{'var3': 'this overwrites local values - step_0', "
"'var2': 'global data is overwritten by local', "
"'var3': 'this overwrites local values - step_0'}",
"'var1': 'global_data'}",
),
(
Message(),
"node_name=step_1: current_node.misc="
"{'var1': 'global_data', "
"{'var3': 'this overwrites local values - step_1', "
"'var2': 'global data is overwritten by local', "
"'var3': 'this overwrites local values - step_1'}",
"'var1': 'global_data'}",
),
(
Message(),
"node_name=step_2: current_node.misc="
"{'var1': 'global_data', "
"{'var3': 'this overwrites local values - step_2', "
"'var2': 'global data is overwritten by local', "
"'var3': 'this overwrites local values - step_2'}",
"'var1': 'global_data'}",
),
(
Message(),
"node_name=step_0: current_node.misc="
"{'var1': 'global_data', "
"{'var3': 'this overwrites local values - step_0', "
"'var2': 'global data is overwritten by local', "
"'var3': 'this overwrites local values - step_0'}",
"'var1': 'global_data'}",
),
)

Expand Down

0 comments on commit a7df04a

Please sign in to comment.