Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROCM] fix 'Invalid plugin kind specified: DNN' error #5

Closed
wants to merge 1 commit into from

Conversation

Ruturaj4
Copy link

No description provided.

case PluginKind::kFft:
return iter->second.fft.has_value();
default:
break;
}

LOG(ERROR) << "Invalid plugin kind specified: "
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this was being hit when the iterator was returning as ended/stopped, and now we are handling that case specifically at line 78, therefore stopping the erroneous error message.

Is that correct?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, the current code doesn't check whether factories_ is empty and falls back to "invalid plugin error" anyways.

@Ruturaj4 Ruturaj4 marked this pull request as draft March 26, 2024 14:18
@Ruturaj4 Ruturaj4 closed this Apr 4, 2024
pemeliya pushed a commit that referenced this pull request May 6, 2024
ekuznetsov139 pushed a commit that referenced this pull request May 13, 2024
…d phase to Initialize()

Imported from GitHub PR openxla#12228

The first time that a NormThunk is executed, it will build a cudnn execution plan. This build step can hang if a NCCL collective is running at the same time. To fix this, I've moved the build step to take place during thunk initialization. We only observe this hang when using cudnn 9.

Here's a backtrace from the hang that will be fixed:
```
Thread 585 (Thread 0x7fb9391ff640 (LWP 41364) "main.py"):
#0  0x00007fd3d17cffd9 in ?? () from /lib/x86_64-linux-gnu/libc.so.6
#1  0x00007fd3d17da24f in pthread_rwlock_wrlock () from /lib/x86_64-linux-gnu/libc.so.6
#2  0x00007fd070967dfe in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#3  0x00007fd0709c928a in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#4  0x00007f1970d76102 in ?? () from /lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.1.0
#5  0x00007f1970f2c999 in ?? () from /lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.1.0
#6  0x00007f1970a7d4ab in ?? () from /lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.1.0
#7  0x00007f1970d0a9cb in ?? () from /lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.1.0
#8  0x00007fce60b2a98c in cudnn::backend::ExecutionPlan::finalize_internal() () from /lib/x86_64-linux-gnu/libcudnn_graph.so.9.1.0
#9  0x00007fce60aefbb1 in cudnn::backend::Descriptor::finalize() () from /lib/x86_64-linux-gnu/libcudnn_graph.so.9.1.0
#10 0x00007fce60b15bec in cudnnBackendFinalize () from /lib/x86_64-linux-gnu/libcudnn_graph.so.9.1.0
#11 0x00007fd2521b8f39 in cudnn_frontend::ExecutionPlanBuilder_v8::build() () from /usr/local/lib/python3.10/dist-packages/jaxlib/xla_extension.so
#12 0x00007fd2521734ba in stream_executor::gpu::(anonymous namespace)::GetExecPlanFromHeuristics(cudnn_frontend::OperationGraph_v8&&, stream_executor::gpu::(anonymous namespace)::CudnnHandle const&, bool) () from /usr/local/lib/python3.10/dist-packages/jaxlib/xla_extension.so
#13 0x00007fd25216ff9b in stream_executor::gpu::CudnnSupport::NormRunnerFromDesc(stream_executor::Stream*, stream_executor::dnn::AlgorithmDesc const&, stream_executor::dnn::NormKind, double, stream_executor::dnn::TensorDescriptor const&, stream_executor::dnn::TensorDescriptor const&, stream_executor::dnn::TensorDescriptor const&, std::optional<stream_executor::dnn::TensorDescriptor>, std::optional<stream_executor::dnn::TensorDescriptor>, std::optional<stream_executor::dnn::TensorDescriptor>, std::optional<stream_executor::dnn::TensorDescriptor>, std::optional<stream_executor::dnn::TensorDescriptor>, std::optional<stream_executor::dnn::TensorDescriptor>) () from /usr/local/lib/python3.10/dist-packages/jaxlib/xla_extension.so
#14 0x00007fd24e36b88b in stream_executor::dnn::NormOp::RunnerFromAlgorithmDesc(stream_executor::dnn::AlgorithmDesc const&, stream_executor::dnn::NormOp::Config, stream_executor::Stream*) () from /usr/local/lib/python3.10/dist-packages/jaxlib/xla_extension.so
#15 0x00007fd24e36ae37 in stream_executor::dnn::LazyOpRunner<stream_executor::dnn::NormOp>::GetOrCreateRunner(stream_executor::dnn::NormOp::Config, stream_executor::Stream*)::{lambda()#1}::operator()() const () from /usr/local/lib/python3.10/dist-packages/jaxlib/xla_extension.so
#16 0x00007fd24e36adbc in void absl::lts_20230802::base_internal::CallOnceImpl<stream_executor::dnn::LazyOpRunner<stream_executor::dnn::NormOp>::GetOrCreateRunner(stream_executor::dnn::NormOp::Config, stream_executor::Stream*)::{lambda()#1}>(std::atomic<unsigned int>*, absl::lts_20230802::base_internal::SchedulingMode, stream_executor::dnn::LazyOpRunner<stream_executor::dnn::NormOp>::GetOrCreateRunner(stream_executor::dnn::NormOp::Config, stream_executor::Stream*)::{lambda()#1}&&) () from /usr/local/lib/python3.10/dist-packages/jaxlib/xla_extension.so
#17 0x00007fd24e36a9bd in stream_executor::dnn::LazyOpRunner<stream_executor::dnn::NormOp>::GetOrCreateRunner(stream_executor::dnn::NormOp::Config, stream_executor::Stream*) () from /usr/local/lib/python3.10/dist-packages/jaxlib/xla_extension.so
#18 0x00007fd24e369d29 in xla::gpu::RunGpuNorm(xla::gpu::GpuNormConfig const&, stream_executor::DeviceMemoryBase const&, stream_executor::DeviceMemoryBase const&, stream_executor::DeviceMemoryBase const&, std::optional<stream_executor::DeviceMemoryBase>, std::optional<stream_executor::DeviceMemoryBase>, std::optional<stream_executor::DeviceMemoryBase>, std::optional<stream_executor::DeviceMemoryBase>, std::optional<stream_executor::DeviceMemoryBase>, std::optional<stream_executor::DeviceMemoryBase>, stream_executor::DeviceMemoryBase const&, stream_executor::Stream*, xla::gpu::RunNormOptions) () from /usr/local/lib/python3.10/dist-packages/jaxlib/xla_extension.so
#19 0x00007fd24e368be6 in xla::gpu::NormThunk::ExecuteOnStream(xla::gpu::Thunk::ExecuteParams const&) () from /usr/local/lib/python3.10/dist-packages/jaxlib/xla_extension.so
```
Copybara import of the project:

--
f535330 by Trevor Morris <[email protected]>:

Fix hang with cudnn layer norm by moving cudnn init to Initialize()

Merging this change closes openxla#12228

COPYBARA_INTEGRATE_REVIEW=openxla#12228 from trevor-m:tmorris-norm-init f535330
PiperOrigin-RevId: 633220207
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants