Skip to content

Commit

Permalink
ggml : add GGML_PAD_REFLECT_1D operation (ggml/1034)
Browse files Browse the repository at this point in the history
* ggml_pad_reflect_1d defined in header

* implemented on CPU

* called the forward pass

* impl Metal kernel

* added Metal kernel

* added OP_PAD_REFLECT_1D in test-backend-ops.cpp

* add test-pad-reflect-1d test case

* test case support multiple backend
  • Loading branch information
PABannier authored and ggerganov committed Dec 5, 2024
1 parent d405804 commit c2082d9
Show file tree
Hide file tree
Showing 6 changed files with 192 additions and 2 deletions.
8 changes: 8 additions & 0 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,7 @@ extern "C" {
GGML_OP_POOL_2D_BACK,
GGML_OP_UPSCALE, // nearest interpolate
GGML_OP_PAD,
GGML_OP_PAD_REFLECT_1D,
GGML_OP_ARANGE,
GGML_OP_TIMESTEP_EMBEDDING,
GGML_OP_ARGSORT,
Expand Down Expand Up @@ -1695,6 +1696,13 @@ extern "C" {
int p2,
int p3);

// pad each dimension with reflection: [a, b, c, d] -> [b, a, b, c, d, c]
GGML_API struct ggml_tensor * ggml_pad_reflect_1d(
struct ggml_context * ctx,
struct ggml_tensor * a,
int p0,
int p1);

// Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
// timesteps: [N,]
// return: [N, dim]
Expand Down
39 changes: 39 additions & 0 deletions ggml/src/ggml-cpu/ggml-cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -10439,6 +10439,40 @@ static void ggml_compute_forward_pad(
}
}

// ggml_compute_forward_pad_reflect_1d

static void ggml_compute_forward_pad_reflect_1d(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {

const struct ggml_tensor * src0 = dst->src[0];

GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);

const int ith = params->ith;
const int nth = params->nth;

const int32_t * opts = (const int32_t *) dst->op_params;
const int p0 = opts[0];
const int p1 = opts[1];

GGML_TENSOR_UNARY_OP_LOCALS

for (int64_t i3 = 0; i3 < ne3; i3++) {
for (int64_t i2 = 0; i2 < ne2; i2++) {
for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
float * left = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + p0*nb0);
float * right = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (ne0-p1-1)*nb0);

ggml_vec_cpy_f32(ne00, left, (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));

for (int i0 = 1; i0 <= p0; i0++) { left[-i0] = left[i0]; }
for (int i0 = 1; i0 <= p1; i0++) { right[i0] = right[-i0]; }
}
}
}
}

// ggml_compute_forward_arange

Expand Down Expand Up @@ -12535,6 +12569,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_pad(params, tensor);
} break;
case GGML_OP_PAD_REFLECT_1D:
{
ggml_compute_forward_pad_reflect_1d(params, tensor);
} break;
case GGML_OP_ARANGE:
{
ggml_compute_forward_arange(params, tensor);
Expand Down Expand Up @@ -12877,6 +12915,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
} break;
case GGML_OP_UPSCALE:
case GGML_OP_PAD:
case GGML_OP_PAD_REFLECT_1D:
case GGML_OP_ARANGE:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_ARGSORT:
Expand Down
35 changes: 35 additions & 0 deletions ggml/src/ggml-metal/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32,
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
GGML_METAL_KERNEL_TYPE_PAD_F32,
GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32,
GGML_METAL_KERNEL_TYPE_ARANGE_F32,
GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
Expand Down Expand Up @@ -877,6 +878,7 @@ @implementation GGMLMetalClass
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32, conv_transpose_1d_f16_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32, pad_reflect_1d_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
Expand Down Expand Up @@ -1099,6 +1101,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_OP_POOL_2D:
case GGML_OP_UPSCALE:
case GGML_OP_PAD:
case GGML_OP_PAD_REFLECT_1D:
case GGML_OP_ARANGE:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_ARGSORT:
Expand Down Expand Up @@ -3258,6 +3261,38 @@ static void ggml_metal_encode_node(

const int nth = MIN(1024, ne0);

[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_PAD_REFLECT_1D:
{
GGML_ASSERT(src0->type == GGML_TYPE_F32);

const int32_t p0 = ((const int32_t *)(dst->op_params))[0];
const int32_t p1 = ((const int32_t *)(dst->op_params))[1];

id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32].pipeline;

[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:6];
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:11];
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:12];
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:13];
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:14];
[encoder setBytes:&p0 length:sizeof(p0) atIndex:15];
[encoder setBytes:&p1 length:sizeof(p1) atIndex:16];

