Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate faster kernels for delay calculation #58

Open
maedoc opened this issue Feb 16, 2024 · 1 comment
Open

Integrate faster kernels for delay calculation #58

maedoc opened this issue Feb 16, 2024 · 1 comment

Comments

@maedoc
Copy link
Member

maedoc commented Feb 16, 2024

In another repo, we prototyped (much) faster delay calculation, and the kernels are not particularly complicated, and everything else can be written in Jax no problem. Jax uses pybind11 and a bunch of complicated stuff to add primitives,

probably as much binding code as the actual kernels themselves:

void delays2(int nv, int nh, int t,
             float *out1, float *out2,
             float *buf, float *weights, int *idelays, int *indices, int *indptr)
{
    // nh is power of two, so x&(nh-1) is faster way to compute x%nh
    int nhm = nh - 1;
    #pragma omp parallel for
    for (int i=0; i<nv; i++)
    {
        // compute coupling terms for both Heun stages
        float acc1 = 0.0f, acc2 = 0.0f;
        #pragma omp simd reduction(+:acc1,acc2)
        for (int j=indptr[i]; j<indptr[i+1]; j++) {
            float *b = buf + indices[j]*nh;
            float w = weights[j];
            int roll_t = nh + t - idelays[j];
            acc1 += w * b[(roll_t+0) & nhm];
            acc2 += w * b[(roll_t+1) & nhm];
        }
        out1[i] = acc1;
        out2[i] = acc2;
    }
}

// variant which updates the buf with current state
void delays2_upbuf(int nv, int nh, int t,
             float *out1, float *out2,
             float *buf, float *weights, int *idelays, int *indices, int *indptr,
             float *x)
{
    // nh is power of two, so x&(nh-1) is faster way to compute x%nh
    int nhm = nh - 1;
    #pragma omp parallel for
    for (int i=0; i<nv; i++)
    {
        // update buffer
        buf[i*nh + ((nh + t) & nhm)] = x[i];
        // compute coupling terms for both Heun stages
        float acc1 = 0.0f, acc2 = 0.0f;
        #pragma omp simd reduction(+:acc1,acc2)
        for (int j=indptr[i]; j<indptr[i+1]; j++) {
            float *b = buf + indices[j]*nh;
            float w = weights[j];
            int roll_t = nh + t - idelays[j];
            acc1 += w * b[(roll_t+0) & nhm];
            acc2 += w * b[(roll_t+1) & nhm];
        }
        out1[i] = acc1;
        out2[i] = acc2;
    }
}
@maedoc
Copy link
Member Author

maedoc commented Feb 16, 2024

These kernels don't operate on batches which would be even better/faster.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant