-
Notifications
You must be signed in to change notification settings - Fork 197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Provide a raft::copy overload for mdspan-to-mdspan copies #1818
Merged
Merged
Changes from 91 commits
Commits
Show all changes
104 commits
Select commit
Hold shift + click to select a range
e24fd2e
Initial commit
tarang-jain b8cda77
Merge branch 'branch-23.04' of https://github.com/rapidsai/raft into …
tarang-jain 07dabfe
New commit
tarang-jain 64eb461
Merge branch 'branch-23.06' of https://github.com/rapidsai/raft into …
tarang-jain 21c2641
Update
tarang-jain c84daa6
Merge
tarang-jain 4ad421b
Merge
tarang-jain ea11b07
Merge
tarang-jain ab19410
build
tarang-jain 9870e9d
Test start
tarang-jain 51a2581
Test start
tarang-jain 552b21e
Merge branch 'branch-23.06' of https://github.com/rapidsai/raft into …
tarang-jain d0e7b2c
style changes
tarang-jain f72f7f8
merge
tarang-jain 05f9daa
merge dependencies.yaml
tarang-jain 0250931
Updates
tarang-jain 057743d
Merge branch 'branch-23.06' of https://github.com/rapidsai/raft into …
tarang-jain 20042b0
Debugging
tarang-jain 2d189c3
Update gtest
tarang-jain 53c4557
Merge branch 'branch-23.06' of https://github.com/rapidsai/raft into …
tarang-jain de753ae
Merge branch 'branch-23.06' of https://github.com/rapidsai/raft into …
tarang-jain 2f8b294
Some updates after reviews
tarang-jain 6539ef4
Use raft::resources
tarang-jain 1709521
Merge branch 'branch-23.06' of https://github.com/rapidsai/raft into …
tarang-jain 008bb5b
move exception
tarang-jain 5b97273
Updates after PR Reviews
tarang-jain 5be6ec2
Merge branch 'branch-23.06' of https://github.com/rapidsai/raft into …
tarang-jain 838bfef
Add container policy
tarang-jain e035e2e
further changes with container policy
tarang-jain cd91a88
Merge branch 'branch-23.06' of https://github.com/rapidsai/raft into …
tarang-jain 338c1a6
Some updates
tarang-jain 6468c24
update container_policy
tarang-jain 1bd5455
Merge branch 'branch-23.08' of https://github.com/rapidsai/raft into …
tarang-jain 81c6a81
Working build
tarang-jain 77ae593
Merge branch 'branch-23.08' of https://github.com/rapidsai/raft into …
tarang-jain 451815e
Update buffer accessor policy
tarang-jain b553369
Merge branch 'branch-23.08' of https://github.com/rapidsai/raft into …
tarang-jain b410f36
Style changes
tarang-jain 4731620
minor changes
tarang-jain 238d010
combine owning buffer cpu/gpu
tarang-jain 75cfcf1
update tests
tarang-jain 7b1909f
Updates
tarang-jain 5c041c4
Merge branch 'branch-23.08' of https://github.com/rapidsai/raft into …
tarang-jain 0bf6f87
Merge branch 'branch-23.08' into tarbuf
wphicks 1a1143f
Temporarily remove new files to bring back necessary ones
wphicks acceb61
Begin refactoring buffer container policies
wphicks fdefc34
Add placeholder resource for stream view in CUDA-free builds
wphicks 24223ed
Add infrastructure for CUDA-free build
wphicks c6f6354
Merge branch 'branch-23.08' into fea-mdbuffer
wphicks 4689052
Add initial set of CUDA-free tests
wphicks 1b7e1e5
Add variant types to mdbuffer
wphicks 5416ceb
Provide all mdarray/mdspan to mdbuffer conversions
wphicks 355b3d4
Begin creating buffer copy utilities
wphicks 601f65d
Merge branch 'branch-23.10' into fea-mdbuffer
wphicks 4770a83
Correct computation of dest indices
wphicks 28e8627
Merge branch 'branch-23.10' into fea-mdbuffer
wphicks 8237a74
Temporarily remove simd-accelerated copy
wphicks 022cf6e
Add initial mdspan copy utility implementation
wphicks a1776f4
Refactor copy properties detection
wphicks a970dad
Correct detection of mdspan copy paths
wphicks 9a2fa9e
Correct build errors
wphicks eac9de6
Provide passing 3D host transpose tests
wphicks 39cf094
Add working tests for cuBlas based transpose
wphicks 760b656
Add incomplete kernel tests
wphicks f8d435f
Remove old mdspan copy header
wphicks 4c4fbaf
Revert "Remove old mdspan copy header"
wphicks ad5c786
Remove correct mdspan copy header
wphicks 2e433ba
Correct std::apply workaround in CUDA
wphicks d669e42
Provide fully working copy kernel
wphicks ed663c8
Begin adding SIMD support
wphicks ab809e8
Revert "Begin adding SIMD support"
wphicks 49d871a
Disable initial SIMD implementation
wphicks cb24abc
Rename mdspan copy headers
wphicks 2a83c1b
Remove mdbuffer work and document mdspan copy
wphicks 4193b74
Merge branch 'branch-23.10' into fea-mdspan_copy
wphicks 624e4f3
Remove un-needed changes left over from mdbuffer
wphicks e9ef750
Add testing for CUDA-disabled builds
wphicks 06fe54d
Merge branch 'branch-23.10' into fea-mdspan_copy
wphicks 92046e0
Fix style and revert some unnecessary changes
wphicks a0a5b69
Remove changes related to mdbuffer
wphicks 58389ec
Remove change related to mdbuffer
wphicks 0a19ae5
Correctly handle proxy references in mdspan copy kernel
wphicks 0675207
Check for unique destination layout in any parallel copy
wphicks 8ad9434
Use perfect forwarding for copy wrappers
wphicks fdbc9ee
Correct comment for dimension iteration order
wphicks 21618ea
Add warning about copying to non-unique layouts
wphicks 18d462e
Add benchmarks for mdspan copy
wphicks 4700199
Merge branch 'branch-23.10' into fea-mdspan_copy
wphicks 2cad1ed
Merge branch 'branch-23.10' into fea-mdspan_copy
wphicks 6e91a1c
Correct check for assignability in mdspan copy
wphicks 55e06fe
Add comment explaining intermediate storage
wphicks faa402a
Correct dtype compatibility test
wphicks 2eba34d
Provide cleaner compile error for using copy with unsupported types
wphicks ca77cf0
Merge branch 'branch-23.10' into fea-mdspan_copy
wphicks 4389b64
Update stream_view docs
wphicks 7416b73
Merge branch 'branch-23.10' into fea-mdspan_copy
wphicks 7f407ed
Merge branch 'branch-23.10' into fea-mdspan_copy
wphicks 62ac60a
Update stream view docs
wphicks 5bddcc8
Merge remote-tracking branch 'origin/fea-mdspan_copy' into fea-mdspan…
wphicks bd5a8f8
Merge branch 'branch-23.12' into fea-mdspan_copy
wphicks a8b17a8
Add static asserts for mdspan_copyable
wphicks 722425c
Correct iteration in host-to-host copies
wphicks 0863db0
Fix double-defined target from branch merge
wphicks 5c4349e
Merge branch 'branch-23.12' into fea-mdspan_copy
cjnolet File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
/* | ||
* Copyright (c) 2023, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#pragma once | ||
#include <raft/core/detail/copy.hpp> | ||
namespace raft { | ||
/** | ||
* @brief Copy data from one mdspan to another with the same extents | ||
* | ||
* This function copies data from one mdspan to another, regardless of whether | ||
* or not the mdspans have the same layout, memory type (host/device/managed) | ||
* or data type. So long as it is possible to convert the data type from source | ||
* to destination, and the extents are equal, this function should be able to | ||
* perform the copy. Any necessary device operations will be stream-ordered via the CUDA stream | ||
* provided by the `raft::resources` argument. | ||
* | ||
* This header includes a custom kernel used for copying data between | ||
* completely arbitrary mdspans on device. To compile this function in a | ||
* non-CUDA translation unit, `raft/core/copy.hpp` may be used instead. The | ||
* pure C++ header will correctly compile even without a CUDA compiler. | ||
* Depending on the specialization, this CUDA header may invoke the kernel and | ||
* therefore require a CUDA compiler. | ||
* | ||
* Limitations: Currently this function does not support copying directly | ||
* between two arbitrary mdspans on different CUDA devices. It is assumed that the caller sets the | ||
* correct CUDA device. Furthermore, host-to-host copies that require a transformation of the | ||
* underlying memory layout are currently not performant, although they are supported. | ||
* | ||
* Note that when copying to an mdspan with a non-unique layout (i.e. the same | ||
* underlying memory is addressed by different element indexes), the source | ||
* data must contain non-unique values for every non-unique destination | ||
* element. If this is not the case, the behavior is undefined. Some copies | ||
* to non-unique layouts which are well-defined will nevertheless fail with an | ||
* exception to avoid race conditions in the underlying copy. | ||
* | ||
* @tparam DstType An mdspan type for the destination container. | ||
* @tparam SrcType An mdspan type for the source container | ||
* @param res raft::resources used to provide a stream for copies involving the | ||
* device. | ||
* @param dst The destination mdspan. | ||
* @param src The source mdspan. | ||
*/ | ||
template <typename DstType, typename SrcType> | ||
detail::mdspan_copyable_with_kernel_t<DstType, SrcType> copy(resources const& res, | ||
DstType&& dst, | ||
SrcType&& src) | ||
{ | ||
detail::copy(res, std::forward<DstType>(dst), std::forward<SrcType>(src)); | ||
} | ||
|
||
#ifndef RAFT_NON_CUDA_COPY_IMPLEMENTED | ||
#define RAFT_NON_CUDA_COPY_IMPLEMENTED | ||
template <typename DstType, typename SrcType> | ||
detail::mdspan_uncopyable_with_kernel_t<DstType, SrcType> copy(resources const& res, | ||
DstType&& dst, | ||
SrcType&& src) | ||
{ | ||
detail::copy(res, std::forward<DstType>(dst), std::forward<SrcType>(src)); | ||
} | ||
#endif | ||
} // namespace raft |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
/* | ||
* Copyright (c) 2023, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#pragma once | ||
#include <raft/core/detail/copy.hpp> | ||
namespace raft { | ||
|
||
#ifndef RAFT_NON_CUDA_COPY_IMPLEMENTED | ||
#define RAFT_NON_CUDA_COPY_IMPLEMENTED | ||
/** | ||
* @brief Copy data from one mdspan to another with the same extents | ||
* | ||
* This function copies data from one mdspan to another, regardless of whether | ||
* or not the mdspans have the same layout, memory type (host/device/managed) | ||
* or data type. So long as it is possible to convert the data type from source | ||
* to destination, and the extents are equal, this function should be able to | ||
* perform the copy. | ||
* | ||
* This header does _not_ include the custom kernel used for copying data | ||
* between completely arbitrary mdspans on device. For arbitrary copies of this | ||
* kind, `#include <raft/core/copy.cuh>` instead. Specializations of this | ||
* function that require the custom kernel will be SFINAE-omitted when this | ||
* header is used instead of `copy.cuh`. This header _does_ support | ||
* device-to-device copies that can be performed with cuBLAS or a | ||
* straightforward cudaMemcpy. Any necessary device operations will be stream-ordered via the CUDA | ||
* stream provided by the `raft::resources` argument. | ||
* | ||
* Limitations: Currently this function does not support copying directly | ||
* between two arbitrary mdspans on different CUDA devices. It is assumed that the caller sets the | ||
* correct CUDA device. Furthermore, host-to-host copies that require a transformation of the | ||
* underlying memory layout are currently not performant, although they are supported. | ||
* | ||
* Note that when copying to an mdspan with a non-unique layout (i.e. the same | ||
* underlying memory is addressed by different element indexes), the source | ||
* data must contain non-unique values for every non-unique destination | ||
* element. If this is not the case, the behavior is undefined. Some copies | ||
* to non-unique layouts which are well-defined will nevertheless fail with an | ||
* exception to avoid race conditions in the underlying copy. | ||
* | ||
* @tparam DstType An mdspan type for the destination container. | ||
* @tparam SrcType An mdspan type for the source container | ||
* @param res raft::resources used to provide a stream for copies involving the | ||
* device. | ||
* @param dst The destination mdspan. | ||
* @param src The source mdspan. | ||
*/ | ||
template <typename DstType, typename SrcType> | ||
detail::mdspan_uncopyable_with_kernel_t<DstType, SrcType> copy(resources const& res, | ||
DstType&& dst, | ||
SrcType&& src) | ||
{ | ||
detail::copy(res, std::forward<DstType>(dst), std::forward<SrcType>(src)); | ||
} | ||
#endif | ||
|
||
} // namespace raft |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
/* | ||
* Copyright (c) 2023, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
#pragma once | ||
namespace raft { | ||
#ifndef RAFT_DISABLE_CUDA | ||
auto constexpr static const CUDA_ENABLED = true; | ||
#else | ||
auto constexpr static const CUDA_ENABLED = false; | ||
#endif | ||
} // namespace raft |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These changes seem to be enforced by cmake-format. Not sure why they were not present before.