forked from pytorch/xla
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support PyTorch CUDACachingAllocator (#12)
* Setting `export PJRT_USE_TORCH_ALLOCATOR=1` to use PyTorch CUDACachingAllocator
- Loading branch information
Showing
6 changed files
with
227 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
#include "torch_xla/csrc/runtime/torch_allocator.h" | ||
|
||
#include <c10/cuda/CUDACachingAllocator.h> | ||
#include <c10/cuda/CUDAGuard.h> | ||
|
||
#include "torch_xla/csrc/runtime/tf_logging.h" | ||
|
||
namespace torch_xla { | ||
namespace runtime { | ||
|
||
TorchCUDACachingAllocator::TorchCUDACachingAllocator(int device_ordinal) { | ||
VLOG(3) << "Creating TorchCUDACachingAllocator for device " << device_ordinal; | ||
name_ = c10::cuda::CUDACachingAllocator::name(); | ||
cuda_stream_ = nullptr; | ||
device_index_ = static_cast<c10::DeviceIndex>(device_ordinal); | ||
} | ||
|
||
void* TorchCUDACachingAllocator::AllocateRaw(size_t alignment, | ||
size_t num_bytes) { | ||
CHECK(cuda_stream_ != nullptr) | ||
<< "A stream must be added to the TorchCUDACachingAllocator allocator"; | ||
if (num_bytes == 0) { | ||
return nullptr; | ||
} | ||
at::cuda::CUDAGuard device_guard{device_index_}; | ||
auto ptr = c10::cuda::CUDACachingAllocator::raw_alloc_with_stream( | ||
num_bytes, cuda_stream_); | ||
VLOG(3) << "Alloc num_bytes " << num_bytes << " with ptr " << ptr | ||
<< " for device " << static_cast<int>(device_index_); | ||
return ptr; | ||
} | ||
|
||
void TorchCUDACachingAllocator::DeallocateRaw(void* ptr) { | ||
VLOG(3) << "Dealloc ptr " << ptr << " for device " | ||
<< static_cast<int>(device_index_); | ||
c10::cuda::CUDACachingAllocator::raw_delete(ptr); | ||
} | ||
|
||
void TorchCUDACachingAllocator::SetStreamAndPreallocateMemory(void* stream) { | ||
auto new_cuda_stream = static_cast<cudaStream_t>(stream); | ||
VLOG(3) << "Setting cuda stream " << stream | ||
<< " for TorchCUDACachingAllocator on device " | ||
<< static_cast<int>(device_index_); | ||
cuda_stream_ = new_cuda_stream; | ||
} | ||
|
||
} // namespace runtime | ||
} // namespace torch_xla |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
#ifndef XLA_CLIENT_TORCH_ALLOCATOR_H_ | ||
#define XLA_CLIENT_TORCH_ALLOCATOR_H_ | ||
|
||
#include <c10/cuda/CUDAStream.h> | ||
#include <cuda_runtime_api.h> | ||
|
||
#include "tsl/framework/allocator.h" | ||
|
||
namespace torch_xla { | ||
namespace runtime { | ||
|
||
class TorchCUDACachingAllocator : public tsl::Allocator { | ||
public: | ||
TorchCUDACachingAllocator(int device_ordinal); | ||
~TorchCUDACachingAllocator() override{}; | ||
|
||
std::string Name() override { return name_; } | ||
|
||
void* AllocateRaw(size_t alignment, size_t num_bytes) override; | ||
void DeallocateRaw(void* ptr) override; | ||
|
||
void SetStreamAndPreallocateMemory(void* stream) override; | ||
|
||
tsl::AllocatorMemoryType GetMemoryType() const override { | ||
return tsl::AllocatorMemoryType::kDevice; | ||
} | ||
|
||
private: | ||
std::string name_; | ||
cudaStream_t cuda_stream_; | ||
c10::DeviceIndex device_index_; | ||
}; | ||
|
||
} // namespace runtime | ||
} // namespace torch_xla | ||
|
||
#endif // XLA_CLIENT_TORCH_ALLOCATOR_H_ |