From 6b861d6c642808145819b517ea1d359cdad195a2 Mon Sep 17 00:00:00 2001 From: Eran Hirsch Date: Mon, 16 Sep 2024 06:16:27 +0300 Subject: [PATCH] Fix LATS select to choose the best child at each level until reaching leaf child (#1723) --- docs/docs/tutorials/lats/lats.ipynb | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/docs/docs/tutorials/lats/lats.ipynb b/docs/docs/tutorials/lats/lats.ipynb index f32c8ccf5..2595f3d95 100644 --- a/docs/docs/tutorials/lats/lats.ipynb +++ b/docs/docs/tutorials/lats/lats.ipynb @@ -152,14 +152,6 @@ " return not self.children\n", "\n", " @property\n", - " def best_child(self):\n", - " \"\"\"Select the child with the highest UCT to search next.\"\"\"\n", - " if not self.children:\n", - " return None\n", - " all_nodes = self._get_all_children()\n", - " return max(all_nodes, key=lambda child: child.upper_confidence_bound())\n", - "\n", - " @property\n", " def best_child_score(self):\n", " \"\"\"Return the child with the highest value.\"\"\"\n", " if not self.children:\n", @@ -589,11 +581,23 @@ "source": [ "from collections import defaultdict\n", "\n", + "def select(root: Node) -> dict:\n", + " \"\"\"Starting from the root node a child node is selected at each tree level until a leaf node is reached.\"\"\"\n", + "\n", + " if not root.children:\n", + " return root\n", + " \n", + " node = root\n", + " while node.children:\n", + " max_child = max(node.children, key=lambda child: child.upper_confidence_bound())\n", + " node = max_child\n", + "\n", + " return node\n", "\n", "def expand(state: TreeState, config: RunnableConfig) -> dict:\n", " \"\"\"Starting from the \"best\" node in the tree, generate N candidates for the next step.\"\"\"\n", " root = state[\"root\"]\n", - " best_candidate: Node = root.best_child if root.children else root\n", + " best_candidate: Node = select(root)\n", " messages = best_candidate.get_trajectory()\n", " # Generate N candidates from the single child candidate\n", " new_candidates = expansion_chain.invoke(\n",