From 4341a4dd19c6f0bf18eb6b8eb57870259fc671ec Mon Sep 17 00:00:00 2001 From: Andreas Zeller Date: Sun, 5 Jan 2025 20:28:50 +0100 Subject: [PATCH] New: print out trees in compact form --- notebooks/Alhazen.ipynb | 55 ++++++++++++++++++++++++++++++++++++++--- 1 file changed, 52 insertions(+), 3 deletions(-) diff --git a/notebooks/Alhazen.ipynb b/notebooks/Alhazen.ipynb index 532dad40..d7b7188c 100644 --- a/notebooks/Alhazen.ipynb +++ b/notebooks/Alhazen.ipynb @@ -715,7 +715,8 @@ " e.g., 'num()'\n", " rule : The production rule.\n", " '''\n", - " def __init__(self, name: str, rule: str, friendly_name: str = None) -> None:\n", + " def __init__(self, name: str, rule: str, /, \n", + " friendly_name: str = None) -> None:\n", " super().__init__(name, rule, rule, friendly_name=friendly_name)\n", "\n", " def name_rep(self) -> str:\n", @@ -2213,7 +2214,7 @@ "source": [ "class InputSpecification:\n", " '''\n", - " This class represents a complet input specification of a new input. A input specification\n", + " This class represents a complete input specification of a new input. A input specification\n", " consists of one or more requirements.\n", " requirements : Is a list of all requirements that must be used.\n", " '''\n", @@ -2669,7 +2670,8 @@ "\n", " self._all_features = extract_existence(self._grammar) + extract_numeric(self._grammar)\n", " self._feature_names = [f.name for f in self._all_features]\n", - " print(f\"Features: {self._feature_names}\")\n", + " print(\"Features:\", \", \".join(f.friendly_name() \n", + " for f in self._all_features))\n", "\n", " def _add_new_data(self, exec_data, feature_data):\n", " joined_data = exec_data.join(feature_data.drop(['sample'], axis=1))\n", @@ -2821,6 +2823,53 @@ "show_decision_tree(remove_unequal_decisions(final_tree), all_feature_names)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import math\n", + "\n", + "def friendly_decision_tree(clf, feature_names, class_names = ['NO_BUG', 'BUG']):\n", + " def _tree(index, indent=0):\n", + " s = \"\"\n", + " feature = clf.tree_.feature[index]\n", + " feature_name = feature_names[feature]\n", + " threshold = clf.tree_.threshold[index]\n", + " value = clf.tree_.value[index]\n", + " class_ = int(value[0][0])\n", + " class_name = class_names[class_]\n", + " left = clf.tree_.children_left[index]\n", + " right = clf.tree_.children_right[index]\n", + " if left == right:\n", + " # Leaf node\n", + " s += \" \" * indent + class_name + \"\\n\"\n", + " else:\n", + " if math.isclose(threshold, 0.5):\n", + " s += \" \" * indent + f\"if {feature_name}:\\n\"\n", + " s += _tree(right, indent + 2)\n", + " s += \" \" * indent + f\"else:\\n\"\n", + " s += _tree(left, indent + 2)\n", + " else:\n", + " s += \" \" * indent + f\"if {feature_name} <= {threshold:.4f}:\\n\"\n", + " s += _tree(left, indent + 2)\n", + " s += \" \" * indent + f\"else:\\n\"\n", + " s += _tree(right, indent + 2)\n", + " return s\n", + "\n", + " return _tree(0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(friendly_decision_tree(final_tree, all_feature_names))" + ] + }, { "cell_type": "markdown", "metadata": {},