We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
I'm trying to compile the following code:
def compute_signature(self) -> jax.Array: flat = self.payload.ravel().view(jnp.uint32) seed = flat[0] max_sz = min(10, len(flat)) hash = jax.lax.fori_loop(1, max_sz, lambda i, h: h + flat[i], seed) return hash
but I'm getting the following error:
025-01-02 19:08:31.000150: 15030 ERROR ||NEURON_CC_WRAPPER||: Failed compilation with ['neuronx-cc', 'compile', '--target=trn1', '--framework=XLA', '/tmp/argocd/neuroncc_compile_workdir/a54c3791-722b-4e7e-a4ae-b362677db784/model.MODULE_5177726690754820286+2ba4dce1.hlo_module.pb', '--output', '/tmp/argocd/neuroncc_compile_workdir/a54c3791-722b-4e7e-a4ae-b362677db784/model.MODULE_5177726690754820286+2ba4dce1.neff', '-O1', '--internal-hlo2tensorizer-options= --modular-flow-mac-threshold-for-default=1000000 --modular-flow-mac-threshold=1000000 ', '--disable-internal-io-dge', '--model-type=transformer', '--tensorizer-options=--disable-dma-cast --skip-pass=PartialLoopFusion --skip-pass=SimplifyNeuronTensor --enable-ccop-compute-overlap --cc-pipeline-tiling-factor=2', '--verbose=35', '--layer-unroll-factor=5', '--pipeline', 'compile', 'SaveTemps']: 2025-01-03T03:08:31Z [NLA001] Unhandled exception with message: === BIR verification failed === Reason: Output dtype must match input Instruction: I-95 Opcode: CollectiveCompute Instruction Source: ({'no_delinear': '0'}non_local uint32 [163840] %'bitcast_convert.1':95)0: Output index: 0 Argument AP: Access Pattern: [[163840,1],[163840,1],[1,163840]] SymbolicAP Memory Location: {bitcast_convert.1_set}@dram#Internal DebugInfo: <bitcast_convert.1||NHC||[163840, 1, 1]> Please open a support ticket at https://github.com/aws-neuron/aws-neuron-sdk/issues/new. You may also be able to obtain more information using the 'XLA_IR_DEBUG' and 'XLA_HLO_DEBUG' environment variables.
025-01-02 19:08:31.000150: 15030 ERROR ||NEURON_CC_WRAPPER||: Failed compilation with ['neuronx-cc', 'compile', '--target=trn1', '--framework=XLA', '/tmp/argocd/neuroncc_compile_workdir/a54c3791-722b-4e7e-a4ae-b362677db784/model.MODULE_5177726690754820286+2ba4dce1.hlo_module.pb', '--output', '/tmp/argocd/neuroncc_compile_workdir/a54c3791-722b-4e7e-a4ae-b362677db784/model.MODULE_5177726690754820286+2ba4dce1.neff', '-O1', '--internal-hlo2tensorizer-options= --modular-flow-mac-threshold-for-default=1000000 --modular-flow-mac-threshold=1000000 ', '--disable-internal-io-dge', '--model-type=transformer', '--tensorizer-options=--disable-dma-cast --skip-pass=PartialLoopFusion --skip-pass=SimplifyNeuronTensor --enable-ccop-compute-overlap --cc-pipeline-tiling-factor=2', '--verbose=35', '--layer-unroll-factor=5', '--pipeline', 'compile', 'SaveTemps']: 2025-01-03T03:08:31Z [NLA001] Unhandled exception with message: === BIR verification failed === Reason: Output dtype must match input Instruction: I-95 Opcode: CollectiveCompute Instruction Source: ({'no_delinear': '0'}non_local uint32 [163840] %'bitcast_convert.1':95)0: Output index: 0 Argument AP: Access Pattern: [[163840,1],[163840,1],[1,163840]] SymbolicAP Memory Location: {bitcast_convert.1_set}@dram#Internal DebugInfo: <bitcast_convert.1||NHC||[163840, 1, 1]>
self.payload is a jax array of bfloat16 and shape [1,128,2560] sharded across several devices
The text was updated successfully, but these errors were encountered:
I was able to reproduce the problem with this script:
import jax import jax.numpy as jnp from jax.sharding import PartitionSpec as P, NamedSharding, Mesh @jax.jit def func(payload) -> jax.Array: flat = payload.ravel().view(jnp.uint32) seed = flat[0] max_sz = min(10, len(flat)) hash = jax.lax.fori_loop(1, max_sz, lambda i, h: h + flat[i], seed) return hash def main(): # Create sharded data devices = jax.devices() mesh = Mesh(devices, ('tp')) sharding = NamedSharding(mesh, P(None, 'tp', None)) data = jnp.ones((1, 128, 2560), device=sharding, dtype=jnp.bfloat16) # Execute the function func(data) if __name__ == '__main__': main()
We will look into the problem and update here once we have a root cause or resolution
Sorry, something went wrong.
No branches or pull requests
I'm trying to compile the following code:
but I'm getting the following error:
self.payload is a jax array of bfloat16 and shape [1,128,2560] sharded across several devices
The text was updated successfully, but these errors were encountered: