forked from cad-audio/executorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
custom_ops_1_out.cpp
55 lines (49 loc) · 1.67 KB
/
custom_ops_1_out.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <executorch/runtime/kernel/kernel_includes.h>
namespace custom {
namespace native {
using exec_aten::ScalarType;
using exec_aten::Tensor;
using executorch::runtime::KernelRuntimeContext;
namespace {
void check_preconditions(const Tensor& in, Tensor& out) {
ET_CHECK_MSG(
out.scalar_type() == ScalarType::Float,
"Expected out tensor to have dtype Float, but got %hhd instead",
static_cast<int8_t>(out.scalar_type()));
ET_CHECK_MSG(
in.scalar_type() == ScalarType::Float,
"Expected in tensor to have dtype Float, but got %hhd instead",
static_cast<int8_t>(in.scalar_type()));
ET_CHECK_MSG(
out.dim() == in.dim(),
"Number of dims of out tensor is not compatible with inputs");
ET_CHECK_MSG(
out.numel() == in.numel(),
"Number of elements of out tensor %zd is not compatible with inputs %zd",
ssize_t(out.numel()),
ssize_t(in.numel()));
}
} // namespace
// mul3.out(Tensor input, *, Tensor(a!) output) -> Tensor(a!)
// ExecuTorch-compatible function signature, with a KernelRuntimeContext.
Tensor& mul3_out_impl(
ET_UNUSED KernelRuntimeContext& ctx,
const Tensor& in,
Tensor& out) {
check_preconditions(in, out);
float* out_data = out.mutable_data_ptr<float>();
const float* in_data = in.const_data_ptr<float>();
for (size_t out_idx = 0; out_idx < out.numel(); ++out_idx) {
out_data[out_idx] = in_data[out_idx] * 3;
}
return out;
}
} // namespace native
} // namespace custom