From 3094ea52326007ad663c6ec34acbf46290f17431 Mon Sep 17 00:00:00 2001 From: "c.winger" Date: Mon, 17 Jun 2024 16:03:35 +0200 Subject: [PATCH] stable hashing, batching for offline opt --- offline_optimization_script.py | 21 ++++++++++++++- ptxboa/api_optimize.py | 4 +-- ptxboa/utils.py | 49 +++++++++++++++++++++++++++++++++- tests/test_utills.py | 24 +++++++++++++++++ 4 files changed, 94 insertions(+), 4 deletions(-) diff --git a/offline_optimization_script.py b/offline_optimization_script.py index f5433f04..11c14d89 100644 --- a/offline_optimization_script.py +++ b/offline_optimization_script.py @@ -98,6 +98,8 @@ def main( cache_dir: Path = DEFAULT_CACHE_DIR, out_dir=None, loglevel: Literal["debug", "info", "warning", "error"] = "info", + index_from: int = None, + index_to: int = None, ): cache_dir = Path(cache_dir) cache_dir.mkdir(exist_ok=True) @@ -121,6 +123,11 @@ def main( param_sets = generate_param_sets(api) + # filter for batch + index_from = index_from or 0 + index_to = index_to or len(param_sets) + param_sets = param_sets[index_from:index_to] + results = [] # save results for params in progress.bar.Bar( suffix=( @@ -186,6 +193,18 @@ def main( choices=["debug", "info", "warning", "error"], help="Log level for the console.", ) + parser.add_argument( + "-f", + "--index_from", + type=int, + help="starting index for prallel runs", + ) + parser.add_argument( + "-t", + "--index_to", + type=int, + help="final index (exlusive) for prallel runs", + ) args = parser.parse_args() - main(cache_dir=args.cache_dir, out_dir=args.out_dir, loglevel=args.loglevel) + main(**vars(args)) diff --git a/ptxboa/api_optimize.py b/ptxboa/api_optimize.py index 0802e608..db1ab5c3 100644 --- a/ptxboa/api_optimize.py +++ b/ptxboa/api_optimize.py @@ -19,7 +19,7 @@ from flh_opt.api_opt import optimize from ptxboa import logger from ptxboa.static._types import CalculateDataType -from ptxboa.utils import SingletonMeta, annuity +from ptxboa.utils import SingletonMeta, annuity, serialize_for_hashing def get_data_hash_md5(key: object) -> str: @@ -36,7 +36,7 @@ def get_data_hash_md5(key: object) -> str: md5 hash of a standardized byte representation of the input data """ # serialize to str, make sure to sort keys - sdata = json.dumps(key, sort_keys=True, ensure_ascii=False, indent=0) + sdata = serialize_for_hashing(key) # to bytes (only bytes can be hashed) bdata = sdata.encode() # create hash diff --git a/ptxboa/utils.py b/ptxboa/utils.py index c9975137..3c80a182 100644 --- a/ptxboa/utils.py +++ b/ptxboa/utils.py @@ -1,7 +1,9 @@ # -*- coding: utf-8 -*- """Utilities.""" - +import json import os +from types import NoneType +from typing import Union def annuity(rate: float, periods: int, value: float) -> float: @@ -47,3 +49,48 @@ def is_test(): "PYTEST_CURRENT_TEST" in os.environ or "STREAMLIT_GLOBAL_UNIT_TEST" in os.environ ) + + +def serialize_for_hashing( + obj: Union[NoneType, int, float, str, bool, dict, list], float_sig_digits=6 +) -> str: + """Serialize data for hashing. + + - custom function to ensure same results for differrent python versions + (json dumps changes sometimes?) + - + + Parameters + ---------- + obj : Union[NoneType, int, float, str, dict, list] + data + float_sig_digits : int, optional + number of significat digits (in scientific notation) + + Returns + ------- + str + string serialization + """ + if isinstance(obj, list): + return "[" + ",".join(serialize_for_hashing(x) for x in obj) + "]" + elif isinstance(obj, dict): + # map keys to sorted + obj_ = { + serialize_for_hashing(k): serialize_for_hashing(v) for k, v in obj.items() + } + return "{" + ",".join(k + ":" + v for k, v in sorted(obj_.items())) + "}" + elif isinstance(obj, bool): + # NOTE: MUST come before test for + return "true" if obj is True else "false" + elif isinstance(obj, str): + # use json to take care of line breaks and other escaping + return json.dumps(obj, ensure_ascii=False) + elif isinstance(obj, int): + return str(obj) + elif isinstance(obj, float): + return f"%.{float_sig_digits}e" % obj + elif obj is None: + return "null" + else: + raise NotImplementedError(type(obj)) diff --git a/tests/test_utills.py b/tests/test_utills.py index 08885f54..49a755b1 100644 --- a/tests/test_utills.py +++ b/tests/test_utills.py @@ -1,7 +1,10 @@ # -*- coding: utf-8 -*- """Tests for utils module.""" +import json import unittest +from ptxboa.utils import serialize_for_hashing + from .utils import assert_deep_equal @@ -25,3 +28,24 @@ def test_assert_deep_equal(self): self.assertRaises(ValueError, assert_deep_equal, [], [1]) self.assertRaises(ValueError, assert_deep_equal, {"a": 1}, {"b": 1}) self.assertRaises(ValueError, assert_deep_equal, {"a": 1}, {"a": 2}) + + def test_serialize_for_hashing(self): + """Test for ptxboa.utils.serialize_for_hashing.""" + for obj, exp_str in [ + ("text", '"text"'), + (123, "123"), + (123.0, "1.230000e+02"), + (-123.0, "-1.230000e+02"), + (-123.4567, "-1.234567e+02"), + (0.0000001234567, "1.234567e-07"), + (True, "true"), + (False, "false"), + ([], "[]"), + ({}, "{}"), + ([1, {"b": 2, "a": [None]}], '[1,{"a":[null],"b":2}]'), + ]: + + res = serialize_for_hashing(obj) + # must be json loadable + json.loads(res) + self.assertEqual(res, exp_str)