From 8219965be21770caf07443bf8979706f6227901e Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Sat, 6 Aug 2022 13:10:31 +0200 Subject: [PATCH] fix: str and int in arg parsing (#143) --- discoart/config.py | 2 +- tests/test_config.py | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/discoart/config.py b/discoart/config.py index 3083087..cb37bf8 100644 --- a/discoart/config.py +++ b/discoart/config.py @@ -56,7 +56,7 @@ def load_config( int_keys.add('seed') for k, v in cfg.items(): - if k in int_keys and v is not None and not isinstance(v, int): + if k in int_keys and v is not None and not isinstance(v, (int, str)): cfg[k] = int(v) if k == 'width_height': cfg[k] = [int(vv) for vv in v] diff --git a/tests/test_config.py b/tests/test_config.py index 32c102d..b632c6a 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -61,3 +61,42 @@ def test_eval_schedule_string(): ) def test_chec_schedule_str(val, expected): assert _is_valid_schedule_str(val) == expected + + +@pytest.mark.parametrize( + 'field', + [ + 'cut_overview', + 'cut_innercut', + 'cut_icgray_p', + 'cut_ic_pow', + 'use_secondary_model', + 'cutn_batches', + 'clip_guidance_scale', + 'tv_scale', + 'range_scale', + 'sat_scale', + 'init_scale', + 'clamp_grad', + 'clamp_max', + ], +) +@pytest.mark.parametrize( + 'val', + [ + True, + False, + 1, + 0.5, + 'True', + 'False', + '1', + '0.5', + '[100]*600+[200]*400', + '[True, False]*1000', + ], +) +def test_eval_config(field, val): + cfg = load_config(default_args) + cfg[field] = val + assert load_config(cfg)[field] is not None