Skip to content

Commit

Permalink
add initial tests
Browse files Browse the repository at this point in the history
  • Loading branch information
shubhambhokare1 committed Feb 1, 2025
1 parent 8db97a6 commit b6e6b14
Show file tree
Hide file tree
Showing 6 changed files with 317 additions and 3 deletions.
6 changes: 6 additions & 0 deletions onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSNchwcDomai
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 16, float, LayerNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 16, double, LayerNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 16, MLFloat16, LayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, RMSNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, RMSNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, RMSNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float, SimplifiedLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double, SimplifiedLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, MLFloat16, SimplifiedLayerNormalization);
Expand Down Expand Up @@ -351,6 +354,9 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 16, float, LayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 16, double, LayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 16, MLFloat16, LayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, RMSNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, RMSNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, RMSNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float, SimplifiedLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double, SimplifiedLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, MLFloat16, SimplifiedLayerNormalization)>,
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/contrib_ops/cpu/layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ namespace contrib {
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("U", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("V", DataTypeImpl::GetTensorType<T>()), \
LayerNorm<true>); \
ONNX_OPERATOR_TYPED_KERNEL_EX(RMSNormalization, kMSDomain, 1, T, kCpuExecutionProvider, \
KernelDefBuilder() \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("V", DataTypeImpl::GetTensorType<T>()), \
LayerNorm<true>);

REGISTER_CONTRIB_KERNELS(float)
Expand Down
152 changes: 152 additions & 0 deletions onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1340,6 +1340,158 @@ ONNX_MS_OPERATOR_SET_SCHEMA(GreedySearch, 1,
GreedySearchShapeInference(ctx);
}));

static const char* RMSNormalization_ver1_doc = R"DOC(
This is RMS normalization defined in ONNX as function as described in the paper https://arxiv.org/pdf/1910.07467.
The overall computation can be split into two stages. The root mean squared norm is taken over the last D dimensions,
where D is the dimension of normalized_shape. For example, if normalized_shape is (3, 5) (a 2-dimensional shape),
the rms norm is computed over the last 2 dimensions of the input. The computation required by standardization can be
described by the following equations.
```
XSquared = Mul(X, X)
XSquaredMean = ReduceMean<axes=normalized_axes>(XSquared)
RMSEps = Add(XSquaredMean, epsilon)
RMS = Sqrt(RMSEps)
Normalized = Div(X, RMS)
```
where `normalized_axes` is `[axis, ..., rank of X - 1]`. The variables `RMS` stand for root mean square,
Depending on `stash_type` attribute, the actual computation
must happen in different floating-point precision.
For example, if `stash_type` is 1, this operator casts
all input variables to 32-bit float, perform the computation, and
finally cast `Normalized` back to the original type of `X`.
The second stage then scales the outcome of the first stage using:
```
Y= Mul(Normalized, Scale)
```
Let `d[i]` indicate the i-th dimension of `X`.
If `X`'s shape is `[d[0], ..., d[axis-1], d[axis], ..., d[rank-1]]`,
the shape of `RMS` is `[d[0], ..., d[axis-1], 1, ..., 1]`.
`Y` and `X` have the same shape. This operator supports unidirectional broadcasting
(tensors `Scale` and `B` should be unidirectional broadcastable to tensor `X`);
for more details please check [the doc](Broadcasting.md).
)DOC";

