Skip to content

Commit

Permalink
#sdy fix bug due to tensor dialect being introduced
Browse files Browse the repository at this point in the history
When investigating a bug, I discovered this fails in JAX:
```py
NS = jax.sharding.NamedSharding
P = jax.sharding.PartitionSpec

mesh = jax.sharding.Mesh(
        np.reshape(np.array(jax.devices()), (4,2)), ('data', 'model'))

in_avals = (jax.ShapeDtypeStruct((4, 8), jnp.float32),)
shardings = (NS(mesh, P('data',)),)
@partial(jax.jit, out_shardings=shardings)
def gen_dummy_inputs():
  return tuple(
      jax.random.normal(
          jax.random.key(42), shape=in_aval.shape
      ).astype(in_aval.dtype)
      for in_aval in in_avals
  )
gen_dummy_inputs()
```

with the error

```
LLVM ERROR: Building op `tensor.cast` but it isn't known in this MLIRContext: the dialect may not be loaded or this operation hasn't been added by the dialect. See also https://mlir.llvm.org/getting_started/Faq/#registered-loaded-dependent-whats-up-with-dialects-management
```

This was because the sdy-round-trip-import introduces the tensor dialect. I'm unsure which pass adds it, but overall what I see is it is actually undone. The details shouldn't matter as long as the pass doesn't crash and the dialect doesn't show up during propagation.

PiperOrigin-RevId: 714046691
  • Loading branch information
bartchr808 authored and Google-ML-Automation committed Jan 10, 2025
1 parent 9b7c4f6 commit 29d3ca1
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
4 changes: 2 additions & 2 deletions xla/service/spmd/shardy/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ cc_library(
"@llvm-project//mlir:FuncExtensions",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@shardy//shardy/dialect/sdy/ir:dialect",
"@shardy//shardy/dialect/sdy/ir:register",
"@stablehlo//:stablehlo_ops",
Expand Down Expand Up @@ -143,12 +144,11 @@ xla_cc_binary(
"//xla/service/spmd/shardy/sdy_round_trip/test_utils:mhlo_to_hlo_to_mhlo",
"//xla/service/spmd/shardy/sdy_round_trip/test_utils:testing_pipeline",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:FuncExtensions",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MlirOptLib",
"@shardy//shardy/dialect/sdy/ir:dialect",
"@shardy//shardy/dialect/sdy/ir:register",
"@shardy//shardy/dialect/sdy/transforms:passes",
"@stablehlo//:stablehlo_ops",
],
)
7 changes: 3 additions & 4 deletions xla/service/spmd/shardy/sdy_opt_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,12 @@ limitations under the License.
==============================================================================*/

#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/InitAllPasses.h"
#include "mlir/Tools/mlir-opt/MlirOptMain.h"
#include "shardy/dialect/sdy/ir/dialect.h"
#include "shardy/dialect/sdy/ir/register.h"
#include "shardy/dialect/sdy/transforms/passes.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
#include "xla/mlir_hlo/mhlo/transforms/passes.h"
#include "xla/service/spmd/shardy/mhlo_round_trip/export_callback_custom_calls.h"
Expand Down Expand Up @@ -51,8 +50,8 @@ int main(int argc, char** argv) {
mlir::mhlo::registerAllMhloPasses();

mlir::DialectRegistry dialects;
dialects.insert<mlir::func::FuncDialect, mlir::mhlo::MhloDialect,
mlir::sdy::SdyDialect, mlir::stablehlo::StablehloDialect>();
mlir::sdy::registerAllDialects(dialects);
dialects.insert<mlir::mhlo::MhloDialect>();
mlir::func::registerAllExtensions(dialects);

// Register all SDY passes and pipelines.
Expand Down

0 comments on commit 29d3ca1

Please sign in to comment.