forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
LerpKernel.cpp
136 lines (123 loc) · 5.25 KB
/
LerpKernel.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
#define TORCH_ASSERT_NO_OPERATORS
#include <ATen/native/Lerp.h>
#include <ATen/Dispatch.h>
#include <ATen/TensorIterator.h>
#include <ATen/native/cpu/Loops.h>
#include <c10/util/irange.h>
namespace at {
namespace native {
namespace {
template <typename scalar_t>
Vectorized<scalar_t> is_lerp_weight_small(Vectorized<scalar_t> weight) {
static_assert(!c10::is_complex<scalar_t>::value, "");
return weight.abs() < Vectorized<scalar_t>(0.5);
}
// is_lerp_weight_small doesn't work for complex because z.abs() returns a
// complex vector which can't be compared. Either implement it with z.abs_2_(),
// or fallback to the scalar function.
#if !(defined(CPU_CAPABILITY_DEFAULT) || defined(_MSC_VER))
template <typename value_t>
Vectorized<c10::complex<value_t>> is_lerp_weight_small(Vectorized<c10::complex<value_t>> weight) {
using vec_reg_t = decltype(weight.abs_2_());
vec_reg_t mask = Vectorized<value_t>(weight.abs_2_()) < Vectorized<value_t>(0.25);
return Vectorized<c10::complex<value_t>>(mask);
}
#else
template <typename scalar_t>
Vectorized<scalar_t> lerp_vec_map(Vectorized<scalar_t> start, Vectorized<scalar_t> end, Vectorized<scalar_t> weight) {
using vec_t = Vectorized<scalar_t>;
__at_align__ scalar_t start_arr[vec_t::size()];
__at_align__ scalar_t end_arr[vec_t::size()];
__at_align__ scalar_t weight_arr[vec_t::size()];
__at_align__ scalar_t result_arr[vec_t::size()];
start.store(start_arr);
end.store(end_arr);
weight.store(weight_arr);
for (auto i : c10::irange(vec_t::size())) {
result_arr[i] = lerp(start_arr[i], end_arr[i], weight_arr[i]);
}
return vec_t::loadu(result_arr);
}
template <typename value_t>
Vectorized<c10::complex<value_t>> lerp_vec(Vectorized<c10::complex<value_t>> start, Vectorized<c10::complex<value_t>> end, Vectorized<c10::complex<value_t>> weight) {
return lerp_vec_map(start, end, weight);
}
#endif
template <typename scalar_t>
Vectorized<scalar_t> lerp_vec(Vectorized<scalar_t> start, Vectorized<scalar_t> end, Vectorized<scalar_t> weight) {
using vec_t = Vectorized<scalar_t>;
auto mask = is_lerp_weight_small(weight);
auto coeff = vec_t::blendv(weight - vec_t(1), weight, mask);
auto base = vec_t::blendv(end, start, mask);
return vec::fmadd(coeff, end - start, base);
}
void lerp_scalar_kernel(at::TensorIteratorBase& iter, const Scalar& weight) {
if (iter.common_dtype() == kBFloat16) {
using bVec = Vectorized<BFloat16>;
using fVec = Vectorized<float>;
float weight_val = weight.to<float>();
auto weight_vec = fVec(weight_val);
at::native::cpu_kernel_vec(
iter,
[weight_val](BFloat16 self_val, BFloat16 end_val) -> BFloat16 {
return lerp(self_val, end_val, weight_val);
},
[=](bVec self_vec, bVec end_vec) -> bVec {
fVec self_vec0, self_vec1, end_vec0, end_vec1;
std::tie(self_vec0, self_vec1) = convert_bfloat16_float(self_vec);
std::tie(end_vec0, end_vec1) = convert_bfloat16_float(end_vec);
auto result0 = lerp_vec(self_vec0, end_vec0, weight_vec);
auto result1 = lerp_vec(self_vec1, end_vec1, weight_vec);
return convert_float_bfloat16(result0, result1);
});
} else {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.common_dtype(), "lerp_kernel_scalar", [&] {
auto weight_val = weight.to<scalar_t>();
at::native::cpu_kernel_vec(
iter,
[weight_val](scalar_t self_val, scalar_t end_val) {
return lerp(self_val, end_val, weight_val);
},
[weight_val](Vectorized<scalar_t> self, Vectorized<scalar_t> end) {
const Vectorized<scalar_t> weight(weight_val);
return lerp_vec(self, end, weight);
});
});
}
}
void lerp_tensor_kernel(at::TensorIteratorBase& iter) {
if (iter.common_dtype() == kBFloat16) {
using bVec = Vectorized<BFloat16>;
using fVec = Vectorized<float>;
at::native::cpu_kernel_vec(
iter,
[=](BFloat16 self_val, BFloat16 end_val, BFloat16 weight_val) -> BFloat16 {
return lerp(self_val, end_val, weight_val);
},
[=](bVec self_vec, bVec end_vec, bVec weight_vec) -> bVec {
fVec self_vec0, self_vec1, end_vec0, end_vec1, weight_vec0, weight_vec1;
std::tie(self_vec0, self_vec1) = convert_bfloat16_float(self_vec);
std::tie(end_vec0, end_vec1) = convert_bfloat16_float(end_vec);
std::tie(weight_vec0, weight_vec1) = convert_bfloat16_float(weight_vec);
auto result0 = lerp_vec(self_vec0, end_vec0, weight_vec0);
auto result1 = lerp_vec(self_vec1, end_vec1, weight_vec1);
return convert_float_bfloat16(result0, result1);
});
} else {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.common_dtype(), "lerp_kernel_tensor", [&] {
at::native::cpu_kernel_vec(
iter,
[](scalar_t self_val, scalar_t end_val, scalar_t weight_val) {
return lerp(self_val, end_val, weight_val);
},
[](Vectorized<scalar_t> self_val, Vectorized<scalar_t> end_val, Vectorized<scalar_t> weight_val) {
return lerp_vec(self_val, end_val, weight_val);
});
});
}
}
} // anonymous namespace
REGISTER_DISPATCH(lerp_kernel_scalar_weight, &lerp_scalar_kernel);
REGISTER_DISPATCH(lerp_kernel_tensor_weight, &lerp_tensor_kernel);
} // namespace native
} // namespace at