You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm trying to run GraphCast on an H100 GPU and I've run it with -v and --debug and all I get is a segmentation fault. I've seen that when trying to run on A40's and my RDX4060Ti on my home due to the lack of memory but that should not be an issue with the H100. There are no errors in the 9000 line debug file so I'm at a loss. Any help or tips would be most welcome. Machine has 128GB of RAM and never gets above 50% utilization.
The last lines of the debug runs are:
2025-01-21 19:08:16,999 INFO Converting GRIB to xarray: 3 seconds.
2025-01-21 19:08:22,475 INFO Reindexing: 5 seconds.
2025-01-21 19:08:22,501 INFO Creating training data: 18 seconds.
2025-01-21 19:09:00,855 INFO Extracting input targets: 25 seconds.
2025-01-21 19:09:03,232 INFO Creating input data (total): 59 seconds.
2025-01-21 19:09:03,233 DEBUG Finished tracing + transforming convert_element_type for pjit in 0.0003840923309326172 sec
2025-01-21 19:09:03,234 DEBUG Discovered path based JAX plugin: jax_plugins.xla_cuda12
2025-01-21 19:09:03,238 DEBUG Discovered entry-point based JAX plugin: jax_plugins.xla_cuda12
2025-01-21 19:09:03,238 DEBUG Loading plugin module jax_plugins.xla_cuda12
2025-01-21 19:09:03,240 DEBUG registering PJRT plugin cuda from /home/rm11714/anaconda3/envs/GCast/lib/python3.10/site-packages/jax_plugins/xla_cuda12/xla_cuda_plugin.so
Segmentation fault (core dumped)
Followed by : ( by me.
The text was updated successfully, but these errors were encountered:
I suspect H100s are not sufficient. The main Graphcast/GenCast repo obviously focuses their testing on TPUs and Google Cloud. You're probably better off writing your own input layer that doesn't depend on ECMWF/ONNX and scraping inits manually with your ECMWF API key.
I'm trying to run GraphCast on an H100 GPU and I've run it with -v and --debug and all I get is a segmentation fault. I've seen that when trying to run on A40's and my RDX4060Ti on my home due to the lack of memory but that should not be an issue with the H100. There are no errors in the 9000 line debug file so I'm at a loss. Any help or tips would be most welcome. Machine has 128GB of RAM and never gets above 50% utilization.
The last lines of the debug runs are:
2025-01-21 19:08:16,999 INFO Converting GRIB to xarray: 3 seconds.
2025-01-21 19:08:22,475 INFO Reindexing: 5 seconds.
2025-01-21 19:08:22,501 INFO Creating training data: 18 seconds.
2025-01-21 19:09:00,855 INFO Extracting input targets: 25 seconds.
2025-01-21 19:09:03,232 INFO Creating input data (total): 59 seconds.
2025-01-21 19:09:03,233 DEBUG Finished tracing + transforming convert_element_type for pjit in 0.0003840923309326172 sec
2025-01-21 19:09:03,234 DEBUG Discovered path based JAX plugin: jax_plugins.xla_cuda12
2025-01-21 19:09:03,238 DEBUG Discovered entry-point based JAX plugin: jax_plugins.xla_cuda12
2025-01-21 19:09:03,238 DEBUG Loading plugin module jax_plugins.xla_cuda12
2025-01-21 19:09:03,240 DEBUG registering PJRT plugin cuda from /home/rm11714/anaconda3/envs/GCast/lib/python3.10/site-packages/jax_plugins/xla_cuda12/xla_cuda_plugin.so
Segmentation fault (core dumped)
Followed by : ( by me.
The text was updated successfully, but these errors were encountered: