Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallelize unary tanh on cpu, generalize ADD to allow more shapes #580

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 107 additions & 48 deletions src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -5558,8 +5558,8 @@ static struct ggml_tensor * ggml_add_impl(
struct ggml_tensor * b,
bool inplace) {
// TODO: support less-strict constraint
// GGML_ASSERT(ggml_can_repeat(b, a));
GGML_ASSERT(ggml_can_repeat_rows(b, a));
GGML_ASSERT(ggml_can_repeat(b, a));
// GGML_ASSERT(ggml_can_repeat_rows(b, a));

bool is_node = false;

Expand Down Expand Up @@ -9288,7 +9288,7 @@ static void ggml_compute_forward_add_f32(
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
GGML_ASSERT(ggml_can_repeat_rows(src1, src0) && ggml_are_same_shape(src0, dst));
GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));

if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;
Expand All @@ -9297,65 +9297,114 @@ static void ggml_compute_forward_add_f32(
const int ith = params->ith;
const int nth = params->nth;

const int nr = ggml_nrows(src0);
const int nr = ggml_nrows(src0);
const int ne = ggml_nelements(src0);

GGML_TENSOR_BINARY_OP_LOCALS

GGML_ASSERT( nb0 == sizeof(float));
GGML_ASSERT(nb00 == sizeof(float));
// GGML_ASSERT( nb0 == sizeof(float));
// GGML_ASSERT(nb00 == sizeof(float));

// rows per thread
// rows and elements per thread
const int dr = (nr + nth - 1)/nth;
const int de = (ne + nth - 1)/nth;

// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);

if (nb10 == sizeof(float)) {
for (int ir = ir0; ir < ir1; ++ir) {
// src1 is broadcastable across src0 and dst in i1, i2, i3
const int64_t i03 = ir/(ne02*ne01);
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);

const int64_t i13 = i03 % ne13;
const int64_t i12 = i02 % ne12;
const int64_t i11 = i01 % ne11;

float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);

#ifdef GGML_USE_ACCELERATE
vDSP_vadd(src0_ptr, 1, src1_ptr, 1, dst_ptr, 1, ne00);
#else
ggml_vec_add_f32(ne00, dst_ptr, src0_ptr, src1_ptr);
#endif
}
} else {
// src1 is not contiguous
for (int ir = ir0; ir < ir1; ++ir) {
// src1 is broadcastable across src0 and dst in i1, i2, i3
const int64_t i03 = ir/(ne02*ne01);
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
// element range for this thread
const int ie0 = de*ith;
const int ie1 = MIN(ie0 + de, ne);

if (nb00 == sizeof(float) && nb0 == sizeof(float) && ggml_can_repeat_rows(src1, src0)) {
if (nb10 == sizeof(float)) {
for (int ir = ir0; ir < ir1; ++ir) {
// src1 is broadcastable across src0 and dst in i1, i2, i3
const int64_t i03 = ir/(ne02*ne01);
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);

const int64_t i13 = i03 % ne13;
const int64_t i12 = i02 % ne12;
const int64_t i11 = i01 % ne11;

float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);

#ifdef GGML_USE_ACCELERATE
vDSP_vadd(src0_ptr, 1, src1_ptr, 1, dst_ptr, 1, ne00);
#else
ggml_vec_add_f32(ne00, dst_ptr, src0_ptr, src1_ptr);
#endif
}
} else {
// src1 is not contiguous
for (int ir = ir0; ir < ir1; ++ir) {
// src1 is broadcastable across src0 and dst in i1, i2, i3
const int64_t i03 = ir/(ne02*ne01);
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);

const int64_t i13 = i03 % ne13;
const int64_t i12 = i02 % ne12;
const int64_t i11 = i01 % ne11;
const int64_t i13 = i03 % ne13;
const int64_t i12 = i02 % ne12;
const int64_t i11 = i01 % ne11;

float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);

for (int i0 = 0; i0 < ne0; i0++) {
float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i0*nb10);
for (int i0 = 0; i0 < ne0; i0++) {
float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i0*nb10);

dst_ptr[i0] = src0_ptr[i0] + *src1_ptr;
dst_ptr[i0] = src0_ptr[i0] + *src1_ptr;
}
}
}
}
else {
if (nb00 == sizeof(float) && nb0 == sizeof(float) && nb10 == sizeof(float) && ne10 == 1) {
for (int ir = ir0; ir < ir1; ++ir) {
// src1 is broadcastable across src0 and dst in i1, i2, i3
const int64_t i03 = ir/(ne02*ne01);
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);

const int64_t i13 = i03 % ne13;
const int64_t i12 = i02 % ne12;
const int64_t i11 = i01 % ne11;

float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);

ggml_vec_add1_f32(ne00, dst_ptr, src0_ptr, *src1_ptr);
}
} else {
// all are not contiguous
for (int ie = ie0; ie < ie1; ++ie) {
// src1 is broadcastable across src0 and dst in i1, i2, i3
const int64_t i03 = ie/(ne02*ne01*ne00);
const int64_t i02 = (ie - i03*ne02*ne01*ne00)/(ne01*ne00);
const int64_t i01 = (ie - i03*ne02*ne01*ne00 - i02*ne01*ne00);
const int64_t i00 = (ie - i03*ne02*ne01*ne00 - i02*ne01*ne00 - i01*ne00);

const int64_t i13 = i03 % ne13;
const int64_t i12 = i02 % ne12;
const int64_t i11 = i01 % ne11;
const int64_t i10 = i00 % ne10;

float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 + i00*nb0 );
float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);

*dst_ptr = *src0_ptr + *src1_ptr;
}
Comment on lines +9383 to 9402
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This branch will probably be very slow - most of the computation will probably go into computing the indices

}
}
}


static void ggml_compute_forward_add_f16_f32(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
Expand Down Expand Up @@ -11018,16 +11067,26 @@ static void ggml_compute_forward_tanh_f32(
return;
}

const int n = ggml_nrows(src0);
const int ith = params->ith;
const int nth = params->nth;

const int nr = ggml_nrows(src0);
const int nc = src0->ne[0];

assert(dst->nb[0] == sizeof(float));
assert(src0->nb[0] == sizeof(float));

for (int i = 0; i < n; i++) {
// rows per thread
const int dr = (nr + nth - 1)/nth;

// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);

for (int ir = ir0; ir < ir1; ir++) {
ggml_vec_tanh_f32(nc,
(float *) ((char *) dst->data + i*( dst->nb[1])),
(float *) ((char *) src0->data + i*(src0->nb[1])));
(float *) ((char *) dst->data + ir*( dst->nb[1])),
(float *) ((char *) src0->data + ir*(src0->nb[1])));
}
}

Expand Down Expand Up @@ -18685,13 +18744,13 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
case GGML_UNARY_OP_SGN:
case GGML_UNARY_OP_NEG:
case GGML_UNARY_OP_STEP:
case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_ELU:
case GGML_UNARY_OP_RELU:
{
n_tasks = 1;
} break;

case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_GELU:
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_SILU:
Expand Down