Skip to content

Commit

Permalink
Refactored the baseclasses for more logical inheritance for the two a…
Browse files Browse the repository at this point in the history
…bstract classes
  • Loading branch information
Esben Jannik Bjerrum committed Nov 22, 2024
1 parent f7b20f1 commit f092feb
Show file tree
Hide file tree
Showing 2 changed files with 179 additions and 112 deletions.
89 changes: 57 additions & 32 deletions scikit_mol/fingerprints/baseclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,16 @@
)


class FpsTransformer(ABC, BaseEstimator, TransformerMixin):
class BaseFpsTransformer(ABC, BaseEstimator, TransformerMixin):
def __init__(
self,
parallel: Union[bool, int] = False,
start_method: str = None,
safe_inference_mode: bool = False,
dtype: np.dtype = np.int8,
):
self.parallel = parallel
self.start_method = start_method
self.safe_inference_mode = safe_inference_mode
self.dtype = dtype

@property
def nBits(self):
Expand Down Expand Up @@ -98,34 +96,25 @@ def get_feature_names_out(self, input_features=None):
prefix = self._get_column_prefix()
return np.array([f"{prefix}_{i}" for i in range(1, self.fpSize + 1)])

@abstractmethod
def _mol2fp(self, mol):
"""Generate fingerprint from mol
MUST BE OVERWRITTEN
"""
raise NotImplementedError("_mol2fp not implemented")

def _fp2array(self, fp):
if fp:
arr = np.zeros((self.fpSize,), dtype=self.dtype)
DataStructs.ConvertToNumpyArray(fp, arr)
return arr
else:
return np.ma.masked_all((self.fpSize,), dtype=self.dtype)

def _transform_mol(self, mol):
def _safe_transform_mol(self, mol):
"""Handle safe inference mode with masked arrays"""
if not mol and self.safe_inference_mode:
return self._fp2array(False)
return np.ma.masked_all(self.fpSize)

try:
fp = self._mol2fp(mol)
return self._fp2array(fp)
result = self._transform_mol(mol)
return result
except Exception as e:
if self.safe_inference_mode:
return self._fp2array(False)
return np.ma.masked_all(self.fpSize)
else:
raise e

@abstractmethod
def _transform_mol(self, mol):
"""Transform a single molecule to numpy array"""
raise NotImplementedError

def fit(self, X, y=None):
"""Included for scikit-learn compatibility
Expand All @@ -137,15 +126,20 @@ def fit(self, X, y=None):
def _transform(self, X):
if self.safe_inference_mode:
# Use the new method with masked arrays if we're in safe inference mode
arrays = [self._transform_mol(mol) for mol in X]
arrays = [self._safe_transform_mol(mol) for mol in X]
print(arrays)
return np.ma.stack(arrays)
else:
elif hasattr(
self, "dtype"
): # TODO, it seems a bit of a code smell that we have to preemptively test a property from the baseclass?
# Use the original, faster method if we're not in safe inference mode
arr = np.zeros((len(X), self.fpSize), dtype=self.dtype)
for i, mol in enumerate(X):
arr[i, :] = self._transform_mol(mol)
return arr
else: # We are unsure on the dtype, so we don't use a preassigned array #TODO test time differnece to previous
arrays = [self._transform_mol(mol) for mol in X]
return np.stack(arrays)

def _transform_sparse(self, X):
arr = np.zeros((len(X), self.fpSize), dtype=self.dtype)
Expand Down Expand Up @@ -202,20 +196,49 @@ def transform(self, X, y=None):
return arr


class FpsGeneratorTransformer(FpsTransformer):
_regenerate_on_properties = ()
class FpsTransformer(BaseFpsTransformer):
"""Classic fingerprint transformer using mol2fp pattern"""

def _fp2array(self, fp):
raise DeprecationWarning("Generators can directly return fingerprints")
def __init__(
self,
parallel: Union[bool, int] = False,
safe_inference_mode: bool = False,
dtype: np.dtype = np.int8,
):
super().__init__(parallel=parallel, safe_inference_mode=safe_inference_mode)
self.dtype = dtype

def _transform_mol(self, mol):
"""Implements the mol -> rdkit fingerprint data structure -> numpy array pattern"""
fp = self._mol2fp(mol)
return self._fp2array(fp)

@abstractmethod
def _mol2fp(self, mol):
raise DeprecationWarning("use _mol2array")
"""Generate fingerprint from mol
MUST BE OVERWRITTEN
"""
raise NotImplementedError("_mol2fp not implemented")

def _fp2array(self, fp):
"""Convert RDKit fingerprint data structure to numpy array"""
if fp:
arr = np.zeros((self.fpSize,), dtype=self.dtype)
DataStructs.ConvertToNumpyArray(fp, arr)
return arr
else:
return np.ma.masked_all((self.fpSize,), dtype=self.dtype)


class FpsGeneratorTransformer(BaseFpsTransformer):
_regenerate_on_properties = ()

def __getstate__(self):
# Get the state of the parent class
state = super().__getstate__()
state.update(self.get_params())
# Remove the unpicklable property from the state
# Remove the potentiallyunpicklable property from the state
state.pop("_fpgen", None) # fpgen is not picklable
return state

Expand All @@ -234,6 +257,8 @@ def __setstate__(self, state):
]
self._generate_fp_generator()

# TODO: overload set_params in order to not make multiple calls to _generate_fp_generator

def __setattr__(self, name: str, value):
super().__setattr__(name, value)
if (
Expand Down
Loading

0 comments on commit f092feb

Please sign in to comment.