Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compilation error #1077

Open
benoitsteiner opened this issue Jan 3, 2025 · 1 comment
Open

Compilation error #1077

benoitsteiner opened this issue Jan 3, 2025 · 1 comment

Comments

@benoitsteiner
Copy link

I'm trying to compile the following code:

def compute_signature(self) -> jax.Array:
        flat = self.payload.ravel().view(jnp.uint32)
        seed = flat[0]
        max_sz = min(10, len(flat))
        hash = jax.lax.fori_loop(1, max_sz, lambda i, h: h + flat[i], seed)
        return hash

but I'm getting the following error:

025-01-02 19:08:31.000150: 15030 ERROR ||NEURON_CC_WRAPPER||: Failed compilation with ['neuronx-cc', 'compile', '--target=trn1', '--framework=XLA', '/tmp/argocd/neuroncc_compile_workdir/a54c3791-722b-4e7e-a4ae-b362677db784/model.MODULE_5177726690754820286+2ba4dce1.hlo_module.pb', '--output', '/tmp/argocd/neuroncc_compile_workdir/a54c3791-722b-4e7e-a4ae-b362677db784/model.MODULE_5177726690754820286+2ba4dce1.neff', '-O1', '--internal-hlo2tensorizer-options= --modular-flow-mac-threshold-for-default=1000000 --modular-flow-mac-threshold=1000000 ', '--disable-internal-io-dge', '--model-type=transformer', '--tensorizer-options=--disable-dma-cast --skip-pass=PartialLoopFusion --skip-pass=SimplifyNeuronTensor --enable-ccop-compute-overlap --cc-pipeline-tiling-factor=2', '--verbose=35', '--layer-unroll-factor=5', '--pipeline', 'compile', 'SaveTemps']: 2025-01-03T03:08:31Z [NLA001] Unhandled exception with message: === BIR verification failed ===
Reason: Output dtype must match input
Instruction: I-95
Opcode: CollectiveCompute
Instruction Source: ({'no_delinear': '0'}non_local uint32 [163840] %'bitcast_convert.1':95)0:
Output index: 0
Argument AP:
Access Pattern: [[163840,1],[163840,1],[1,163840]]
SymbolicAP
Memory Location: {bitcast_convert.1_set}@dram#Internal DebugInfo: <bitcast_convert.1||NHC||[163840, 1, 1]>

self.payload is a jax array of bfloat16 and shape [1,128,2560] sharded across several devices

@jluntamazon
Copy link
Contributor

I was able to reproduce the problem with this script:

import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P, NamedSharding, Mesh


@jax.jit
def func(payload) -> jax.Array:
    flat = payload.ravel().view(jnp.uint32)
    seed = flat[0]
    max_sz = min(10, len(flat))
    hash = jax.lax.fori_loop(1, max_sz, lambda i, h: h + flat[i], seed)
    return hash


def main():
    # Create sharded data
    devices = jax.devices()
    mesh = Mesh(devices, ('tp'))
    sharding = NamedSharding(mesh, P(None, 'tp', None))
    data = jnp.ones((1, 128, 2560), device=sharding, dtype=jnp.bfloat16)

    # Execute the function
    func(data)


if __name__ == '__main__':
    main()

We will look into the problem and update here once we have a root cause or resolution

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants