Skip to content

Commit

Permalink
Merge pull request #78 from Hynn01/main
Browse files Browse the repository at this point in the history
Fix randomness control checkers and update README
  • Loading branch information
Hynn01 authored May 7, 2022
2 parents 29718e4 + fec0690 commit caec5fb
Show file tree
Hide file tree
Showing 11 changed files with 132 additions and 156 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
[![PyPI - Downloads - Monthly](https://img.shields.io/pypi/dm/dslinter)](https://pypi.org/project/dslinter/)
[![Code Grade](https://api.codiga.io/project/33224/status/svg)](https://api.codiga.io/project/33224/status/svg)

> Hi! We’re currently researching the code smells in machine learning projects in the industry context and looking for feedback for `dslinter`! It would be a massive help if you could run `dslinter` on your machine learning project in an industry setting and send the text and the json output to [email protected] . The steps and commands can be found [here](https://github.com/SERG-Delft/dslinter/blob/main/STEPS_TO_FOLLOW.md) and it should take no more than 10 minutes. Feel free to send me an [email]([email protected]) if you want to go through the process together. The process is anonymous and we will remove any sensitive information before the results are published. Many thanks!
`dslinter` is a PyLint plugin for linting data science and machine learning code. It aims to help developers ensure the machine learning code quality and supports the following Python libraries: TensorFlow, PyTorch, Scikit-Learn, Pandas, NumPy and SciPy.

`dslinter` implements the detection rules for smells identified by [our previous work](https://arxiv.org/pdf/2203.13746.pdf). The smells are collected from papers, grey literature, GitHub commits, and Stack Overflow posts. The smells are also elaborated at a [website](https://hynn01.github.io/ml-smells/) :)
Expand Down
63 changes: 33 additions & 30 deletions dslinter/checkers/deterministic_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,50 +34,53 @@ class DeterministicAlgorithmChecker(BaseChecker):
),
)

_import_pytorch = False
_has_deterministic_algorithm_option = False

def visit_import(self, node: astroid.Import):
"""
Check whether there is a pytorch import
:param node: import node
"""
try:
if self._import_pytorch is False:
self._import_pytorch = has_import(node, "torch")
except: # pylint: disable = bare-except
ExceptionHandler.handle(self, node)

def visit_module(self, module: astroid.Module):
"""
Check whether use_deterministic_algorithms option is used.
:param module: call node
Check whether there is a rule violation.
:param module:
"""
try:
_import_pytorch = False
_has_deterministic_algorithm_option = False

# if the user wants to only check main module, but the current file is not main module, just return
_is_main_module = check_main_module(module)
if self.config.no_main_module_check_deterministic_pytorch is False and _is_main_module is False:
return

# if torch.use_deterministic_algorithm() is call and the argument is True,
# set _has_deterministic_algorithm_option to True
# traverse over the node in the module
for node in module.body:
if isinstance(node, astroid.Import):
if _import_pytorch is False:
_import_pytorch = has_import(node, "torch")

if isinstance(node, astroid.nodes.Expr) and hasattr(node, "value"):
call_node = node.value
if(
hasattr(call_node, "func")
and hasattr(call_node.func, "attrname")
and call_node.func.attrname == "use_deterministic_algorithms"
and hasattr(call_node, "args")
and len(call_node.args) > 0
and hasattr(call_node.args[0], "value")
and call_node.args[0].value is True
):
self._has_deterministic_algorithm_option = True
if _has_deterministic_algorithm_option is False:
_has_deterministic_algorithm_option = self._check_deterministic_algorithm_option(call_node)

# check if the rules are violated
if(
self._import_pytorch is True
and self._has_deterministic_algorithm_option is False
_import_pytorch is True
and _has_deterministic_algorithm_option is False
):
self.add_message("deterministic-pytorch", node=module)

except: # pylint: disable = bare-except
ExceptionHandler.handle(self, module)

@staticmethod
def _check_deterministic_algorithm_option(call_node: astroid.Call):
# if torch.use_deterministic_algorithm() is call and the argument is True,
# set _has_deterministic_algorithm_option to True
if(
hasattr(call_node, "func")
and hasattr(call_node.func, "attrname")
and call_node.func.attrname == "use_deterministic_algorithms"
and hasattr(call_node, "args")
and len(call_node.args) > 0
and hasattr(call_node.args[0], "value")
and call_node.args[0].value is True
):
return True
return False
6 changes: 3 additions & 3 deletions dslinter/checkers/randomness_control_dataloader_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class RandomnessControlDataloaderPytorchChecker(BaseChecker):
}
options = ()

_import_DataLoader = False
_import_dataloader = False

def visit_importfrom(self, importfrom_node: astroid.ImportFrom):
"""
Expand All @@ -35,7 +35,7 @@ def visit_importfrom(self, importfrom_node: astroid.ImportFrom):
):
for name, _ in importfrom_node.names:
if name == "DataLoader":
self._import_DataLoader = True
self._import_dataloader = True
except: # pylint: disable = bare-except
ExceptionHandler.handle(self, importfrom_node)

Expand All @@ -60,7 +60,7 @@ def visit_call(self, node: astroid.Call):
def _use_dataloader_from_import(self, node):
# Dataloader has been imported from torch.utils.data
if(
self._import_DataLoader is True
self._import_dataloader is True
and hasattr(node.func, "name")
and node.func.name == "DataLoader"
):
Expand Down
83 changes: 37 additions & 46 deletions dslinter/checkers/randomness_control_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,6 @@ class RandomnessControlNumpyChecker(BaseChecker):
"The np.random.seed() should be set in numpy program for reproducible result."
)
}
options = ()

_import_numpy = False
_has_manual_seed = False
_import_ml_libraries = False

options = (
(
Expand All @@ -38,63 +33,59 @@ class RandomnessControlNumpyChecker(BaseChecker):
),
)

def visit_import(self, node: astroid.Import):
"""
Check whether there is a numpy import and ml library import.
:param node: import node
"""
try:
if self._import_numpy is False:
self._import_numpy = has_import(node, "numpy")
if self._import_ml_libraries is False:
self._import_ml_libraries = has_import(node, "sklearn") \
or has_import(node, "torch") \
or has_import(node, "tensorflow")
except: # pylint: disable = bare-except
ExceptionHandler.handle(self, node)

def visit_importfrom(self, node: astroid.ImportFrom):
"""
Check whether there is a scikit-learn import.
:param node: import from node
"""
try:
if self._import_ml_libraries is False:
self._import_ml_libraries = has_importfrom_sklearn(node)
except: # pylint: disable = bare-except
ExceptionHandler.handle(self, node)


def visit_module(self, module: astroid.Module):
"""
Check whether there is a rule violation.
:param module:
"""
try:
_import_numpy = False
_import_ml_libraries = False
_has_numpy_manual_seed = False

# if the user wants to only check main module, but the current file is not main module, just return
_is_main_module = check_main_module(module)
if self.config.no_main_module_check_randomness_control_numpy is False and _is_main_module is False:
return

# traverse over the node in the module
for node in module.body:
if isinstance(node, astroid.Import):
if _import_ml_libraries is False:
_import_ml_libraries = has_import(node, "tensorflow") or has_import(node, "torch") or has_import(node, "sklearn")
if _import_numpy is False:
_import_numpy = has_import(node, "numpy")

if isinstance(node, astroid.ImportFrom):
if _import_ml_libraries is False:
_import_ml_libraries = has_importfrom_sklearn(node)

if isinstance(node, astroid.nodes.Expr) and hasattr(node, "value"):
call_node = node.value
if(
hasattr(call_node, "func")
and hasattr(call_node.func, "attrname")
and call_node.func.attrname == "seed"
and hasattr(call_node.func.expr, "attrname")
and call_node.func.expr.attrname == "random"
and hasattr(call_node.func.expr, "expr")
and hasattr(call_node.func.expr.expr, "name")
and call_node.func.expr.expr.name in ["np", "numpy"]
):
self._has_manual_seed = True
if _has_numpy_manual_seed is False:
_has_numpy_manual_seed = self._check_numpy_manual_seed(call_node)

# check if the rules are violated
if(
self._import_numpy is True
and self._import_ml_libraries is True
and self._has_manual_seed is False
_import_numpy is True
and _import_ml_libraries is True
and _has_numpy_manual_seed is False
):
self.add_message("randomness-control-numpy", node=module)
except: # pylint: disable = bare-except
ExceptionHandler.handle(self, module)

@staticmethod
def _check_numpy_manual_seed(call_node: astroid.Call):
if(
hasattr(call_node, "func")
and hasattr(call_node.func, "attrname")
and call_node.func.attrname == "seed"
and hasattr(call_node.func.expr, "attrname")
and call_node.func.expr.attrname == "random"
and hasattr(call_node.func.expr, "expr")
and hasattr(call_node.func.expr.expr, "name")
and call_node.func.expr.expr.name in ["np", "numpy"]
):
return True
return False
45 changes: 23 additions & 22 deletions dslinter/checkers/randomness_control_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,45 +33,46 @@ class RandomnessControlPytorchChecker(BaseChecker):
),
)

_import_pytorch = False
_has_manual_seed = False

def visit_import(self, node: astroid.Import):
"""
Check whether there is a pytorch import.
:param node: import node
"""
try:
if self._import_pytorch is False:
self._import_pytorch = has_import(node, "torch")
except: # pylint: disable = bare-except
ExceptionHandler.handle(self, node)

def visit_module(self, module: astroid.Module):
"""
Check whether there is a rule violation.
:param module:
"""
try:
_import_pytorch = False
_has_pytorch_manual_seed = False

# if the user wants to only check main module, but the current file is not main module, just return
_is_main_module = check_main_module(module)
if self.config.no_main_module_check_randomness_control_pytorch is False and _is_main_module is False:
return

# traverse over the node in the module
for node in module.body:
if isinstance(node, astroid.Import):
if _import_pytorch is False:
_import_pytorch = has_import(node, "torch")

if isinstance(node, astroid.nodes.Expr) and hasattr(node, "value"):
call_node = node.value
if(
hasattr(call_node, "func")
and hasattr(call_node.func, "attrname")
and call_node.func.attrname == "manual_seed"
):
self._has_manual_seed = True
if _has_pytorch_manual_seed is False:
_has_pytorch_manual_seed = self._check_pytorch_manual_seed(call_node)

# check if the rules are violated
if(
self._import_pytorch is True
and self._has_manual_seed is False
_import_pytorch is True
and _has_pytorch_manual_seed is False
):
self.add_message("randomness-control-pytorch", node=module)
except: # pylint: disable = bare-except
ExceptionHandler.handle(self, module)

@staticmethod
def _check_pytorch_manual_seed(call_node: astroid.Call):
if(
hasattr(call_node, "func")
and hasattr(call_node.func, "attrname")
and call_node.func.attrname == "manual_seed"
):
return True
return False
3 changes: 1 addition & 2 deletions dslinter/checkers/randomness_control_scikitlearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ def visit_call(self, node: astroid.Call):
if _has_random_state_keyword is False:
self.add_message("randomness-control-scikitlearn", node=node)

# pylint: disable = W0702
except:
except: # pylint: disable = bare-except
ExceptionHandler.handle(self, node)
traceback.print_exc()
57 changes: 30 additions & 27 deletions dslinter/checkers/randomness_control_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class RandomnessControlTensorflowChecker(BaseChecker):
"The tf.random.set_seed() should be set in TensorFlow program for reproducible result"
)
}

options = (
(
"no_main_module_check_randomness_control_tensorflow",
Expand All @@ -32,49 +33,51 @@ class RandomnessControlTensorflowChecker(BaseChecker):
),
)

_import_tensorflow = False
_has_manual_seed = False

def visit_import(self, node: astroid.Import):
"""
Check whether there is a tensorflow import.
:param node: import node
"""
try:
if self._import_tensorflow is False:
self._import_tensorflow = has_import(node, "tensorflow")
except: # pylint: disable = bare-except
ExceptionHandler.handle(self, node)

def visit_module(self, module: astroid.Module):
"""
Check whether there is a rule violation.
:param module:
"""
try:
_import_tensorflow = False
_has_tensorflow_manual_seed = False

# if the user wants to only check main module, but the current file is not main module, just return
_is_main_module = check_main_module(module)
if self.config.no_main_module_check_randomness_control_tensorflow is False and _is_main_module is False:
return

# traverse over the node in the module
for node in module.body:
if isinstance(node, astroid.Import):
if _import_tensorflow is False:
_import_tensorflow = has_import(node, "tensorflow")

if isinstance(node, astroid.nodes.Expr) and hasattr(node, "value"):
call_node = node.value
if(
hasattr(call_node, "func")
and hasattr(call_node.func, "attrname")
and call_node.func.attrname == "set_seed"
and hasattr(call_node.func.expr, "attrname")
and call_node.func.expr.attrname == "random"
and hasattr(call_node.func.expr, "expr")
and hasattr(call_node.func.expr.expr, "name")
and call_node.func.expr.expr.name in ["tf", "tensorflow"]
):
self._has_manual_seed = True
if _has_tensorflow_manual_seed is False:
_has_tensorflow_manual_seed = self._check_tensorflow_manual_seed(call_node)

# check if the rules are violated
if(
self._import_tensorflow is True
and self._has_manual_seed is False
_import_tensorflow is True
and _has_tensorflow_manual_seed is False
):
self.add_message("randomness-control-tensorflow", node=module)
except: # pylint: disable = bare-except
ExceptionHandler.handle(self, module)

@staticmethod
def _check_tensorflow_manual_seed(call_node: astroid.Call):
if(
hasattr(call_node, "func")
and hasattr(call_node.func, "attrname")
and call_node.func.attrname == "set_seed"
and hasattr(call_node.func.expr, "attrname")
and call_node.func.expr.attrname == "random"
and hasattr(call_node.func.expr, "expr")
and hasattr(call_node.func.expr.expr, "name")
and call_node.func.expr.expr.name in ["tf", "tensorflow"]
):
return True
return False
Loading

0 comments on commit caec5fb

Please sign in to comment.