Skip to content

Commit

Permalink
Updated notebooks and some more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Roman Joeres authored and Roman Joeres committed Mar 18, 2024
1 parent ebc24b0 commit 17b7b19
Show file tree
Hide file tree
Showing 12 changed files with 725 additions and 548 deletions.
13 changes: 6 additions & 7 deletions datasail/reader/read_molecules.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List, Tuple, Optional
from pathlib import Path
from typing import List, Tuple, Optional, Dict

import numpy as np
from rdkit import Chem
Expand All @@ -8,7 +9,7 @@
except ImportError:
MolFromMrvFile = None

from datasail.reader.utils import DataSet, read_data, DATA_INPUT, MATRIX_INPUT, read_data_input
from datasail.reader.utils import DataSet, read_data, DATA_INPUT, MATRIX_INPUT, read_data_input, read_sdf_file
from datasail.settings import M_TYPE, UNK_LOCATION, FORM_SMILES


Expand Down Expand Up @@ -53,15 +54,13 @@ def read_molecule_data(
"""
dataset = DataSet(type=M_TYPE, format=FORM_SMILES, location=UNK_LOCATION)

def read_dir(ds: DataSet):
def read_dir(ds: DataSet, path: Path):
ds.data = {}
for file in data.iterdir():
for file in path.iterdir():
if file.suffix[1:].lower() != "sdf" and mol_reader[file.suffix[1:].lower()] is not None:
ds.data[file.stem] = mol_reader[file.suffix[1:].lower()](file)
else:
suppl = Chem.SDMolSupplier(file)
for i, mol in enumerate(suppl):
ds.data[f"{file.stem}_{i}"] = mol
ds.data = read_sdf_file(file)

read_data_input(data, dataset, read_dir)

Expand Down
24 changes: 24 additions & 0 deletions datasail/reader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import h5py
import numpy as np
import pandas as pd
from rdkit import Chem

from datasail.reader.validate import validate_user_args
from datasail.settings import get_default, SIM_ALGOS, DIST_ALGOS, UNK_LOCATION, format2ending, FASTA_FORMATS
Expand Down Expand Up @@ -383,6 +384,8 @@ def read_data_input(data: DATA_INPUT, dataset: DataSet, read_dir: Callable[[Data
elif data.suffix[1:].lower() == "h5":
with h5py.File(data) as file:
dataset.data = {k: np.array(file[k]) for k in file.keys()}
elif data.suffix[1:].lower() == "sdf":
dataset.data = read_sdf_file(data)
else:
raise ValueError("Unknown file format. Supported formats are: .fasta, .fna, .fa, tsv, .csv, .pkl, .h5")
elif data.is_dir():
Expand All @@ -402,6 +405,27 @@ def read_data_input(data: DATA_INPUT, dataset: DataSet, read_dir: Callable[[Data
raise ValueError("Unknown data input type.")


def read_sdf_file(file: Path) -> Dict[str, str]:
"""
Read in a SDF file and return the data as a dataset.
Args:
file: The file to read in
Returns:
The dataset containing the data
"""
data = {}
suppl = Chem.SDMolSupplier(str(file))
for i, mol in enumerate(suppl):
try:
name = mol.GetProp("_Name") if mol.HasProp("_Name") else f"{file.stem}_{i}"
data[name] = Chem.MolToSmiles(mol)
except:
pass
return data


def parse_fasta(path: Path = None) -> Dict[str, str]:
"""
Parse a FASTA file and do some validity checks if requested.
Expand Down
34 changes: 17 additions & 17 deletions datasail/sail.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@ def validate_args(**kwargs) -> Dict[str, object]:
"""
# create output directory
output_created = False
if kwargs[KW_OUTDIR] and not kwargs[KW_OUTDIR].is_dir():
if kwargs[KW_OUTDIR] is not None and not kwargs[KW_OUTDIR].is_dir():
output_created = True
kwargs[KW_OUTDIR].mkdir(parents=True, exist_ok=True)

LOGGER.setLevel(VERB_MAP[kwargs[KW_VERBOSE]])
LOGGER.handlers[0].setLevel(level=VERB_MAP[kwargs[KW_VERBOSE]])

if kwargs[KW_OUTDIR]:
if kwargs[KW_OUTDIR] is not None:
kwargs[KW_LOGDIR] = kwargs[KW_OUTDIR] / "logs"
kwargs[KW_LOGDIR].mkdir(parents=True, exist_ok=True)
file_handler = logging.FileHandler(kwargs[KW_LOGDIR] / "general.log")
Expand All @@ -65,7 +65,7 @@ def validate_args(**kwargs) -> Dict[str, object]:
if len(kwargs[KW_SPLITS]) < 2:
error("Less then two splits required. This is no useful input, please check the input again.", 1,
kwargs[KW_CLI])
if not kwargs[KW_NAMES]:
if kwargs[KW_NAMES] is None:
kwargs[KW_NAMES] = [f"Split{x:03d}" for x in range(len(kwargs[KW_SPLITS]))]
elif len(kwargs[KW_SPLITS]) != len(kwargs[KW_NAMES]):
error("Different number of splits and names. You have to give the same number of splits and names for them.",
Expand All @@ -88,23 +88,23 @@ def validate_args(**kwargs) -> Dict[str, object]:
kwargs[KW_THREADS] = min(kwargs[KW_THREADS], os.cpu_count())

# check the interaction file
if kwargs[KW_INTER] and isinstance(kwargs[KW_INTER], Path) and not kwargs[KW_INTER].is_file():
if kwargs[KW_INTER] is not None and isinstance(kwargs[KW_INTER], Path) and not kwargs[KW_INTER].is_file():
error("The interaction filepath is not valid.", 5, kwargs[KW_CLI])

# check the epsilon value
if 1 < kwargs[KW_DELTA] < 0:
if 1 < kwargs[KW_DELTA] or kwargs[KW_DELTA] < 0:
error("The delta value has to be a real value between 0 and 1.", 6, kwargs[KW_CLI])

# check the epsilon value
if 1 < kwargs[KW_EPSILON] < 0:
if 1 < kwargs[KW_EPSILON] or kwargs[KW_EPSILON] < 0:
error("The epsilon value has to be a real value between 0 and 1.", 6, kwargs[KW_CLI])

# check number of runs to be a positive integer
if kwargs[KW_RUNS] < 1:
error("The number of runs cannot be lower than 1.", 25, kwargs[KW_CLI])

# check the input regarding the caching
if kwargs[KW_CACHE] and kwargs[KW_CACHE_DIR]:
if kwargs[KW_CACHE] and kwargs[KW_CACHE_DIR] is not None:
kwargs[KW_CACHE_DIR] = Path(kwargs[KW_CACHE_DIR])
if not kwargs[KW_CACHE_DIR].is_dir():
LOGGER.warning("Cache directory does not exist, DataSAIL creates it automatically")
Expand All @@ -114,18 +114,18 @@ def validate_args(**kwargs) -> Dict[str, object]:
error("The linkage method has to be one of 'mean', 'single', or 'complete'.", 26, kwargs[KW_CLI])

# syntactically parse the input data for the E-dataset
if kwargs[KW_E_DATA] and isinstance(kwargs[KW_E_DATA], Path) and not kwargs[KW_E_DATA].exists():
if kwargs[KW_E_DATA] is not None and isinstance(kwargs[KW_E_DATA], Path) and not kwargs[KW_E_DATA].exists():
error("The filepath to the E-data is invalid.", 7, kwargs[KW_CLI])
if kwargs[KW_E_WEIGHTS] and isinstance(kwargs[KW_E_WEIGHTS], Path) and not kwargs[KW_E_WEIGHTS].is_file():
if kwargs[KW_E_WEIGHTS] is not None and isinstance(kwargs[KW_E_WEIGHTS], Path) and not kwargs[KW_E_WEIGHTS].is_file():
error("The filepath to the weights of the E-data is invalid.", 8, kwargs[KW_CLI])
if kwargs[KW_E_STRAT] and isinstance(kwargs[KW_E_STRAT], Path) and not kwargs[KW_E_STRAT].is_file():
if kwargs[KW_E_STRAT] is not None and isinstance(kwargs[KW_E_STRAT], Path) and not kwargs[KW_E_STRAT].is_file():
error("The filepath to the stratification of the E-data is invalid.", 11, kwargs[KW_CLI])
if kwargs[KW_E_SIM] and isinstance(kwargs[KW_E_SIM], str) and kwargs[KW_E_SIM].lower() not in SIM_ALGOS:
if kwargs[KW_E_SIM] is not None and isinstance(kwargs[KW_E_SIM], str) and kwargs[KW_E_SIM].lower() not in SIM_ALGOS:
kwargs[KW_E_SIM] = Path(kwargs[KW_E_SIM])
if not kwargs[KW_E_SIM].is_file():
error(f"The similarity metric for the E-data seems to be a file-input but the filepath is invalid.", 9,
kwargs[KW_CLI])
if kwargs[KW_E_DIST] and isinstance(kwargs[KW_E_DIST], str) and kwargs[KW_E_DIST].lower() not in DIST_ALGOS:
if kwargs[KW_E_DIST] is not None and isinstance(kwargs[KW_E_DIST], str) and kwargs[KW_E_DIST].lower() not in DIST_ALGOS:
kwargs[KW_E_DIST] = Path(kwargs[KW_E_DIST])
if not kwargs[KW_E_DIST].is_file():
error(f"The distance metric for the E-data seems to be a file-input but the filepath is invalid.", 10,
Expand All @@ -134,18 +134,18 @@ def validate_args(**kwargs) -> Dict[str, object]:
error("The number of clusters to find in the E-data has to be a positive integer.", 12, kwargs[KW_CLI])

# syntactically parse the input data for the F-dataset
if kwargs[KW_F_DATA] and isinstance(kwargs[KW_F_DATA], Path) and not kwargs[KW_F_DATA].exists():
if kwargs[KW_F_DATA] is not None and isinstance(kwargs[KW_F_DATA], Path) and not kwargs[KW_F_DATA].exists():
error("The filepath to the F-data is invalid.", 13, kwargs[KW_CLI])
if kwargs[KW_F_WEIGHTS] and isinstance(kwargs[KW_F_WEIGHTS], Path) and not kwargs[KW_F_WEIGHTS].is_file():
if kwargs[KW_F_WEIGHTS] is not None and isinstance(kwargs[KW_F_WEIGHTS], Path) and not kwargs[KW_F_WEIGHTS].is_file():
error("The filepath to the weights of the F-data is invalid.", 14, kwargs[KW_CLI])
if kwargs[KW_E_STRAT] and isinstance(kwargs[KW_E_STRAT], Path) and not kwargs[KW_E_STRAT].is_file():
if kwargs[KW_E_STRAT] is not None and isinstance(kwargs[KW_E_STRAT], Path) and not kwargs[KW_E_STRAT].is_file():
error("The filepath to the stratification of the E-data is invalid.", 20, kwargs[KW_CLI])
if kwargs[KW_F_SIM] and isinstance(kwargs[KW_F_SIM], str) and kwargs[KW_F_SIM].lower() not in SIM_ALGOS:
if kwargs[KW_F_SIM] is not None and isinstance(kwargs[KW_F_SIM], str) and kwargs[KW_F_SIM].lower() not in SIM_ALGOS:
kwargs[KW_F_SIM] = Path(kwargs[KW_F_SIM])
if not kwargs[KW_F_SIM].is_file():
error(f"The similarity metric for the F-data seems to be a file-input but the filepath is invalid.", 15,
kwargs[KW_CLI])
if kwargs[KW_F_DIST] and isinstance(kwargs[KW_F_DIST], str) and kwargs[KW_F_DIST].lower() not in DIST_ALGOS:
if kwargs[KW_F_DIST] is not None and isinstance(kwargs[KW_F_DIST], str) and kwargs[KW_F_DIST].lower() not in DIST_ALGOS:
if not kwargs[KW_F_DIST].is_file():
error(f"The distance metric for the F-data seems to be a file-input but the filepath is invalid.", 16,
kwargs[KW_CLI])
Expand Down
4 changes: 0 additions & 4 deletions datasail/solver/cluster_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,7 @@ def solve_c1(
loss = cvxpy.sum([t for tmp_list in tmp for t in tmp_list])
if distances is not None:
loss = -loss
start = time.time()
problem = solve(loss, constraints, max_sec, solver, log_file)
ttime = time.time() - start
with open("strat_timing.txt", "a") as out:
print(delta, epsilon, ttime, sep=",", file=out)

return None if problem is None else {
e: names[s] for s in range(len(splits)) for i, e in enumerate(clusters) if x[s, i].value > 0.1
Expand Down
102 changes: 51 additions & 51 deletions datasail/solver/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,57 +225,57 @@ def gen():
return output


def generate_baseline(
splits: List[float],
weights: Union[np.ndarray, List[float]],
similarities: Optional[np.ndarray],
distances: Optional[np.ndarray],
) -> float:
"""
Generate a baseline solution for the double-cold splitting problem.
Args:
splits: List of relative sizes of the splits
weights: List of weights of the entities
similarities: Pairwise similarity matrix of entities in the order of their names
distances: Pairwise distance matrix of entities in the order of their names
Returns:
The amount of information leakage in a random double-cold splitting
"""
indices = sorted(list(range(len(weights))), key=lambda i: -weights[i])
max_sizes = np.array(splits) * sum(weights)
sizes = [0] * len(splits)
assignments = [-1] * len(weights)
oh_val, oh_idx = float("inf"), -1
for idx in indices:
for s in range(len(splits)):
if sizes[s] + weights[idx] <= max_sizes[s]:
assignments[idx] = s
sizes[s] += weights[idx]
break
elif (sizes[s] + weights[idx]) / max_sizes[s] < oh_val:
oh_val = (sizes[s] + weights[idx]) / max_sizes[s]
oh_idx = s
if assignments[idx] == -1:
assignments[idx] = oh_idx
sizes[oh_idx] += weights[idx]
x = np.zeros((len(assignments), max(assignments) + 1))
x[np.arange(len(assignments)), assignments] = 1
ones = np.ones((1, len(weights)))

if distances is not None:
hit_matrix = np.sum([np.maximum(
(np.expand_dims(x[:, s], axis=1) @ ones) + (np.expand_dims(x[:, s], axis=1) @ ones).T - (ones.T @ ones), 0)
for s in range(len(splits))], axis=0)
leak_matrix = np.multiply(hit_matrix, distances)
else:
hit_matrix = np.sum(
[((np.expand_dims(x[:, s], axis=1) @ ones) - (np.expand_dims(x[:, s], axis=1) @ ones).T) ** 2 for s in
range(len(splits))], axis=0) / (len(splits) - 1)
leak_matrix = np.multiply(hit_matrix, similarities)

return float(np.sum(leak_matrix))
# def generate_baseline(
# splits: List[float],
# weights: Union[np.ndarray, List[float]],
# similarities: Optional[np.ndarray],
# distances: Optional[np.ndarray],
# ) -> float:
# """
# Generate a baseline solution for the double-cold splitting problem.
#
# Args:
# splits: List of relative sizes of the splits
# weights: List of weights of the entities
# similarities: Pairwise similarity matrix of entities in the order of their names
# distances: Pairwise distance matrix of entities in the order of their names
#
# Returns:
# The amount of information leakage in a random double-cold splitting
# """
# indices = sorted(list(range(len(weights))), key=lambda i: -weights[i])
# max_sizes = np.array(splits) * sum(weights)
# sizes = [0] * len(splits)
# assignments = [-1] * len(weights)
# oh_val, oh_idx = float("inf"), -1
# for idx in indices:
# for s in range(len(splits)):
# if sizes[s] + weights[idx] <= max_sizes[s]:
# assignments[idx] = s
# sizes[s] += weights[idx]
# break
# elif (sizes[s] + weights[idx]) / max_sizes[s] < oh_val:
# oh_val = (sizes[s] + weights[idx]) / max_sizes[s]
# oh_idx = s
# if assignments[idx] == -1:
# assignments[idx] = oh_idx
# sizes[oh_idx] += weights[idx]
# x = np.zeros((len(assignments), max(assignments) + 1))
# x[np.arange(len(assignments)), assignments] = 1
# ones = np.ones((1, len(weights)))
#
# if distances is not None:
# hit_matrix = np.sum([np.maximum((np.expand_dims(x[:, s], axis=1) @ ones) +
# (np.expand_dims(x[:, s], axis=1) @ ones).T -
# (ones.T @ ones), 0) for s in range(len(splits))], axis=0)
# leak_matrix = np.multiply(hit_matrix, distances)
# else:
# hit_matrix = np.sum(
# [((np.expand_dims(x[:, s], axis=1) @ ones) - (np.expand_dims(x[:, s], axis=1) @ ones).T) ** 2 for s in
# range(len(splits))], axis=0) / (len(splits) - 1)
# leak_matrix = np.multiply(hit_matrix, similarities)
#
# return float(np.sum(leak_matrix))


def interaction_contraints(
Expand Down
23 changes: 8 additions & 15 deletions examples/bace.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,8 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"is_executing": true
}
},
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%%capture\n",
Expand Down Expand Up @@ -82,12 +78,9 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {
"scrolled": true,
"pycharm": {
"is_executing": true
}
"scrolled": true
},
"outputs": [],
"source": [
Expand Down Expand Up @@ -142,10 +135,10 @@
"\tRun 3 - Type: <class 'dict'> - 1513 assignments\n",
"\n",
"ID: Comp000001 - Split: train\n",
"ID: Comp000002 - Split: test\n",
"ID: Comp000003 - Split: val\n",
"ID: Comp000004 - Split: val\n",
"ID: Comp000005 - Split: val\n"
"ID: Comp000002 - Split: train\n",
"ID: Comp000003 - Split: train\n",
"ID: Comp000004 - Split: train\n",
"ID: Comp000005 - Split: train\n"
]
}
],
Expand Down
Loading

0 comments on commit 17b7b19

Please sign in to comment.