From 7ea646db7fd55c2e016345dc4da0aea76bd82a30 Mon Sep 17 00:00:00 2001 From: Milad LEYLI ABADI Date: Tue, 21 May 2024 16:51:11 +0200 Subject: [PATCH] Physics metrics updated for power grid use case to return 0 when no violation (#40) --- lips/metrics/power_grid/physics_compliances.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/lips/metrics/power_grid/physics_compliances.py b/lips/metrics/power_grid/physics_compliances.py index 3c089cd..02eb427 100644 --- a/lips/metrics/power_grid/physics_compliances.py +++ b/lips/metrics/power_grid/physics_compliances.py @@ -55,8 +55,8 @@ def verify_current_pos(predictions: dict, except KeyError: logger.error("%s does not exists in predictions dict", key_) raise + verifications[key_] = {} if np.any(a_arr < 0): - verifications[key_] = {} a_or_errors = np.array(np.where(a_arr < 0)).T a_or_violation_proportion = (1.0 * len(a_or_errors)) / a_arr.size error_a_or = -np.sum(np.minimum(a_arr.flatten(), 0.)) @@ -67,6 +67,7 @@ def verify_current_pos(predictions: dict, verifications[key_]["Violation_proportion"] = float(a_or_violation_proportion) else: logger.info("Current positivity check passed for %s", key_) + verifications[key_]["Violation_proportion"] = 0. return verifications def verify_voltage_pos(predictions:dict, @@ -101,8 +102,8 @@ def verify_voltage_pos(predictions:dict, except KeyError: logger.error("%s does not exists in predictions dict", key_) raise + verifications[key_] = {} if np.any(v_arr < 0): - verifications[key_] = {} v_or_errors = np.array(np.where(v_arr < 0)).T v_or_violation_proportion = len(v_or_errors) / v_arr.size error_v_or = -np.sum(np.minimum(v_arr.flatten(), 0.)) @@ -113,6 +114,7 @@ def verify_voltage_pos(predictions:dict, verifications[key_]["Violation_proportion"] = float(v_or_violation_proportion) else: logger.info("Voltage positivity check passed for %s", key_) + verifications[key_]["Violation_proportion"] = 0. return verifications def verify_loss_pos(predictions: dict, @@ -159,6 +161,7 @@ def verify_loss_pos(predictions: dict, verifications["violation_proportion"] = float(loss_violation_proportion) else: logger.info("Loss positivity check passed") + verifications["violation_proportion"] = 0. return verifications def verify_disc_lines(predictions: dict,