Skip to content

Commit

Permalink
Simplify sweep flags
Browse files Browse the repository at this point in the history
Sweeps are now passed directly as `--cfg.model="{'__qualname__': 'kd.nn:Transformer', 'num_layers': 16}"`, no need for an additional special flag

PiperOrigin-RevId: 697997893
  • Loading branch information
Conchylicultor authored and The kauldron Authors committed Nov 19, 2024
1 parent 1f99873 commit e40aa7a
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 115 deletions.
16 changes: 8 additions & 8 deletions kauldron/utils/colab.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@
"""Colab utils."""

import enum
import json
import types

from etils import epy
from kauldron import konfig
from kauldron.utils import sweep_utils
from kauldron import kontext


with epy.lazy_imports():
Expand Down Expand Up @@ -70,15 +69,16 @@ def iter_sweep_configs(
cfg = module.get_config(config_args)
else:
cfg = module.get_config()
# TODO(epot): Display the sweep short name (workdir) and config.
sweep_json = sweep_item.job_kwargs[kauldron_utils.SWEEP_FLAG_NAME]
cfg = sweep_utils.update_with_sweep(
config=cfg,
sweep_kwargs=sweep_json,

sweep_kwargs = kauldron_utils.deserialize_job_kwargs(
sweep_item.job_kwargs
)
# TODO(epot): Display the sweep short name (workdir) and config.
for k, v in sweep_kwargs.items():
kontext.set_by_path(cfg, k, v)

# Only for visualization.
sweep_cfg_overwrites = konfig.ConfigDict(json.loads(sweep_json))
sweep_cfg_overwrites = konfig.ConfigDict(sweep_kwargs)
print(f'Work-unit {i+1}:', flush=True)
ecolab.disp(sweep_cfg_overwrites)

Expand Down
77 changes: 0 additions & 77 deletions kauldron/utils/sweep_utils.py

This file was deleted.

54 changes: 38 additions & 16 deletions kauldron/utils/sweep_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Test."""

import json
import contextlib
import shlex
import sys

from absl import flags
from kauldron import kd
from examples import mnist_autoencoder
from kauldron.utils import sweep_utils
from kauldron.xm._src import kauldron_utils
from kauldron.xm._src import sweep_cfg_utils

Expand Down Expand Up @@ -53,27 +53,49 @@ def test_sweep():
assert len(all_sweep_info) == 4 # Cross product

sweep0 = kauldron_utils._encode_sweep_item(all_sweep_info[0])
assert json.loads(sweep0.job_kwargs[sweep_utils._FLAG_NAME]) == {
sweep0 = kauldron_utils.deserialize_job_kwargs(sweep0.job_kwargs)
assert sweep0 == {
'eval_ds.batch_size': 16,
'train_ds.batch_size': 16,
'model': {'__qualname__': 'flax.linen:Dense', '0': 12},
}


def test_sweep_overwrite():
assert sweep_utils._FLAG_NAME == kauldron_utils.SWEEP_FLAG_NAME

cfg = mnist_autoencoder.get_config()
cfg = sweep_utils.update_with_sweep( # pytype: disable=wrong-arg-types
config=cfg,
sweep_kwargs=json.dumps({
'seed': 12,
'train_ds.name': 'imagenet',
'train_ds.transforms[0].keep[0]': 'other_image',
'model': {'__qualname__': 'flax.linen:Dense', '0': 12},
}),
argv = shlex.split(
# fmt: off
'my_app'
f' --cfg={mnist_autoencoder.__file__}'
' --cfg.seed=12'
' --cfg.train_ds.name=imagenet'
' --cfg.train_ds.transforms[0].keep[0]=other_image'
' --cfg.model="{\\"__qualname__\\": \\"flax.linen:Dense\\", \\"0\\": 12}"'
# fmt: on
)

flag_values = flags.FlagValues()
with _replace_sys_argv(argv):
sweep_flag = kd.konfig.DEFINE_config_file(
'cfg',
mnist_autoencoder.__file__,
'Config file to use for the sweep.',
flag_values=flag_values,
)
flag_values(argv)

cfg = sweep_flag.value
assert cfg.seed == 12
assert cfg.train_ds.transforms[0].keep == ['other_image']
assert cfg.train_ds.name == 'imagenet'
assert cfg.model == nn.Dense(12)
assert isinstance(cfg.model, kd.konfig.ConfigDict)


@contextlib.contextmanager
def _replace_sys_argv(argv):
old_argv = sys.argv
sys.argv = argv
try:
yield
finally:
sys.argv = old_argv
14 changes: 10 additions & 4 deletions kauldron/xm/_src/kauldron_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,7 @@
if typing.TYPE_CHECKING:
from kauldron import kd # pylint: disable=g-bad-import-order # pytype: disable=import-error

# Flag to send the json-serialized sweep overwrites kwargs
# If modifying this, also modify the value in `kauldron/utils/sweep_utils.py`
SWEEP_FLAG_NAME = "sweep_config"
_Json = epy.typing.Json

# TODO(epot): Support sweep on platform,...

Expand Down Expand Up @@ -269,11 +267,19 @@ def _encode_sweep_item(
return dataclasses.replace(
sweep_item,
# Use custom encoder to support ConfigDict objects
job_kwargs={SWEEP_FLAG_NAME: _JsonEncoder().encode(job_kwargs)},
job_kwargs=_serialize_job_kwargs(job_kwargs),
xm_ui_kwargs={k: _ui_repr(v) for k, v in job_kwargs.items()},
)


def _serialize_job_kwargs(job_kwargs: dict[str, _Json]) -> dict[str, _Json]:
return {f"cfg.{k}": _JsonEncoder().encode(v) for k, v in job_kwargs.items()}


def deserialize_job_kwargs(job_kwargs: dict[str, _Json]) -> dict[str, _Json]:
return {k.removeprefix("cfg."): json.loads(v) for k, v in job_kwargs.items()}


def _ui_repr(v):
"""Parameters displayed on the UI."""
# TODO(epot): In theory, could list exhaustivelly all accepted types
Expand Down
11 changes: 1 addition & 10 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
"Training configuration.",
lock_config=False,
)
_SWEEP_CONFIG = sweep_utils.define_sweep_flag()
_EVAL_NAMES = flags.DEFINE_list(
"eval_names",
None,
Expand All @@ -58,11 +57,7 @@ def main(_):

with _wu_error_handling(_POST_MORTEM.value):
eval_names = _EVAL_NAMES.value
cfg = sweep_utils.update_with_sweep(
config=_CONFIG.value,
sweep_kwargs=_SWEEP_CONFIG.value,
)
# TODO(b/374268398): add back _update_xm_configuration(cfg) once ACtx good
cfg = _CONFIG.value
if eval_names is None:
trainer: kd.train.Trainer = kd.konfig.resolve(cfg)
trainer.train()
Expand All @@ -82,10 +77,6 @@ def _wu_error_handling(post_mortem: bool = False):
yield # not yet supported externally


def _update_xm_configuration(cfg: kd.konfig.ConfigDict) -> None:
pass # not supported externally


def _flags_parser(args: list[str]) -> None:
"""Flag parser."""
# Import everything, except kxm (XManager not included in the trainer binary)
Expand Down

0 comments on commit e40aa7a

Please sign in to comment.