const int nth = MIN(1024, ne0);

[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_ARANGE:
Expand Down
47 changes: 47 additions & 0 deletions ggml/src/ggml-metal/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -2897,6 +2897,53 @@ kernel void kernel_pad_f32(
}
}

kernel void kernel_pad_reflect_1d_f32(
device const char * src0,
device char * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant int64_t & ne0,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant uint64_t & nb0,
constant uint64_t & nb1,
constant uint64_t & nb2,
constant uint64_t & nb3,
constant int32_t & p0,
constant int32_t & p1,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {

const int64_t i3 = tgpig.z;
const int64_t i2 = tgpig.y;
const int64_t i1 = tgpig.x;

const int64_t i03 = i3;
const int64_t i02 = i2;
const int64_t i01 = i1;

device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);

if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
if (i0 < p0) {
dst_ptr[i0] = src0_ptr[p0 - i0];
} else if (i0 < ne0 - p1) {
dst_ptr[i0] = src0_ptr[i0 - p0];
} else {
dst_ptr[i0] = src0_ptr[(ne0 - p1 - p0) - (p1 + 1 - (ne0 - i0)) - 1];
}
}
}
}

kernel void kernel_arange_f32(
device char * dst,
constant int64_t & ne0,
Expand Down
37 changes: 35 additions & 2 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -950,6 +950,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"POOL_2D_BACK",
"UPSCALE",
"PAD",
"PAD_REFLECT_1D",
"ARANGE",
"TIMESTEP_EMBEDDING",
"ARGSORT",
Expand Down Expand Up @@ -983,7 +984,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"OPT_STEP_ADAMW",
};

static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81");
static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");

static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
Expand Down Expand Up @@ -1045,6 +1046,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"pool_2d_back(x)",
"upscale(x)",
"pad(x)",
"pad_reflect_1d(x)",
"arange(start, stop, step)",
"timestep_embedding(timesteps, dim, max_period)",
"argsort(x)",
Expand Down Expand Up @@ -1078,7 +1080,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"adamw(x)",
};

static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81");
static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");

static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");

Expand Down Expand Up @@ -4097,6 +4099,37 @@ struct ggml_tensor * ggml_pad(
return result;
}

// ggml_pad_reflect_1d

struct ggml_tensor * ggml_pad_reflect_1d(
struct ggml_context * ctx,
struct ggml_tensor * a,
int p0,
int p1) {
GGML_ASSERT(p0 >= 0);
GGML_ASSERT(p1 >= 0);

GGML_ASSERT(p0 < a->ne[0]); // padding length on each size must be less than the
GGML_ASSERT(p1 < a->ne[0]); // existing length of the dimension being padded

GGML_ASSERT(ggml_is_contiguous(a));
GGML_ASSERT(a->type == GGML_TYPE_F32);

struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type,
a->ne[0] + p0 + p1,
a->ne[1],
a->ne[2],
a->ne[3]);

int32_t params[] = { p0, p1 };
ggml_set_op_params(result, params, sizeof(params));

result->op = GGML_OP_PAD_REFLECT_1D;
result->src[0] = a;

return result;
}

// ggml_arange

struct ggml_tensor * ggml_arange(
Expand Down
28 changes: 28 additions & 0 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2697,6 +2697,33 @@ struct test_pad : public test_case {
}
};

// GGML_OP_PAD_REFLECT_1D
struct test_pad_reflect_1d : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne_a;
const int pad_0;
const int pad_1;

std::string vars() override {
return VARS_TO_STR4(type, ne_a, pad_0, pad_1);
}

test_pad_reflect_1d(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne_a = {512, 34, 2, 1},
int pad_0 = 10, int pad_1 = 9)
: type(type), ne_a(ne_a), pad_0(pad_0), pad_1(pad_1) {}

ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 2, ne_a.data());
ggml_set_name(a, "a");

ggml_tensor * out = ggml_pad_reflect_1d(ctx, a, pad_0, pad_1);
ggml_set_name(out, "out");

return out;
}
};

// GGML_OP_ARANGE
struct test_arange : public test_case {
const ggml_type type;
Expand Down Expand Up @@ -3816,6 +3843,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {9, 9, 1280, 1}));
test_cases.emplace_back(new test_acc());
test_cases.emplace_back(new test_pad());
test_cases.emplace_back(new test_pad_reflect_1d());
test_cases.emplace_back(new test_arange());
test_cases.emplace_back(new test_timestep_embedding());
test_cases.emplace_back(new test_leaky_relu());
Expand Down

0 comments on commit c2082d9

Please sign in to comment.