ONNX_MS_OPERATOR_SET_SCHEMA(
RMSNormalization,
1,
OpSchema()
.SetDoc(RMSNormalization_ver1_doc)
.Attr(
"axis",
"The first normalization dimension: normalization will be performed along dimensions axis : rank(inputs).",
AttributeProto::INT,
static_cast<int64_t>(-1))
.Attr("epsilon", "The epsilon value to use to avoid division by zero.", AttributeProto::FLOAT, 1e-5f)
.Attr(
"stash_type",
"The floating-point precision used in stage one of the computation.",
AttributeProto::INT,
static_cast<int64_t>(ONNX_NAMESPACE::TensorProto_DataType_FLOAT))
.AllowUncheckedAttributes()
.Input(
0,
"X",
"The output of the layer for which the skip connection is being created. "
"In general, the shape is (N, C, D1, D2, ... , Dn) for n-dimensional data, where "
"D1 to Dn are the spatial dimension sizes and N is the batch size, C is the number of channels. "
"The root mean squared norm is taken over the last D dimensions, D is determined by the axis attribute.",
"T")
.Input(
1,
"scale",
"Scale tensor. Scale tensor shape should be broadcastable to the normalized shape ([axis, .., Dn]).",
"V")
.Output(0, "Y", "Output data tensor. Same shape as X", "V")
.TypeConstraint(
"T",
{"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"},
"Constrain input X type to float tensors.")
.TypeConstraint(
"V",
{"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"},
"Constrain output Y and scale type to float tensors.")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
propagateShapeAndTypeFromFirstInput(ctx);
if (!hasNInputShapes(ctx, 1)) {
return;
}
propagateShapeFromInputToOutput(ctx, 0, 0);
auto& input_shape = ctx.getInputType(0)->tensor_type().shape();
int64_t input_ndim = input_shape.dim_size();
int64_t axis = -1;
auto axis_proto = ctx.getAttribute("axis");
if (axis_proto) {
axis = axis_proto->i();
}
if (axis < 0) {
// Convert negative axis value to equivalent
// positive value.
axis += input_ndim;
}
if (axis < 0) {
fail_shape_inference(
"Unexpected axis value (",
axis,
") rank of first input is ",
input_ndim,
" in ",
ctx.getDisplayName(),
".");
}
})
.SetContextDependentFunctionBodyBuilder([](const FunctionBodyBuildContext& ctx,
const OpSchema& schema,
FunctionProto& functionProto) {
// RMSNormalization <axis, epsilon, stash_type> (X, Scale) => (Y)
auto* tp = ctx.getInputType(0);
if ((tp == nullptr) || (!tp->has_tensor_type()))
return false;
int64_t T = tp->tensor_type().elem_type();

auto type_attr = ctx.getAttribute("stash_type");
int64_t U = (type_attr != nullptr) ? type_attr->i()
: static_cast<int64_t>(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
if ((U != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) &&
(U != ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) &&
(U != ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) && (U != ONNX_NAMESPACE::TensorProto_DataType_DOUBLE))
return false; // Error

auto* axis_attr = ctx.getAttribute("axis");
int64_t axis = (axis_attr != nullptr) ? axis_attr->i() : -1;
auto* epsilon_attr = ctx.getAttribute("epsilon");
float epsilon = (epsilon_attr != nullptr) ? epsilon_attr->f() : 1e-5f;

auto mktensor = [](int64_t val) -> ONNX_NAMESPACE::TensorProto {
auto tp = ONNX_NAMESPACE::ToTensor(std::vector<int64_t>{val});
tp.add_dims(1);
return tp;
};

FunctionBuilder builder(functionProto);
builder.Const("FloatEpsilon", ToTensor<float>(epsilon))
.Add("Epsilon = Cast (FloatEpsilon)", "to", U)
.Add("XShape = Shape (X)") // shape of input tensor: 1D tensor
.Add("Rank = Size (XShape)") // rank of input tensor: scalar
.Add("Axis1D = Constant()", "value", mktensor(axis)) // [axis] : 1D tensor
.Add(
axis >= 0 // number of axes that are reduced =
? "PosAxis1D = Identity (Axis1D)" // [axis]: 1D tensor
: "PosAxis1D = Add (Rank, Axis1D)") // [rank + axis] : 1D tensor
.Const1D("One1D", (int64_t)1)
.Add("ReduceAxes = Range(PosAxis1D, Rank, One1D)")
.Add("XU = Cast (X)", "to", U);
builder.Add("XSquared = Mul (XU, XU)")
.Add("XSquaredMean = ReduceMean (XSquared, ReduceAxes)")
.Add("RMSPlusEpsilon = Add (XSquaredMean, Epsilon)")
.Add("RMS = Sqrt (RMSPlusEpsilon)")
.Add("Normalized = Div (XU, RMS)")
.Add("NormalizedT = Cast (Normalized)", "to", T);
builder.Add("Y = Mul (NormalizedT, scale)");

schema.BuildFunction(functionProto);
return true;
}));

