Skip to content

Commit

Permalink
Fix open-source tests
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 698002092
  • Loading branch information
Conchylicultor authored and The kauldron Authors committed Nov 19, 2024
1 parent e40aa7a commit 307dbb4
Show file tree
Hide file tree
Showing 12 changed files with 31 additions and 31 deletions.
21 changes: 20 additions & 1 deletion .github/workflows/pytest_and_autopublish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,31 @@ jobs:
# cache-dependency-path: '**/pyproject.toml'

- run: pip --version

# TODO(epot): Delete once all projects have been released
- run: pip install git+https://github.com/google/etils
- run: pip install git+https://github.com/google/flax

- run: pip install -e .[dev]
- run: pip freeze

# Run tests (in parallel)
# Filter out:
# * Projects: Not part of core Kauldron (could be tested separately)
# * TF Data pipeline (not supported due to TFGrain not open-sourced)
# * XManager tests (not yet supported)
# * sweep_utils_test: Depends on kxm
# * lpips_test: Missing VGG weights
# * partial_loader_test: Orbax partial checkpoint loader not yet open-sourced (TODO(epot): Restore)
- name: Run core tests
run: pytest -vv -n auto
run: |
pytest -vv -n auto \
--ignore=projects/ \
--ignore=kauldron/data/tf/ \
--ignore=kauldron/xm/ \
--ignore=kauldron/metrics/lpips_test.py \
--ignore=kauldron/checkpoints/partial_loader_test.py \
--ignore=kauldron/utils/sweep_utils_test.py
# Auto-publish when version is increased
publish-job:
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ poetry.lock
.pytype/

# Other
.hypothesis/
*.DS_Store

# PyCharm
Expand Down
13 changes: 2 additions & 11 deletions kauldron/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,8 @@
from kauldron.data.in_memory import InMemoryPipeline

# PyGrain based data pipeline.
# TODO(epot): Somehow importing here create infinite recursion when the
# import is resolved, likely because there's some special handling of the
# suffix `py` to support `third_party.py`. I don't have time to investigate
# so instead the module is imported below in `lazy_imports` rather than
# `lazy_api_imports`.
# from kauldron.data import py

# TODO(epot): Migrate all existing symbols to `kd.data.tf.`
from kauldron.data import py

# tf.data based data pipeline.
from kauldron.data import tf

Expand All @@ -57,6 +51,3 @@
from kauldron.data.transforms.map_transforms import Gather
from kauldron.data.transforms.map_transforms import Rearrange
from kauldron.data.transforms.map_transforms import ValueRange

with _epy.lazy_imports():
from kauldron.data import py
2 changes: 1 addition & 1 deletion kauldron/data/py/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def _root_ds(self) -> grain.IterDataset:
num_workers,
enable_profiling=self.enable_profiling,
)
ds = ds.mp_prefetch(multiprocessing_options)
ds = ds.prefetch(multiprocessing_options)
return ds

def __iter__(self) -> iterators.Iterator:
Expand Down
2 changes: 1 addition & 1 deletion kauldron/data/transforms/map_transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def graph_mode():

@enp.testing.parametrize_xnp(restrict=["np", "tnp"])
def test_value_range(xnp: enp.NpModule):
vr = kd.data.tf.ValueRange(
vr = kd.data.ValueRange(
key="values",
in_vrange=(0.0, 255.0),
vrange=(0.0, 1.0),
Expand Down
2 changes: 0 additions & 2 deletions kauldron/evals/eval_impl_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ def test_eval_impl(tmp_path: epath.Path):
# Load config and reduce size
cfg = mnist_autoencoder.get_config()

# TODO(klausg): remove this once data mocking works correctly with grain
cfg.train_ds.__qualname__ = 'kauldron.kd:data.Tfds'
cfg.train_ds.batch_size = 1
cfg.evals.eval.ds.batch_size = 1 # pytype: disable=attribute-error
cfg.model.encoder.features = 3
Expand Down
2 changes: 1 addition & 1 deletion kauldron/modules/adapter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_external():
)

inputs = jnp.ones((5, 5))
input_kwargs = kd.kontext.get_from_keys_obj(
input_kwargs = kd.kontext.resolve_from_keyed_obj(
{'a': inputs, 'b': jnp.zeros(())}, model
)
out_train = model.apply(
Expand Down
8 changes: 4 additions & 4 deletions kauldron/utils/config_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Test."""

import dataclasses

from kauldron import kd
Expand All @@ -22,8 +20,10 @@

def test_trainer_replace():
trainer = kd.train.Trainer(
eval_ds=kd.data.Tfds(name='mnist', split='train'),
train_ds=kd.data.Tfds(name='mnist', split='train', seed=60), # pytype: disable=wrong-keyword-args
eval_ds=kd.data.py.Tfds(name='mnist', split='train', shuffle=False),
train_ds=kd.data.py.Tfds(
name='mnist', split='train', shuffle=True, seed=60
),
init_transforms={
'base': kd.ckpts.PartialKauldronLoader(workdir='/some/workdir')
},
Expand Down
3 changes: 0 additions & 3 deletions kauldron/utils/import_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,6 @@ def test_import_kauldron():

end = time.time()

# Enforce lazy deps
assert "sklearn" not in sys.modules

# TODO(epot): Reduce this value
assert end - start < 40, "Kauldron import took too long."

Expand Down
2 changes: 0 additions & 2 deletions kauldron/utils/sharding_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ def test_sharding(tmp_path: epath.Path):
# Load config and reduce size
cfg = mnist_autoencoder.get_config()

# TODO(klausg): remove this once data mocking works correctly with grain
cfg.train_ds.__qualname__ = 'kauldron.kd:data.Tfds'
cfg.train_ds.batch_size = 1
cfg.model.encoder.features = 3
cfg.workdir = os.fspath(tmp_path)
Expand Down
2 changes: 1 addition & 1 deletion kauldron/xm/_src/job_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ class JobParams:
platform: None | str = None
cell: None | str = None

citc_source: None | str | tuple[str, ...]
citc_source: None | g3_utils.Source = None
use_interpreter: bool = False
interpreter_info: InterpreterInfo = dataclasses.field(
default_factory=InterpreterInfo
Expand Down
4 changes: 0 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,6 @@

import contextlib

import sys
from unittest import mock
# TODO(klausg): remove once we have a better solution
sys.modules['etils.exm'] = mock.MagicMock()
from absl import app
from absl import flags
from etils import epy
Expand Down

0 comments on commit 307dbb4

Please sign in to comment.