Skip to content

Commit

Permalink
Parallelize some small computations for invstd calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
Raragyay committed Nov 11, 2023
1 parent 3f2348e commit 7933cf8
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 34 deletions.
1 change: 1 addition & 0 deletions sw/dnn/batchnorm/data/datagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ def emit_header(**kwargs):
data_str += [format_array_declaration(ctype, grad_ofmap_uid, grad_ofmap.shape)]
data_str += [format_array_declaration(ctype, grad_weight_uid, grad_weight.shape)]
data_str += [format_array_declaration(ctype, grad_bias_uid, grad_bias.shape)]
# data_str += [format_array_declaration(ctype, "temp", (8,ci))]
# Layer struct
data_str += [format_struct_definition("batchnorm_layer_t", "layer", layer_cfg)]
data_str += [
Expand Down
2 changes: 1 addition & 1 deletion sw/dnn/batchnorm/data/params.hjson
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
{
input_dim: {
channels: 8
height: 4,
height: 4
width: 4
}
eps: 1e-5
Expand Down
71 changes: 39 additions & 32 deletions sw/dnn/batchnorm/src/batchnorm.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
// Licensed under the Apache License, Version 2.0, see LICENSE for details.
// SPDX-License-Identifier: Apache-2.0

#include <math.h>
#include "printf.h"
#include "snrt.h"

typedef struct {
uint32_t CI;
uint32_t IH;
Expand Down Expand Up @@ -386,39 +387,45 @@ static inline void batchnorm_backwards(batchnorm_backward_layer_t *l) {
ptr += grad_ifmap_len;

if (snrt_is_dm_core()) {
// might need to tile this later
snrt_dma_start_1d(grad_ofmap_scratch, l->grad_ofmap,
grad_ofmap_len * sizeof(double));
snrt_dma_wait_all();
snrt_cluster_hw_barrier();
snrt_cluster_hw_barrier();
} else if (snrt_is_compute_core()) {
if (compute_id == 0) {
for (uint32_t channel = 0; channel < C; ++channel) {
invstd_scratch[channel] =
1 / sqrt(l->running_var[channel] + eps);
}
uint32_t start_invstd_comp = snrt_mcycle();
for (uint32_t channel = compute_id; channel < C;
channel += num_compute_cores) {
invstd_scratch[channel] = 1 / sqrt(l->running_var[channel] + eps);
}
uint32_t end_invstd_comp = snrt_mcycle();
// wait for grad_ofmap to be loaded in
snrt_fpu_fence();
snrt_cluster_hw_barrier();
// reduce over [num_points, C] to [num_threads, C] by splitting over
// num_points
// assumes divisibility for now
// reduce from [num_points, C] to [num_threads, C] by splitting over
// num_points
// Read from ofmap.
uint32_t start_grad_bias_weight_reduction_1 = snrt_mcycle();
for (uint32_t i = compute_id; i < num_points; i += num_compute_cores) {
for (uint32_t channel = 0; channel < C; ++channel) {
double dy = grad_ofmap_scratch[i * C + channel];
double x = l->ifmap[i * C + channel];
double mean = l->running_mean[channel];
double dot_res = dy * (x - mean);
// currently accessing channel 1, 2, 3, 4, 1, 2, 3, 4
grad_bias_scratch[compute_id * C + channel] +=
grad_ofmap_scratch[i * C + channel];
grad_weight_scratch[compute_id * C + channel] +=
dot_res * invstd_scratch[channel];
}
}
uint32_t end_grad_bias_weight_reduction_1 = snrt_mcycle();
snrt_fpu_fence();
snrt_cluster_hw_barrier();

// reduce over [num_threads, C] to [C] by splitting over C
// reduce from [num_threads, C] to [C] by splitting over C
// just reduce back into the first buffer.
uint32_t start_grad_bias_weight_reduction_2 = snrt_mcycle();
for (uint32_t channel = compute_id; channel < C;
channel += num_compute_cores) {
register volatile double grad_bias_sum = 0;
Expand All @@ -437,19 +444,18 @@ static inline void batchnorm_backwards(batchnorm_backward_layer_t *l) {
"frep.o %[n_frep], 2, 0, 0 \n"
"fadd.d %[bias_sum], ft0, %[bias_sum] \n"
"fadd.d %[weight_sum], ft1, %[weight_sum] \n"
: [bias_sum] "+fr"(grad_bias_sum),[weight_sum] "+fr"(grad_weight_sum)
: [bias_sum] "+fr"(grad_bias_sum), [weight_sum] "+fr"(
grad_weight_sum)
: [n_frep] "r"(num_compute_cores -
1) // we repeat n_frep+1 times
: "ft0", "ft1", "ft2");
snrt_fpu_fence();
snrt_ssr_disable();
// for (uint32_t core_id = 0; core_id < num_compute_cores;
// ++core_id) {
// grad_bias_sum += grad_bias_scratch[core_id * C + channel];
// }
grad_bias_scratch[0 * C + channel] = grad_bias_sum;
grad_weight_scratch[0 * C + channel] = grad_weight_sum;
}

uint32_t end_grad_bias_weight_reduction_2 = snrt_mcycle();
}
snrt_cluster_hw_barrier();

Expand All @@ -465,33 +471,32 @@ static inline void batchnorm_backwards(batchnorm_backward_layer_t *l) {
grad_ifmap_len * sizeof(double));
snrt_dma_wait_all();
} else if (snrt_is_compute_core()) {
if (compute_id == 0) { // can try parallelizing aftewrards
for (uint32_t channel = 0; channel < C; ++channel) {
invstd_scratch[channel] *=
l->weight[channel]; // can dma weight, can also frep
}
uint32_t start_invstd_scratch_inplace_augmentation = snrt_mcycle();
for (uint32_t channel = compute_id; channel < C;
channel += num_compute_cores) {
invstd_scratch[channel] *=
l->weight[channel]; // can dma weight, can also frep
}
uint32_t end_invstd_scratch_inplace_augmentation = snrt_mcycle();
snrt_cluster_hw_barrier();
snrt_fpu_fence();
uint32_t start_ifmap_grad_computation = snrt_mcycle();
uint32_t num_points_work_for_core =
get_core_num_work_items(num_points, num_compute_cores, compute_id);

snrt_ssr_loop_2d(
SNRT_SSR_DM0,
num_points_work_for_core /*check modulo for non-multiples?*/, C,
num_compute_cores * C * sizeof(double), sizeof(double));
snrt_ssr_loop_2d(
SNRT_SSR_DM1,
num_points_work_for_core /*check modulo for non-multiples?*/, C,
num_compute_cores * C * sizeof(double), sizeof(double));
snrt_ssr_loop_2d(SNRT_SSR_DM0, num_points_work_for_core, C,
num_compute_cores * C * sizeof(double),
sizeof(double));
snrt_ssr_loop_2d(SNRT_SSR_DM1, num_points_work_for_core, C,
num_compute_cores * C * sizeof(double),
sizeof(double));
// repeatedly read from invstd_scratch, treating it as a 2d array with
// row-stride 0
snrt_ssr_loop_2d(SNRT_SSR_DM2, num_points_work_for_core, C, 0,
sizeof(double));

snrt_ssr_read(SNRT_SSR_DM0, SNRT_SSR_2D,
&grad_ofmap_scratch[compute_id * C]);
// printf("Core %d, grad_ofmap first value is %f\n", compute_id,
// grad_ofmap_scratch[compute_id * C]);
snrt_ssr_write(SNRT_SSR_DM1, SNRT_SSR_2D,
&grad_ifmap_scratch[compute_id * C]);
snrt_ssr_read(SNRT_SSR_DM2, SNRT_SSR_2D, invstd_scratch);
Expand All @@ -505,9 +510,10 @@ static inline void batchnorm_backwards(batchnorm_backward_layer_t *l) {
: "ft0", "ft1", "ft2");

snrt_fpu_fence();
// wait for writes to the ofmap to finish?
__builtin_ssr_barrier(SNRT_SSR_DM1);
snrt_ssr_disable();

uint32_t end_ifmap_grad_computation = snrt_mcycle();
snrt_cluster_hw_barrier();
}

Expand All @@ -522,5 +528,6 @@ static inline void batchnorm_backwards(batchnorm_backward_layer_t *l) {
// in training mode: big equation
// in eval mode: grad_ifmap[i][c] =
// dy[i][c]*1/sqrt(running_var[c]+eps)*weight[c]

snrt_cluster_hw_barrier();
}
9 changes: 8 additions & 1 deletion sw/dnn/batchnorm/verify_backwards.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
)
import pickle

ERR_THRESHOLD = 0.001
ERR_THRESHOLD = 1e-7

PRECISION_T = {8: "64", 4: "32", 2: "16", 1: "8"}

Expand Down Expand Up @@ -125,6 +125,13 @@ def main():
bytes_to_float(raw_results["grad_bias"], prec), dtype=NUMPY_T[prec]
).reshape((CI,))
)
# temp = torch.from_numpy(
# np.array(
# bytes_to_float(raw_results["temp"], prec), dtype=NUMPY_T[prec]
# ).reshape((8,CI))
# )
# print(temp)


# convert from NHWC to NCHW format
ifmap = ifmap.permute(0, 3, 1, 2)
Expand Down

0 comments on commit 7933cf8

Please sign in to comment.