Skip to content

Commit

Permalink
New: print out trees in compact form
Browse files Browse the repository at this point in the history
  • Loading branch information
andreas-zeller committed Jan 5, 2025
1 parent 57e23a4 commit 4341a4d
Showing 1 changed file with 52 additions and 3 deletions.
55 changes: 52 additions & 3 deletions notebooks/Alhazen.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -715,7 +715,8 @@
" e.g., 'num(<integer>)'\n",
" rule : The production rule.\n",
" '''\n",
" def __init__(self, name: str, rule: str, friendly_name: str = None) -> None:\n",
" def __init__(self, name: str, rule: str, /, \n",
" friendly_name: str = None) -> None:\n",
" super().__init__(name, rule, rule, friendly_name=friendly_name)\n",
"\n",
" def name_rep(self) -> str:\n",
Expand Down Expand Up @@ -2213,7 +2214,7 @@
"source": [
"class InputSpecification:\n",
" '''\n",
" This class represents a complet input specification of a new input. A input specification\n",
" This class represents a complete input specification of a new input. A input specification\n",
" consists of one or more requirements.\n",
" requirements : Is a list of all requirements that must be used.\n",
" '''\n",
Expand Down Expand Up @@ -2669,7 +2670,8 @@
"\n",
" self._all_features = extract_existence(self._grammar) + extract_numeric(self._grammar)\n",
" self._feature_names = [f.name for f in self._all_features]\n",
" print(f\"Features: {self._feature_names}\")\n",
" print(\"Features:\", \", \".join(f.friendly_name() \n",
" for f in self._all_features))\n",
"\n",
" def _add_new_data(self, exec_data, feature_data):\n",
" joined_data = exec_data.join(feature_data.drop(['sample'], axis=1))\n",
Expand Down Expand Up @@ -2821,6 +2823,53 @@
"show_decision_tree(remove_unequal_decisions(final_tree), all_feature_names)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import math\n",
"\n",
"def friendly_decision_tree(clf, feature_names, class_names = ['NO_BUG', 'BUG']):\n",
" def _tree(index, indent=0):\n",
" s = \"\"\n",
" feature = clf.tree_.feature[index]\n",
" feature_name = feature_names[feature]\n",
" threshold = clf.tree_.threshold[index]\n",
" value = clf.tree_.value[index]\n",
" class_ = int(value[0][0])\n",
" class_name = class_names[class_]\n",
" left = clf.tree_.children_left[index]\n",
" right = clf.tree_.children_right[index]\n",
" if left == right:\n",
" # Leaf node\n",
" s += \" \" * indent + class_name + \"\\n\"\n",
" else:\n",
" if math.isclose(threshold, 0.5):\n",
" s += \" \" * indent + f\"if {feature_name}:\\n\"\n",
" s += _tree(right, indent + 2)\n",
" s += \" \" * indent + f\"else:\\n\"\n",
" s += _tree(left, indent + 2)\n",
" else:\n",
" s += \" \" * indent + f\"if {feature_name} <= {threshold:.4f}:\\n\"\n",
" s += _tree(left, indent + 2)\n",
" s += \" \" * indent + f\"else:\\n\"\n",
" s += _tree(right, indent + 2)\n",
" return s\n",
"\n",
" return _tree(0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(friendly_decision_tree(final_tree, all_feature_names))"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down

0 comments on commit 4341a4d

Please sign in to comment.