Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
#sdy fix bug due to tensor dialect being introduced
When investigating a bug, I discovered this fails in JAX: ```py NS = jax.sharding.NamedSharding P = jax.sharding.PartitionSpec mesh = jax.sharding.Mesh( np.reshape(np.array(jax.devices()), (4,2)), ('data', 'model')) in_avals = (jax.ShapeDtypeStruct((4, 8), jnp.float32),) shardings = (NS(mesh, P('data',)),) @partial(jax.jit, out_shardings=shardings) def gen_dummy_inputs(): return tuple( jax.random.normal( jax.random.key(42), shape=in_aval.shape ).astype(in_aval.dtype) for in_aval in in_avals ) gen_dummy_inputs() ``` with the error ``` LLVM ERROR: Building op `tensor.cast` but it isn't known in this MLIRContext: the dialect may not be loaded or this operation hasn't been added by the dialect. See also https://mlir.llvm.org/getting_started/Faq/#registered-loaded-dependent-whats-up-with-dialects-management ``` This was because the sdy-round-trip-import introduces the tensor dialect. I'm unsure which pass adds it, but overall what I see is it is actually undone. The details shouldn't matter as long as the pass doesn't crash and the dialect doesn't show up during propagation. PiperOrigin-RevId: 714046691
- Loading branch information