Skip to content

Commit

Permalink
Fix LATS select to choose the best child at each level until reaching…
Browse files Browse the repository at this point in the history
… leaf child (#1723)
  • Loading branch information
eranhirs authored Sep 16, 2024
1 parent 55d79b1 commit 6b861d6
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions docs/docs/tutorials/lats/lats.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -152,14 +152,6 @@
" return not self.children\n",
"\n",
" @property\n",
" def best_child(self):\n",
" \"\"\"Select the child with the highest UCT to search next.\"\"\"\n",
" if not self.children:\n",
" return None\n",
" all_nodes = self._get_all_children()\n",
" return max(all_nodes, key=lambda child: child.upper_confidence_bound())\n",
"\n",
" @property\n",
" def best_child_score(self):\n",
" \"\"\"Return the child with the highest value.\"\"\"\n",
" if not self.children:\n",
Expand Down Expand Up @@ -589,11 +581,23 @@
"source": [
"from collections import defaultdict\n",
"\n",
"def select(root: Node) -> dict:\n",
" \"\"\"Starting from the root node a child node is selected at each tree level until a leaf node is reached.\"\"\"\n",
"\n",
" if not root.children:\n",
" return root\n",
" \n",
" node = root\n",
" while node.children:\n",
" max_child = max(node.children, key=lambda child: child.upper_confidence_bound())\n",
" node = max_child\n",
"\n",
" return node\n",
"\n",
"def expand(state: TreeState, config: RunnableConfig) -> dict:\n",
" \"\"\"Starting from the \"best\" node in the tree, generate N candidates for the next step.\"\"\"\n",
" root = state[\"root\"]\n",
" best_candidate: Node = root.best_child if root.children else root\n",
" best_candidate: Node = select(root)\n",
" messages = best_candidate.get_trajectory()\n",
" # Generate N candidates from the single child candidate\n",
" new_candidates = expansion_chain.invoke(\n",
Expand Down

0 comments on commit 6b861d6

Please sign in to comment.