ONNX_MS_OPERATOR_SET_SCHEMA(Sampling, 1,
OpSchema()
.SetDoc("Greedy Sampling for text generation.")
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/graph/contrib_ops/ms_opset.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Sampling);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipGroupNorm);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipLayerNormalization);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipSimplifiedLayerNormalization);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RMSNormalization);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SparseAttention);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SparseToDenseMatMul);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Tokenizer);
Expand Down Expand Up @@ -216,6 +217,7 @@ class OpSet_Microsoft_ver1 {
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipGroupNorm)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipLayerNormalization)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipSimplifiedLayerNormalization)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RMSNormalization)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SparseToDenseMatMul)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SparseAttention)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Tokenizer)>());
Expand Down
30 changes: 27 additions & 3 deletions onnxruntime/test/contrib_ops/layer_norm_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ namespace test {
constexpr auto k_epsilon_default = 1e-5f;
constexpr auto k_random_data_min = -10.0f;
constexpr auto k_random_data_max = 10.0f;
const std::string RMS_NORM_OP = "RMSNormalization";
const std::string SIMPLIFIED_LAYER_NORM_OP = "SimplifiedLayerNormalization";
const std::string LAYER_NORM_OP = "LayerNormalization";

Expand Down Expand Up @@ -58,7 +59,7 @@ static void TestLayerNorm(const std::vector<int64_t>& x_dims,

test.AddInput<float>("X", n_x_m_dims, X_data);
test.AddInput<float>("scale", m_dims, scale_data, true);
if (op.compare(SIMPLIFIED_LAYER_NORM_OP) != 0 && no_bias == false) {
if (op.compare(SIMPLIFIED_LAYER_NORM_OP) != 0 && op.compare(RMS_NORM_OP) != 0 && no_bias == false) {
test.AddInput<float>("B", m_dims, B_data, true);
}

Expand All @@ -72,10 +73,12 @@ static void TestLayerNorm(const std::vector<int64_t>& x_dims,
std::vector<float> var_data = FillZeros<float>(stats_dims);

// the Main and InvStdDev outputs are training specific
if (op.compare(SIMPLIFIED_LAYER_NORM_OP) != 0) {
if (op.compare(SIMPLIFIED_LAYER_NORM_OP) != 0 && op.compare(RMS_NORM_OP) != 0) {
test.AddOutput<float>("mean", stats_dims, mean_data);
}
test.AddOutput<float>("var", stats_dims, var_data);
if (op.compare(RMS_NORM_OP) != 0) {
test.AddOutput<float>("var", stats_dims, var_data);
}
#endif

#ifdef USE_CUDA
Expand Down Expand Up @@ -154,6 +157,27 @@ TEST(CudaKernelTest, SimplifiedLayerNorm_LargeSizeTensor) {
TestLayerNorm(X_dims, SIMPLIFIED_LAYER_NORM_OP, k_epsilon_default);
}

TEST(CudaKernelTest, RMSNorm_SmallSizeTensor) {
const std::vector<int64_t> X_dims{4, 20, 128};
TestLayerNorm(X_dims, RMS_NORM_OP, k_epsilon_default);
}

TEST(CudaKernelTest, RMSNorm_SmallSizeTensor_IntermediateAxis) {
const std::vector<int64_t> X_dims{4, 20, 8, 16};
constexpr int64_t axis = -2;
TestLayerNorm(X_dims, RMS_NORM_OP, k_epsilon_default, axis);
}

TEST(CudaKernelTest, RMSNorm_MidSizeTensor) {
std::vector<int64_t> X_dims{8, 80, 768};
TestLayerNorm(X_dims, RMS_NORM_OP, k_epsilon_default);
}

TEST(CudaKernelTest, RMSNorm_LargeSizeTensor) {
std::vector<int64_t> X_dims{16, 512, 1024};
TestLayerNorm(X_dims, RMS_NORM_OP, k_epsilon_default);
}

// LayerNormalization is an ONNX operator in opset 17. It uses the same implementation so this is just a sanity check.
TEST(CudaKernelTest, LayerNorm_SmallSizeTensor_Opset17) {
const std::vector<int64_t> X_dims{4, 20, 128};
Expand Down
Loading

0 comments on commit b6e6b14

Please sign in to comment.