Skip to content

Commit

Permalink
Merge pull request #79 from Hynn01/main
Browse files Browse the repository at this point in the history
Fix hyperparameters-tensorflow checker with a library check and Update README
  • Loading branch information
Hynn01 authored May 9, 2022
2 parents 75c5125 + f613569 commit 90d8c08
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 3 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ git submodule update --init --recursive
```
2. `dslinter` uses `poetry` to manage dependencies, so you will need to install `poetry` first and then install the dependencies.
```
pip install poerty
pip install poetry
poetry install
```
- To install `dslinter` from source for development purposes, install it with:
Expand Down
4 changes: 3 additions & 1 deletion dslinter/checkers/hyperparameters_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pylint.lint import PyLinter
from dslinter.checkers.hyperparameters import HyperparameterChecker
from dslinter.utils.exception_handler import ExceptionHandler
from dslinter.utils.hyperparameters_helper import check_module_with_library


class HyperparameterTensorflowChecker(HyperparameterChecker):
Expand Down Expand Up @@ -82,7 +83,8 @@ def visit_call(self, node: astroid.Call):
):
self.hyperparameter_in_class(node, node.func.name)
if(
hasattr(node, "func")
check_module_with_library(node, "tensorflow")
and hasattr(node, "func")
and hasattr(node.func, "attrname")
and node.func.attrname == "fit"
):
Expand Down
23 changes: 23 additions & 0 deletions dslinter/tests/checkers/test_hyperparameters_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,26 @@ def test_learning_rate_not_set(self):
with self.assertAddsMessages(pylint.testutils.MessageTest(msg_id = "hyperparameters-tensorflow", node = call_node)):
self.checker.visit_importfrom(importfrom_node)
self.checker.visit_call(call_node)

def test_sklearn_fit(self):
script = """
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split
from sklearn.pipeline import make_pipeline
from sklearn.decomposition import PCA
from sklearn.naive_bayes import GaussianNB
from sklearn.metrics import accuracy_score
RANDOM_STATE = 42
features, target = load_wine(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(
features, target, test_size=0.30, random_state=RANDOM_STATE
)
clf = make_pipeline(PCA(n_components=2), GaussianNB())
clf.fit(X_train, y_train) #@
pred_test = clf.predict(X_test)
ac = accuracy_score(y_test, pred_test)
"""
call_node = astroid.extract_node(script)
with self.assertNoMessages():
self.checker.visit_call(call_node)
15 changes: 15 additions & 0 deletions dslinter/utils/hyperparameters_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import astroid


def check_module_with_library(node, library_name: str):
while not isinstance(node.parent, astroid.Module):
node = node.parent
module = node

if isinstance(module, astroid.Module):
for node in module.body:
if isinstance(node, astroid.Import):
for name, _ in node.names:
if name == library_name:
return True
return False
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ skip = 'scripts'

[tool.poetry]
name = "dslinter"
version = "2.0.4"
version = "2.0.5"
description = "`dslinter` is a pylint plugin for linting data science and machine learning code. We plan to support the following Python libraries: TensorFlow, PyTorch, Scikit-Learn, Pandas, NumPy and SciPy."

license = "GPL-3.0 License"
Expand Down

0 comments on commit 90d8c08

Please sign in to comment.