From e35047eb19a74d829b7623a40df298c7715b2d60 Mon Sep 17 00:00:00 2001 From: 1andrin <115493865+1andrin@users.noreply.github.com> Date: Fri, 15 Nov 2024 11:32:01 +0100 Subject: [PATCH] linter Signed-off-by: 1andrin <115493865+1andrin@users.noreply.github.com> --- causaltune/erupt.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/causaltune/erupt.py b/causaltune/erupt.py index c2ecc544..25571923 100644 --- a/causaltune/erupt.py +++ b/causaltune/erupt.py @@ -73,12 +73,12 @@ def score( return (w * outcome).mean() def weights( - self, - df: pd.DataFrame, - policy: Union[Callable, np.ndarray, pd.Series] + self, df: pd.DataFrame, policy: Union[Callable, np.ndarray, pd.Series] ) -> pd.Series: W = df[self.treatment_name].astype(int) - assert all([x >= 0 for x in W.unique()]), "Treatment values must be non-negative integers" + assert all( + [x >= 0 for x in W.unique()] + ), "Treatment values must be non-negative integers" # Handle policy input if callable(policy): @@ -87,7 +87,9 @@ def weights( policy = policy.values policy = np.array(policy) d = pd.Series(index=df.index, data=policy) - assert all([x >= 0 for x in d.unique()]), "Policy values must be non-negative integers" + assert all( + [x >= 0 for x in d.unique()] + ), "Policy values must be non-negative integers" # Get propensity scores with better handling of edge cases if isinstance(self.propensity_model, DummyPropensity): @@ -98,25 +100,25 @@ def weights( except Exception: # Fallback to safe defaults if prediction fails p = np.full((len(df), 2), 0.5) - + # Clip propensity scores to avoid division by zero or extreme weights min_clip = max(1e-6, self.clip) # Ensure minimum clip is not too small p = np.clip(p, min_clip, 1 - min_clip) - # Initialize weights + # Initialize weights weight = np.zeros(len(df)) - + try: # Calculate weights with safer operations for i in W.unique(): - mask = (W == i) + mask = W == i p_i = p[:, i][mask] # Add small constant to denominator to prevent division by zero weight[mask] = 1 / (p_i + 1e-10) except Exception: # If something goes wrong, return safe weights weight = np.ones(len(df)) - + # Zero out weights where policy disagrees with actual treatment weight[d != W] = 0.0 @@ -133,12 +135,12 @@ def weights( else: # If all weights are zero, use uniform weights weight = np.ones(len(df)) / len(df) - + # Final check for NaNs if np.any(np.isnan(weight)): # Replace any remaining NaNs with uniform weights weight = np.ones(len(df)) / len(df) - + return pd.Series(index=df.index, data=weight) def probabilistic_erupt_score(