diff --git a/.github/workflows/unittests.yml b/.github/workflows/unittests.yml index 0c8bc5c..2e065ef 100644 --- a/.github/workflows/unittests.yml +++ b/.github/workflows/unittests.yml @@ -42,7 +42,7 @@ jobs: # Check formatting - name: Check pyink formatting - run: uv run pyink penzai --check + run: uv run pyink penzai tests --check - name: Run pylint run: uv run pylint penzai diff --git a/tests/deprecated/v1/toolshed/model_rewiring_test.py b/tests/deprecated/v1/toolshed/model_rewiring_test.py index 94593ec..2707333 100644 --- a/tests/deprecated/v1/toolshed/model_rewiring_test.py +++ b/tests/deprecated/v1/toolshed/model_rewiring_test.py @@ -141,23 +141,26 @@ def target(stuff): result["b"].canonicalize(), ( pz.nx.wrap( - jnp.array([ - [ - 1**3 + 3 * 1**2 * (4 - 1), - 1**3 + 3 * 1**2 * (5 - 1), - 1**3 + 3 * 1**2 * (6 - 1), - ], - [ - 2**3 + 3 * 2**2 * (4 - 2), - 2**3 + 3 * 2**2 * (5 - 2), - 2**3 + 3 * 2**2 * (6 - 2), - ], + jnp.array( [ - 3**3 + 3 * 3**2 * (4 - 3), - 3**3 + 3 * 3**2 * (5 - 3), - 3**3 + 3 * 3**2 * (6 - 3), + [ + 1**3 + 3 * 1**2 * (4 - 1), + 1**3 + 3 * 1**2 * (5 - 1), + 1**3 + 3 * 1**2 * (6 - 1), + ], + [ + 2**3 + 3 * 2**2 * (4 - 2), + 2**3 + 3 * 2**2 * (5 - 2), + 2**3 + 3 * 2**2 * (6 - 2), + ], + [ + 3**3 + 3 * 3**2 * (4 - 3), + 3**3 + 3 * 3**2 * (5 - 3), + 3**3 + 3 * 3**2 * (6 - 3), + ], ], - ], jnp.float32) + jnp.float32, + ) ) .tag("foo", "bar") .canonicalize() diff --git a/tests/toolshed/model_rewiring_test.py b/tests/toolshed/model_rewiring_test.py index 35a0346..45ef9cd 100644 --- a/tests/toolshed/model_rewiring_test.py +++ b/tests/toolshed/model_rewiring_test.py @@ -141,23 +141,26 @@ def target(stuff): result["b"].canonicalize(), ( pz.nx.wrap( - jnp.array([ - [ - 1**3 + 3 * 1**2 * (4 - 1), - 1**3 + 3 * 1**2 * (5 - 1), - 1**3 + 3 * 1**2 * (6 - 1), - ], - [ - 2**3 + 3 * 2**2 * (4 - 2), - 2**3 + 3 * 2**2 * (5 - 2), - 2**3 + 3 * 2**2 * (6 - 2), - ], + jnp.array( [ - 3**3 + 3 * 3**2 * (4 - 3), - 3**3 + 3 * 3**2 * (5 - 3), - 3**3 + 3 * 3**2 * (6 - 3), + [ + 1**3 + 3 * 1**2 * (4 - 1), + 1**3 + 3 * 1**2 * (5 - 1), + 1**3 + 3 * 1**2 * (6 - 1), + ], + [ + 2**3 + 3 * 2**2 * (4 - 2), + 2**3 + 3 * 2**2 * (5 - 2), + 2**3 + 3 * 2**2 * (6 - 2), + ], + [ + 3**3 + 3 * 3**2 * (4 - 3), + 3**3 + 3 * 3**2 * (5 - 3), + 3**3 + 3 * 3**2 * (6 - 3), + ], ], - ], jnp.float32) + jnp.float32, + ) ) .tag("foo", "bar") .canonicalize()