-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy path_get_batch_descriptors.py
95 lines (74 loc) · 2.99 KB
/
_get_batch_descriptors.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import warnings
from dataclasses import dataclass
import pandas as pd
from scipy.stats import wasserstein_distance
from sklearn.preprocessing import minmax_scale
from ._large_molecule_descriptors import LargeMoleculeDescriptors
class MetricWarning(RuntimeWarning):
pass
LARGE_MOL_FIGSIZE = (3, 5)
NO_VALID_DESIGNS_WARNING = "There were no valid designs."
warnings.simplefilter("always", category=MetricWarning)
@dataclass
class MetricColumnInfo:
feature_columns: list[str]
sample_column: str
figshape: tuple[int, int]
def __post_init__(self):
self.feature_columns = sorted(set(self.feature_columns))
def get_column_info(chain) -> MetricColumnInfo:
match chain:
case "fv_heavy":
return MetricColumnInfo(
[f"fv_heavy_{feature}" for feature in LargeMoleculeDescriptors.descriptor_names()],
"fv_heavy_aho",
LARGE_MOL_FIGSIZE,
)
case "fv_light":
return MetricColumnInfo(
[f"fv_light_{feature}" for feature in LargeMoleculeDescriptors.descriptor_names()],
"fv_light_aho",
LARGE_MOL_FIGSIZE,
)
def get_batch_descriptors(
sample_df: pd.DataFrame, ref_feats: pd.DataFrame, chain
) -> tuple[dict[str, float], float, float, float]:
"""
Compute aggregate statistics for a collection of samples compared to reference.
Parameters
----------
sample_df: pd.DataFrame
Collection of samples, generally this would be a return value of
`walkjump.callbacks.sample_and_compute_metrics()`
ref_feats: pd.DataFrame
Pre-computed reference distributions
chain: ReferenceChainType
Type of input molecule. Behavior switches based on molecule type.
Returns
-------
Tuple[Dict[str, float], float, float, float]
Wasserstein distances per statistic column and the
(average wass. dist., total wass. dist., proportion not NaN)
"""
info = get_column_info(chain)
try:
prop_valid = float(sample_df[info.sample_column].notna().sum()) / len(sample_df)
except ZeroDivisionError:
warnings.warn(NO_VALID_DESIGNS_WARNING, category=MetricWarning, stacklevel=2)
prop_valid = 0.0
wasserstein_distances = {}
for column in info.feature_columns:
# filter out NaN rows for this column
valid = sample_df.loc[sample_df[column].notna(), column]
valid_ref = ref_feats.loc[ref_feats[column].notna(), column]
# min/max norm the validated rows.
try:
normed = minmax_scale(valid)
normed_ref = minmax_scale(valid_ref)
# compute wasserstein
wasserstein_distances[f"{column}_wd"] = wasserstein_distance(normed, normed_ref)
except ValueError:
wasserstein_distances[f"{column}_wd"] = float("inf")
total_wd = sum(wasserstein_distances.values())
avg_wd = total_wd / len(info.feature_columns)
return wasserstein_distances, avg_wd, total_wd, prop_valid