forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathDeviceAccelerator.cpp
37 lines (33 loc) · 1.25 KB
/
DeviceAccelerator.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
#include <ATen/Context.h>
#include <ATen/DeviceAccelerator.h>
namespace at {
C10_API std::optional<DeviceType> getAccelerator(bool checked) {
#define DETECT_AND_ASSIGN_ACCELERATOR(device_name) \
if (at::has##device_name()) { \
device_type = k##device_name; \
TORCH_CHECK( \
!is_accelerator_detected, \
"Cannot have ", \
device_type.value(), \
" with other accelerators."); \
is_accelerator_detected = true; \
}
if (is_privateuse1_backend_registered()) {
// We explicitly allow PrivateUse1 and another device at the same time as we
// use this for testing. Whenever a PrivateUse1 device is registered, use it
// first.
return kPrivateUse1;
}
std::optional<DeviceType> device_type = std::nullopt;
bool is_accelerator_detected = false;
DETECT_AND_ASSIGN_ACCELERATOR(CUDA)
DETECT_AND_ASSIGN_ACCELERATOR(MTIA)
DETECT_AND_ASSIGN_ACCELERATOR(XPU)
if (checked) {
TORCH_CHECK(
device_type, "Cannot access accelerator device when none is available.")
}
return device_type;
#undef DETECT_AND_ASSIGN_ACCELERATOR
}
} // namespace at