From 72456130842b128e9e13656dc75de463f4fe8517 Mon Sep 17 00:00:00 2001 From: Max Halford Date: Mon, 11 Sep 2023 12:19:27 +0200 Subject: [PATCH] tidy up lof --- river/anomaly/lof.py | 589 +++++++++++++++++-------------------- river/anomaly/test_ilof.py | 4 +- river/test_estimators.py | 1 + 3 files changed, 274 insertions(+), 320 deletions(-) diff --git a/river/anomaly/lof.py b/river/anomaly/lof.py index e66104c4e5..e523a23121 100644 --- a/river/anomaly/lof.py +++ b/river/anomaly/lof.py @@ -8,6 +8,140 @@ from river.neighbors.base import DistanceFunc +def check_equal(x_list: list, y_list: list): + """ + Check if new list of observations (x_list) has any data sample that is equal to any previous data recorded (y_list). + """ + result = [x for x in x_list if not any(x == y for y in y_list)] + return result, len(x_list) - len(result) + + +def expand_objects( + new_particles: list, + x_list: list, + neighborhoods: dict, + rev_neighborhoods: dict, + k_dist: dict, + reach_dist: dict, + dist_dict: dict, + local_reach: dict, + lof: dict, +): + """ + Expand size of dictionaries and lists to take into account new data points. + """ + n = len(x_list) + m = len(new_particles) + x_list.extend(new_particles) + neighborhoods.update({i: [] for i in range(n + m)}) + rev_neighborhoods.update({i: [] for i in range(n + m)}) + k_dist.update({i: float("inf") for i in range(n + m)}) + reach_dist.update({i + n: {} for i in range(m)}) + dist_dict.update({i + n: {} for i in range(m)}) + local_reach.update({i + n: [] for i in range(m)}) + lof.update({i + n: [] for i in range(m)}) + return ( + (n, m), + x_list, + neighborhoods, + rev_neighborhoods, + k_dist, + reach_dist, + dist_dict, + local_reach, + lof, + ) + + +def define_sets(nm, neighborhoods: dict, rev_neighborhoods: dict): + """ + Define sets of points for the incremental LOF algorithm. + """ + # Define set of new points from batch + set_new_points = set(range(nm[0], nm[0] + nm[1])) + set_neighbors: set = set() + set_rev_neighbors: set = set() + + # Define neighbors and reverse neighbors of new data points + for i in set_new_points: + set_neighbors = set(set_neighbors) | set(neighborhoods[i]) + set_rev_neighbors = set(set_rev_neighbors) | set(rev_neighborhoods[i]) + + # Define points that need to update their local reachability distance because of new data points + set_upd_lrd = set_rev_neighbors + for j in set_rev_neighbors: + set_upd_lrd = set_upd_lrd | set(rev_neighborhoods[j]) + set_upd_lrd = set_upd_lrd | set_new_points + + # Define points that need to update their lof because of new data points + set_upd_lof = set_upd_lrd + for m in set_upd_lrd: + set_upd_lof = set_upd_lof | set(rev_neighborhoods[m]) + set_upd_lof = set_upd_lof + + return set_new_points, set_neighbors, set_rev_neighbors, set_upd_lrd, set_upd_lof + + +def calc_reach_dist_new_points( + set_index: set, + neighborhoods: dict, + rev_neighborhoods: dict, + reach_dist: dict, + dist_dict: dict, + k_dist: dict, +): + """ + Calculate reachability distance from new points to neighbors and from neighbors to new points. + """ + for c in set_index: + for j in set(neighborhoods[c]): + reach_dist[c][j] = max(dist_dict[c][j], k_dist[j]) + for j in set(rev_neighborhoods[c]): + reach_dist[j][c] = max(dist_dict[j][c], k_dist[c]) + return reach_dist + + +def calc_reach_dist_other_points( + set_index: set, + rev_neighborhoods: dict, + reach_dist: dict, + dist_dict: dict, + k_dist: dict, +): + """ + Calculate reachability distance from reverse neighbors of reverse neighbors ( RkNN(RkNN(NewPoints)) ) + to reverse neighbors ( RkNN(NewPoints) ). These values change due to the insertion of new points. + """ + for j in set_index: + for i in set(rev_neighborhoods[j]): + reach_dist[i][j] = max(dist_dict[i][j], k_dist[j]) + return reach_dist + + +def calc_local_reach_dist( + set_index: set, neighborhoods: dict, reach_dist: dict, local_reach_dist: dict +): + """ + Calculate local reachability distance of affected points. + """ + for i in set_index: + local_reach_dist[i] = len(neighborhoods[i]) / sum( + [reach_dist[i][j] for j in neighborhoods[i]] + ) + return local_reach_dist + + +def calc_lof(set_index: set, neighborhoods: dict, local_reach: dict, lof: dict): + """ + Calculate local outlier factor (LOF) of affected points. + """ + for i in set_index: + lof[i] = sum([local_reach[j] for j in neighborhoods[i]]) / ( + len(neighborhoods[i]) * local_reach[i] + ) + return lof + + class LocalOutlierFactor(anomaly.base.AnomalyDetector): """Incremental Local Outlier Factor (Incremental LOF). @@ -43,8 +177,6 @@ class LocalOutlierFactor(anomaly.base.AnomalyDetector): The number of nearest neighbors to use for density estimation. distance_func Distance function to be used. By default, the Euclidean distance is used. - verbose - Whether to print warning/messages Attributes ---------- @@ -69,28 +201,27 @@ class LocalOutlierFactor(anomaly.base.AnomalyDetector): local_reach A dictionary to hold local reachability distances for each observation. - Example - ---------- + Examples + -------- + >>> import pandas as pd >>> from river import anomaly >>> from river import datasets - >>> import pandas as pd >>> cc_df = pd.DataFrame(datasets.CreditCard()) - >>> k = 20 # Define number of nearest neighbors - >>> incremental_lof = anomaly.LocalOutlierFactor(k, verbose=False) + >>> lof = anomaly.LocalOutlierFactor(n_neighbors=20) >>> for x, _ in datasets.CreditCard().take(200): - ... incremental_lof.learn_one(x) + ... lof.learn_one(x) - >>> incremental_lof.learn_many(cc_df[201:401]) + >>> lof.learn_many(cc_df[201:401]) - >>> ilof_scores = [] + >>> scores = [] >>> for x in cc_df[0][401:406]: - ... ilof_scores.append(incremental_lof.score_one(x)) + ... scores.append(lof.score_one(x)) - >>> [round(ilof_score, 3) for ilof_score in ilof_scores] + >>> [round(score, 3) for score in scores] [1.802, 1.937, 1.567, 1.181, 1.28] References @@ -98,12 +229,12 @@ class LocalOutlierFactor(anomaly.base.AnomalyDetector): David Pokrajac, Aleksandar Lazarevic, and Longin Jan Latecki (2007). Incremental Local Outlier Detection for Data Streams. In: Proceedings of the 2007 IEEE Symposium on Computational Intelligence and Data Mining (CIDM 2007). 504-515. DOI: 10.1109/CIDM.2007.368917. + """ def __init__( self, n_neighbors: int = 10, - verbose=True, distance_func: DistanceFunc = None, ): self.n_neighbors = n_neighbors @@ -117,7 +248,7 @@ def __init__( self.reach_dist: dict = {} self.lof: dict = {} self.local_reach: dict = {} - self.verbose = verbose + self.distance_func = distance_func self.distance = ( distance_func if distance_func is not None @@ -125,193 +256,148 @@ def __init__( ) def learn_many(self, x: pd.DataFrame): - """ - Update the model with multiple incoming observations simultaneously. - This function assumes that the observations are stored in the first column of the dataset. - - Parameters - ---------- - x - A Pandas DataFrame including multiple instances to be learned at the same time - """ x = x[0].tolist() self.learn(x) def learn_one(self, x: dict): - """ - Update the model with one incoming observation - - Parameters - ---------- - x - A dictionary of feature values. - """ self.x_batch.append(x) if len(self.x_list) or len(self.x_batch) > 1: self.learn(self.x_batch) self.x_batch = [] def learn(self, x_batch: list): - x_batch, equal = self.check_equal(x_batch, self.x_list) - if equal != 0 and self.verbose: - print("At least one sample is equal to previously observed instances.") - - if len(x_batch) == 0: - if self.verbose: - print("No new data was added.") - else: - # Increase size of objects to accommodate new data - ( - nm, - self.x_list, - self.neighborhoods, - self.rev_neighborhoods, - self.k_dist, - self.reach_dist, - self.dist_dict, - self.local_reach, - self.lof, - ) = self.expand_objects( - x_batch, - self.x_list, - self.neighborhoods, - self.rev_neighborhoods, - self.k_dist, - self.reach_dist, - self.dist_dict, - self.local_reach, - self.lof, - ) - - # Calculate neighborhoods, reverse neighborhoods, k-distances and distances between neighbors - ( - self.neighborhoods, - self.rev_neighborhoods, - self.k_dist, - self.dist_dict, - ) = self.initial_calculations( - self.x_list, - nm, - self.neighborhoods, - self.rev_neighborhoods, - self.k_dist, - self.dist_dict, - ) - - # Define sets of particles - ( - set_new_points, - set_neighbors, - set_rev_neighbors, - set_upd_lrd, - set_upd_lof, - ) = self.define_sets(nm, self.neighborhoods, self.rev_neighborhoods) - - # Calculate new reachability distance of all affected points - self.reach_dist = self.calc_reach_dist_new_points( - set_new_points, - self.neighborhoods, - self.rev_neighborhoods, - self.reach_dist, - self.dist_dict, - self.k_dist, - ) - self.reach_dist = self.calc_reach_dist_other_points( - set_rev_neighbors, - self.rev_neighborhoods, - self.reach_dist, - self.dist_dict, - self.k_dist, - ) - - # Calculate new local reachability distance of all affected points - self.local_reach = self.calc_local_reach_dist( - set_upd_lrd, self.neighborhoods, self.reach_dist, self.local_reach - ) - - # Calculate new Local Outlier Factor of all affected points - self.lof = self.calc_lof(set_upd_lof, self.neighborhoods, self.local_reach, self.lof) + x_batch, equal = check_equal(x_batch, self.x_list) + + # Increase size of objects to accommodate new data + ( + nm, + self.x_list, + self.neighborhoods, + self.rev_neighborhoods, + self.k_dist, + self.reach_dist, + self.dist_dict, + self.local_reach, + self.lof, + ) = expand_objects( + x_batch, + self.x_list, + self.neighborhoods, + self.rev_neighborhoods, + self.k_dist, + self.reach_dist, + self.dist_dict, + self.local_reach, + self.lof, + ) - def score_one(self, x: dict): - """ - Score a new incoming observation based on model constructed previously. - Perform same calculations as 'learn_one' function but doesn't add the new calculations to the attributes - Data samples that are equal to samples stored by the model are not considered. + # Calculate neighborhoods, reverse neighborhoods, k-distances and distances between neighbors + ( + self.neighborhoods, + self.rev_neighborhoods, + self.k_dist, + self.dist_dict, + ) = self._initial_calculations( + self.x_list, + nm, + self.neighborhoods, + self.rev_neighborhoods, + self.k_dist, + self.dist_dict, + ) - Parameters - ---------- - x - A dictionary of feature values. + # Define sets of particles + ( + set_new_points, + set_neighbors, + set_rev_neighbors, + set_upd_lrd, + set_upd_lof, + ) = define_sets(nm, self.neighborhoods, self.rev_neighborhoods) + + # Calculate new reachability distance of all affected points + self.reach_dist = calc_reach_dist_new_points( + set_new_points, + self.neighborhoods, + self.rev_neighborhoods, + self.reach_dist, + self.dist_dict, + self.k_dist, + ) + self.reach_dist = calc_reach_dist_other_points( + set_rev_neighbors, + self.rev_neighborhoods, + self.reach_dist, + self.dist_dict, + self.k_dist, + ) - Returns - ------- - lof : list - List of LOF calculated for incoming data - """ + # Calculate new local reachability distance of all affected points + self.local_reach = calc_local_reach_dist( + set_upd_lrd, self.neighborhoods, self.reach_dist, self.local_reach + ) - self.x_scores.append(x) + # Calculate new Local Outlier Factor of all affected points + self.lof = calc_lof(set_upd_lof, self.neighborhoods, self.local_reach, self.lof) - self.x_scores, equal = self.check_equal(self.x_scores, self.x_list) - if equal != 0 and self.verbose: - print("The new observation is the same to one of the previously observed instances.") + def score_one(self, x: dict): + self.x_scores.append(x) + self.x_scores, equal = check_equal(self.x_scores, self.x_list) if len(self.x_scores) == 0: - if self.verbose: - print("No new data was added.") - else: - x_list_copy = self.x_list.copy() - ( - nm, - x_list_copy, - neighborhoods, - rev_neighborhoods, - k_dist, - reach_dist, - dist_dict, - local_reach, - lof, - ) = self.expand_objects( - self.x_scores, - x_list_copy, - self.neighborhoods, - self.rev_neighborhoods, - self.k_dist, - self.reach_dist, - self.dist_dict, - self.local_reach, - self.lof, - ) - - neighborhoods, rev_neighborhoods, k_dist, dist_dict = self.initial_calculations( - x_list_copy, nm, neighborhoods, rev_neighborhoods, k_dist, dist_dict - ) - ( - set_new_points, - set_neighbors, - set_rev_neighbors, - set_upd_lrd, - set_upd_lof, - ) = self.define_sets(nm, neighborhoods, rev_neighborhoods) - reach_dist = self.calc_reach_dist_new_points( - set_new_points, neighborhoods, rev_neighborhoods, reach_dist, dist_dict, k_dist - ) - reach_dist = self.calc_reach_dist_other_points( - set_rev_neighbors, - rev_neighborhoods, - reach_dist, - dist_dict, - k_dist, - ) - local_reach = self.calc_local_reach_dist( - set_upd_lrd, neighborhoods, reach_dist, local_reach - ) - lof = self.calc_lof(set_upd_lof, neighborhoods, local_reach, lof) - self.x_scores = [] - - # Use nm[0] as index since upon this configuration nm[1] is expected to be 1. - return lof[nm[0]] - - def initial_calculations( + return None + + x_list_copy = self.x_list.copy() + ( + nm, + x_list_copy, + neighborhoods, + rev_neighborhoods, + k_dist, + reach_dist, + dist_dict, + local_reach, + lof, + ) = expand_objects( + self.x_scores, + x_list_copy, + self.neighborhoods, + self.rev_neighborhoods, + self.k_dist, + self.reach_dist, + self.dist_dict, + self.local_reach, + self.lof, + ) + + neighborhoods, rev_neighborhoods, k_dist, dist_dict = self._initial_calculations( + x_list_copy, nm, neighborhoods, rev_neighborhoods, k_dist, dist_dict + ) + ( + set_new_points, + set_neighbors, + set_rev_neighbors, + set_upd_lrd, + set_upd_lof, + ) = define_sets(nm, neighborhoods, rev_neighborhoods) + reach_dist = calc_reach_dist_new_points( + set_new_points, neighborhoods, rev_neighborhoods, reach_dist, dist_dict, k_dist + ) + reach_dist = calc_reach_dist_other_points( + set_rev_neighbors, + rev_neighborhoods, + reach_dist, + dist_dict, + k_dist, + ) + local_reach = calc_local_reach_dist(set_upd_lrd, neighborhoods, reach_dist, local_reach) + lof = calc_lof(set_upd_lof, neighborhoods, local_reach, lof) + self.x_scores = [] + + # Use nm[0] as index since upon this configuration nm[1] is expected to be 1. + return lof[nm[0]] + + def _initial_calculations( self, x_list: list, nm: tuple, @@ -349,6 +435,7 @@ def initial_calculations( Updated dictionary to hold k-distances for each observation dist_dict Updated dictionary of dictionaries storing distances between particles + """ n = nm[0] @@ -387,137 +474,3 @@ def initial_calculations( rev_neighborhoods[neighbor_id].append(particle_id) return neighborhoods, rev_neighborhoods, k_distances, dist_dict - - @staticmethod - def check_equal(x_list: list, y_list: list): - """ - Check if new list of observations (x_list) has any data sample that is equal to any previous data recorded (y_list). - """ - result = [x for x in x_list if not any(x == y for y in y_list)] - return result, len(x_list) - len(result) - - @staticmethod - def expand_objects( - new_particles: list, - x_list: list, - neighborhoods: dict, - rev_neighborhoods: dict, - k_dist: dict, - reach_dist: dict, - dist_dict: dict, - local_reach: dict, - lof: dict, - ): - """ - Expand size of dictionaries and lists to take into account new data points. - """ - n = len(x_list) - m = len(new_particles) - x_list.extend(new_particles) - neighborhoods.update({i: [] for i in range(n + m)}) - rev_neighborhoods.update({i: [] for i in range(n + m)}) - k_dist.update({i: float("inf") for i in range(n + m)}) - reach_dist.update({i + n: {} for i in range(m)}) - dist_dict.update({i + n: {} for i in range(m)}) - local_reach.update({i + n: [] for i in range(m)}) - lof.update({i + n: [] for i in range(m)}) - return ( - (n, m), - x_list, - neighborhoods, - rev_neighborhoods, - k_dist, - reach_dist, - dist_dict, - local_reach, - lof, - ) - - @staticmethod - def define_sets(nm, neighborhoods: dict, rev_neighborhoods: dict): - """ - Define sets of points for the incremental LOF algorithm. - """ - # Define set of new points from batch - set_new_points = set(range(nm[0], nm[0] + nm[1])) - set_neighbors: set = set() - set_rev_neighbors: set = set() - - # Define neighbors and reverse neighbors of new data points - for i in set_new_points: - set_neighbors = set(set_neighbors) | set(neighborhoods[i]) - set_rev_neighbors = set(set_rev_neighbors) | set(rev_neighborhoods[i]) - - # Define points that need to update their local reachability distance because of new data points - set_upd_lrd = set_rev_neighbors - for j in set_rev_neighbors: - set_upd_lrd = set_upd_lrd | set(rev_neighborhoods[j]) - set_upd_lrd = set_upd_lrd | set_new_points - - # Define points that need to update their lof because of new data points - set_upd_lof = set_upd_lrd - for m in set_upd_lrd: - set_upd_lof = set_upd_lof | set(rev_neighborhoods[m]) - set_upd_lof = set_upd_lof - - return set_new_points, set_neighbors, set_rev_neighbors, set_upd_lrd, set_upd_lof - - @staticmethod - def calc_reach_dist_new_points( - set_index: set, - neighborhoods: dict, - rev_neighborhoods: dict, - reach_dist: dict, - dist_dict: dict, - k_dist: dict, - ): - """ - Calculate reachability distance from new points to neighbors and from neighbors to new points. - """ - for c in set_index: - for j in set(neighborhoods[c]): - reach_dist[c][j] = max(dist_dict[c][j], k_dist[j]) - for j in set(rev_neighborhoods[c]): - reach_dist[j][c] = max(dist_dict[j][c], k_dist[c]) - return reach_dist - - @staticmethod - def calc_reach_dist_other_points( - set_index: set, - rev_neighborhoods: dict, - reach_dist: dict, - dist_dict: dict, - k_dist: dict, - ): - """ - Calculate reachability distance from reverse neighbors of reverse neighbors ( RkNN(RkNN(NewPoints)) ) - to reverse neighbors ( RkNN(NewPoints) ). These values change due to the insertion of new points. - """ - for j in set_index: - for i in set(rev_neighborhoods[j]): - reach_dist[i][j] = max(dist_dict[i][j], k_dist[j]) - return reach_dist - - @staticmethod - def calc_local_reach_dist( - set_index: set, neighborhoods: dict, reach_dist: dict, local_reach_dist: dict - ): - """ - Calculate local reachability distance of affected points. - """ - for i in set_index: - local_reach_dist[i] = len(neighborhoods[i]) / sum( - [reach_dist[i][j] for j in neighborhoods[i]] - ) - return local_reach_dist - - @staticmethod - def calc_lof(set_index: set, neighborhoods: dict, local_reach: dict, lof: dict): - """ - Calculate local outlier factor (LOF) of affected points. - """ - for i in set_index: - lof[i] = sum([local_reach[j] for j in neighborhoods[i]]) / ( - len(neighborhoods[i]) * local_reach[i] - ) - return lof diff --git a/river/anomaly/test_ilof.py b/river/anomaly/test_ilof.py index b3855fba79..ac584173c1 100644 --- a/river/anomaly/test_ilof.py +++ b/river/anomaly/test_ilof.py @@ -32,7 +32,7 @@ def test_incremental_lof_scores(): df_train = pd.DataFrame({"observations": x_train_dict, "ground_truth": ground_truth}) x_pred = np.random.uniform(low=-5, high=5, size=(30, 2)) x_pred_dict = [{f"feature_{i + 1}": elem[i] for i in range(2)} for elem in x_pred] - incremental_lof = anomaly.LocalOutlierFactor(n_neighbors=20, verbose=False) + incremental_lof = anomaly.LocalOutlierFactor(n_neighbors=20) for x in df_train["observations"]: incremental_lof.learn_one(x) @@ -62,7 +62,7 @@ def test_batch_lof_scores(): batch_sizes = [20, 50, 100] for batch_size in batch_sizes: - ilof_river_batch = anomaly.LocalOutlierFactor(n_neighbors=20, verbose=False) + ilof_river_batch = anomaly.LocalOutlierFactor(n_neighbors=20) ilof_river_batch.learn_many(cc_df[0:batch_size]) ilof_scores_river_batch = np.array([v for v in ilof_river_batch.lof.values()]) diff --git a/river/test_estimators.py b/river/test_estimators.py index e5c3ded948..18aacf32f2 100644 --- a/river/test_estimators.py +++ b/river/test_estimators.py @@ -53,6 +53,7 @@ def iter_estimators_which_can_be_tested(): ignored = ( River2SKLBase, SKL2RiverBase, + anomaly.LocalOutlierFactor, # needs warm-start to work correctly compose.FuncTransformer, compose.Grouper, compose.Pipeline,