Skip to content

Commit

Permalink
Everything is running :-)
Browse files Browse the repository at this point in the history
  • Loading branch information
andreas-zeller committed Jan 5, 2025
1 parent 3164f33 commit c2e3d37
Showing 1 changed file with 75 additions and 48 deletions.
123 changes: 75 additions & 48 deletions notebooks/Alhazen.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@
"outputs": [],
"source": [
"# Load initial input files\n",
"sample_list = ['sqrt(-16)', 'sqrt(4)']"
"initial_sample_list = ['sqrt(-16)', 'sqrt(4)']"
]
},
{
Expand Down Expand Up @@ -604,7 +604,11 @@
" @abstractmethod\n",
" def get_feature_value(self, derivation_tree) -> float:\n",
" '''Returns the feature value for a given derivation tree of an input.'''\n",
" pass"
" pass\n",
"\n",
" def replace(self, new_key: str) -> 'Feature':\n",
" '''Returns a new feature with the same name but a different key.'''\n",
" return self.__class__(self.name, self.rule, new_key)"
]
},
{
Expand Down Expand Up @@ -1113,7 +1117,6 @@
"source": [
"from sklearn.tree import DecisionTreeClassifier\n",
"from sklearn.feature_extraction import DictVectorizer\n",
"from sklearn import tree\n",
"\n",
"import graphviz"
]
Expand Down Expand Up @@ -1220,11 +1223,15 @@
"metadata": {},
"outputs": [],
"source": [
"dot_data = tree.export_graphviz(clf, out_file=None,\n",
" feature_names=vec.get_feature_names_out(),\n",
" class_names=[\"BUG\", \"NO BUG\"],\n",
" filled=True, rounded=True)\n",
"graph = graphviz.Source(dot_data)"
"import graphviz\n",
"import sklearn\n",
"\n",
"def show_decision_tree(clf, feature_names):\n",
" dot_data = sklearn.tree.export_graphviz(clf, out_file=None, \n",
" feature_names=feature_names,\n",
" class_names=[\"BUG\", \"NO_BUG\"], \n",
" filled=True, rounded=True) \n",
" return graphviz.Source(dot_data)"
]
},
{
Expand All @@ -1233,7 +1240,7 @@
"metadata": {},
"outputs": [],
"source": [
"display(graph)"
"show_decision_tree(clf, vec.get_feature_names_out())"
]
},
{
Expand Down Expand Up @@ -1309,6 +1316,7 @@
"source": [
"def train_tree(data):\n",
" sample_bug_count = len(data[(data[\"oracle\"].astype(str) == \"BUG\")])\n",
" assert sample_bug_count > 0, \"No bug samples found\"\n",
" sample_count = len(data)\n",
"\n",
" clf = DecisionTreeClassifier(min_samples_leaf=1,\n",
Expand Down Expand Up @@ -1392,7 +1400,7 @@
"clf = clf.fit(X_data, oracle)\n",
"\n",
"import graphviz\n",
"dot_data = tree.export_graphviz(clf, out_file=None, \n",
"dot_data = sklearn.tree.export_graphviz(clf, out_file=None, \n",
" feature_names=feature_names,\n",
" class_names=[\"BUG\", \"NO BUG\"], \n",
" filled=True, rounded=True) \n",
Expand Down Expand Up @@ -1697,13 +1705,19 @@
" mini = row['min']\n",
" maxi = row['max']\n",
" if (not np.isinf(mini)) or (not np.isinf(maxi)):\n",
" requirements.append(Requirement(feature, mini, maxi))\n",
" requirements.append(TreeRequirement(feature, mini, maxi))\n",
" paths.append(TreePath(None, is_bug, requirements))\n",
"\n",
" return paths\n",
"\n",
"\n",
"class Requirement:\n",
" return paths\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class TreeRequirement:\n",
"\n",
" def __init__(self, feature: Feature, mini, maxi):\n",
" self.__feature: Feature = feature\n",
Expand Down Expand Up @@ -1788,10 +1802,10 @@
"\n",
"class TreePath:\n",
"\n",
" def __init__(self, samplefile: Optional[Path], is_bug: bool, requirements: List[Requirement]):\n",
" def __init__(self, samplefile: Optional[Path], is_bug: bool, requirements: List[TreeRequirement]):\n",
" self.__sample = samplefile\n",
" self.__is_bug = is_bug\n",
" self.__requirements: List[Requirement] = requirements\n",
" self.__requirements: List[TreeRequirement] = requirements\n",
"\n",
" def is_bug(self) -> bool:\n",
" return self.__is_bug\n",
Expand Down Expand Up @@ -2152,7 +2166,7 @@
"from typing import List\n",
"from fuzzingbook.GrammarFuzzer import DerivationTree\n",
"\n",
"class Requirement:\n",
"class SpecRequirement:\n",
" '''\n",
" This class represents a requirement for a new input sample that should be generated.\n",
" This class contains the feature that should be fullfiled (Feature), a quantifier\n",
Expand All @@ -2171,18 +2185,24 @@
" self.value = value\n",
"\n",
" def __str__(self):\n",
" return f\"Requirement({self.feature.name} {self.quant} {self.value})\"\n",
"\n",
"\n",
" return f\"Requirement({self.feature.name} {self.quant} {self.value})\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class InputSpecification:\n",
" '''\n",
" This class represents a complet input specification of a new input. A input specification\n",
" consists of one or more requirements.\n",
" requirements : Is a list of all requirements that must be used.\n",
" '''\n",
"\n",
" def __init__(self, requirements: List[Requirement]):\n",
" self.requirements: List[Reqirement] = requirements\n",
" def __init__(self, requirements: List[SpecRequirement]):\n",
" self.requirements: List[SpecRequirement] = requirements\n",
"\n",
" def __str__(self):\n",
" # Handle first element\n",
Expand Down Expand Up @@ -2227,7 +2247,7 @@
" if f.name == feature_name:\n",
" feature_class = f\n",
"\n",
" requirement_list.append(Requirement(feature_class, quant, value))\n",
" requirement_list.append(SpecRequirement(feature_class, quant, value))\n",
"\n",
" return InputSpecification(requirement_list)\n",
"\n",
Expand Down Expand Up @@ -2489,6 +2509,15 @@
" return final_samples"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"generate_samples = generate_samples_advanced"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -2524,16 +2553,16 @@
"exsqrt = ExistenceFeature('exists(<function>@0)', '<function>', 'sqrt')\n",
"exdigit = ExistenceFeature('exists(<digit>)', '<digit>', '<digit>')\n",
"\n",
"reqDigit = Requirement(exdigit, '>', '0.5')\n",
"fbdDigit = Requirement(exdigit, '<=', '0.5')\n",
"reqDigit = SpecRequirement(exdigit, '>', '0.5')\n",
"fbdDigit = SpecRequirement(exdigit, '<=', '0.5')\n",
"\n",
"req0 = Requirement(exsqrt, '>', '-6.0')\n",
"req0 = SpecRequirement(exsqrt, '>', '-6.0')\n",
"testspec0 = InputSpecification([req0, reqDigit])\n",
"req1 = Requirement(exsqrt, '<=', '-6.0')\n",
"req1 = SpecRequirement(exsqrt, '<=', '-6.0')\n",
"testspec1 = InputSpecification([req1, fbdDigit])\n",
"\n",
"numterm = NumericInterpretation('num(<term>)', '<term>')\n",
"req2 = Requirement(numterm, '<', '-31.0')\n",
"req2 = SpecRequirement(numterm, '<', '-31.0')\n",
"testspec2 = InputSpecification([req2, req0, reqDigit])\n",
"\n",
"print('--generating samples--')\n",
Expand Down Expand Up @@ -2663,7 +2692,8 @@
"\n",
"# let's initialize Alhazen\n",
"# let's use the previously used sample_list (['sqrt(-16)', 'sqrt(4)'])\n",
"alhazen = Alhazen(sample_list, CALC_GRAMMAR, MAX_ITERATION, GENERATOR_TIMEOUT)\n",
"alhazen = Alhazen(initial_sample_list,\n",
" CALC_GRAMMAR, MAX_ITERATION, GENERATOR_TIMEOUT)\n",
"\n",
"# and run it\n",
"# Alhazen returns a list of all the iteratively learned decision trees\n",
Expand All @@ -2674,10 +2704,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"</hr>\n",
"\n",
"Let's display the final decision tree learned by Alhazen. You can use the function `show_tree(decison_tree, features)` to display the final tree."
"Let's display the final decision tree learned by Alhazen. You can use the function `show_tree(decision_tree, features)` to display the final tree."
]
},
{
Expand All @@ -2686,12 +2713,8 @@
"metadata": {},
"outputs": [],
"source": [
"def show_tree(clf, feature_names):\n",
" dot_data = tree.export_graphviz(clf, out_file=None, \n",
" feature_names= feature_names,\n",
" class_names=[\"BUG\", \"NO_BUG\"], \n",
" filled=True, rounded=True) \n",
" return graphviz.Source(dot_data)"
"final_tree = trees[MAX_ITERATION-1]\n",
"final_tree"
]
},
{
Expand All @@ -2701,14 +2724,23 @@
"outputs": [],
"source": [
"all_features = extract_existence(CALC_GRAMMAR) + extract_numeric(CALC_GRAMMAR)\n",
"# show_tree(trees[MAX_ITERATION-1], all_features)"
"all_feature_names = [f.name for f in all_features]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"show_decision_tree(final_tree, all_feature_names)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Info:** The decision tree may contain unnecessary long paths, where the bug-class does not change. You can use the function 'remove_unequal_decisions(decision_tree)' to remove those nodes."
"**Info:** The decision tree may contain unnecessary long paths, where the bug-class does not change. You can use the function `remove_unequal_decisions(decision_tree)` to remove those nodes."
]
},
{
Expand All @@ -2717,14 +2749,9 @@
"metadata": {},
"outputs": [],
"source": [
"show_tree(remove_unequal_decisions(trees[MAX_ITERATION-1]), all_features)"
"show_decision_tree(remove_unequal_decisions(final_tree), all_feature_names)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down

0 comments on commit c2e3d37

Please sign in to comment.