Skip to content

Commit

Permalink
Post DASUM and notebook.
Browse files Browse the repository at this point in the history
  • Loading branch information
sigilante committed Feb 23, 2024
1 parent 8807a97 commit 35f6333
Show file tree
Hide file tree
Showing 6 changed files with 275 additions and 31 deletions.
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ MUNIT_OBJ = $(MUNIT_SRC:.c=.o)
BLAS_SRC_DIR = ./src/blas/level1
BLAS_SRCS = \
$(BLAS_SRC_DIR)/sasum.c \
$(BLAS_SRC_DIR)/dasum.c \
$(BLAS_SRC_DIR)/qasum.c \
$(BLAS_SRC_DIR)/saxpy.c \
$(BLAS_SRC_DIR)/daxpy.c \
Expand Down
231 changes: 231 additions & 0 deletions adapt-type.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 47,
"id": "998239f8",
"metadata": {},
"outputs": [],
"source": [
"src = './tests/blas/level1/test_sasum.c'\n",
"# src = './src/blas/level1/sasum.c'\n",
"with open(src, 'r') as srcfile:\n",
" srccode = srcfile.read()"
]
},
{
"cell_type": "code",
"execution_count": 48,
"id": "f1f8f7d3",
"metadata": {},
"outputs": [],
"source": [
"kind = 'd'\n",
"fn = 'sasum'"
]
},
{
"cell_type": "code",
"execution_count": 52,
"id": "e6b51acf",
"metadata": {},
"outputs": [],
"source": [
"types = {'h': 16, 's': 32, 'd': 64, 'q': 128}\n",
"types_ = {'d': 'double', 'h':'uint16_t'}\n",
"values = {\n",
" 'h': {\n",
" '-5.0f' : '0xc500 ',\n",
" '-4.0f' : '0xc400 ',\n",
" '-3.0f' : '0xc200 ',\n",
" '-2.0f' : '0xc000 ',\n",
" '-1.0f' : '0xbc00 ',\n",
" '0.0f' : '0x0 ',\n",
" '1.0f' : '0x3c00 ',\n",
" '2.0f' : '0x4000 ',\n",
" '3.0f' : '0x4200 ',\n",
" '4.0f' : '0x4400 ',\n",
" '5.0f' : '0x4500 ',\n",
" '10.0f' : '0x4900 ',\n",
" '12.0f' : '0x4a00 ',\n",
" '16.0f' : '0x4c00 ',\n",
" '20.0f' : '0x4d00 ',\n",
" '24.0f' : '0x4e00 ',\n",
" '30.0f' : '0x4f80 ',\n",
" '32.0f' : '0x5000 ',\n",
" '36.0f' : '0x5080 ',\n",
" '40.0f' : '0x5100 ',\n",
" '48.0f' : '0x5200 ',\n",
" '50.0f' : '0x5240 ',\n",
" '60.0f' : '0x5380',\n",
" '64.0f' : '0x5400'\n",
" },\n",
" 'd': {\n",
" '-5.0f' : '-5.0',\n",
" '-4.0f' : '-4.0',\n",
" '-3.0f' : '-3.0',\n",
" '-2.0f' : '-2.0',\n",
" '-1.0f' : '-1.0',\n",
" '0.0f' : '0.0',\n",
" '1.0f' : '1.0',\n",
" '2.0f' : '2.0',\n",
" '3.0f' : '3.0',\n",
" '4.0f' : '4.0',\n",
" '5.0f' : '5.0',\n",
" '10.0f' : '10.0',\n",
" '12.0f' : '12.0',\n",
" '16.0f' : '16.0',\n",
" '20.0f' : '20.0',\n",
" '24.0f' : '24.0',\n",
" '30.0f' : '30.0',\n",
" '32.0f' : '32.0',\n",
" '36.0f' : '36.0',\n",
" '40.0f' : '40.0',\n",
" '48.0f' : '48.0',\n",
" '50.0f' : '50.0',\n",
" '60.0f' : '60.0',\n",
" '64.0f' : '64.0' \n",
" }\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 53,
"id": "3ea54287",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"-5.0f\n",
"-4.0f\n",
"-3.0f\n",
"-2.0f\n",
"-1.0f\n",
"0.0f\n",
"1.0f\n",
"2.0f\n",
"3.0f\n",
"4.0f\n",
"5.0f\n",
"10.0f\n",
"12.0f\n",
"16.0f\n",
"20.0f\n",
"24.0f\n",
"30.0f\n",
"32.0f\n",
"36.0f\n",
"40.0f\n",
"48.0f\n",
"50.0f\n",
"60.0f\n",
"64.0f\n"
]
}
],
"source": [
"trgcode = srccode.replace('32', repr(types[kind]))\n",
"trgcode = trgcode.replace(fn, kind+fn[1:])\n",
"trgcode = trgcode.replace('svec', kind+'vec')\n",
"trgcode = trgcode.replace('stemp', kind+'temp')\n",
"trgcode = trgcode.replace('S', kind.upper())\n",
"trgcode = trgcode.replace(kind.upper()+'B', 'SB')\n",
"trgcode = trgcode.replace('SX', kind.upper()+'X')\n",
"trgcode = trgcode.replace('SX', kind.upper()+'Y')\n",
"trgcode = trgcode.replace('(float)', '('+types_[kind]+')')\n",
"trgcode = trgcode.replace('float[]', types_[kind]+'[]')\n",
"trgcode = trgcode.replace('nan_unify_s', 'nan_unify_'+kind)\n",
"for key in values[kind]:\n",
" print(key)\n",
" trgcode = trgcode.replace(key, values[kind][key])"
]
},
{
"cell_type": "code",
"execution_count": 54,
"id": "bfd6abee",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"#include \"test.h\"\n",
"\n",
"MunitResult test_dasum_0(const MunitParameter params[],\n",
" void* user_data_or_fixture) {\n",
" const uint64_t N = 2;\n",
" float64_t* DX = dvec((double[]){0.0, 0.0}, N);\n",
"\n",
" float64_t D = (float64_t) dasum(N, (float64_t*)DX, 1);\n",
" float64_t R = (float64_t){0.0};\n",
"\n",
" assert_int(D.v, ==, R.v);\n",
" \n",
" return MUNIT_OK;\n",
"}\n",
"\n",
"MunitResult test_dasum_12345(const MunitParameter params[],\n",
" void* user_data_or_fixture) {\n",
" const uint64_t N = 5;\n",
" float64_t* DX = dvec((double[]){1.0, -2.0, 3.0, -4.0, 5.0}, N);\n",
"\n",
" float64_t D = (float64_t) dasum(N, (float64_t*)DX, 1);\n",
" float64_t R = {*(uint64_t*)&(double){15.0}};\n",
" \n",
" assert_int(D.v, ==, R.v);\n",
" \n",
" return MUNIT_OK;\n",
"}\n",
"\n",
"// TODO test stride\n",
"\n"
]
}
],
"source": [
"print(trgcode)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "41761065",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "0dac712a",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
39 changes: 8 additions & 31 deletions src/blas/level1/dasum.c
Original file line number Diff line number Diff line change
@@ -1,34 +1,11 @@
#include "softblas.h"

float64_t dasum(uint64_t N, const float64_t *DX, uint64_t incX) {
float64_t dasum = SB_REAL64_ZERO;
float64_t htemp = SB_REAL64_ZERO;

if (N <= 0 || incX <= 0) {
return dasum;
}

if (incX == 1) {
uint64_t m = N % 6;
if (m != 0) {
for (uint64_t i = 0; i < M; i++) {
htemp = f64_add(htemp, f64_abs(DX[i]));
}
if (N < 6) {
dasum = htemp;
return dasum;
}
}
uint64_t mp1 = m + 1;
for (uint64_t i = mp1; i < N; i += 6) {
htemp = f64_add(htemp, f64_add(f64_add(f64_add(f64_add(f64_add(f64_abs(DX[i]), f64_abs(DX[i + 1])), f64_abs(DX[i + 2])), f64_abs(DX[i + 3])), f64_abs(DX[i + 4])), f64_abs(DX[i + 5])));
}
} else {
uint64_t NincX = N * incX;
for (uint64_t i = 0; i < NincX; i += incX) {
htemp = f64_add(htemp, f64_abs(DX[i]));
}
}
dasum = htemp;
return dasum;
}
float64_t dtemp = { SB_REAL64_ZERO };

for (uint64_t i = 0; i < N; i++) {
dtemp = f64_add(dtemp, f64_abs(DX[i*incX]));
}

return nan_unify_d(dtemp);
}
4 changes: 4 additions & 0 deletions tests/blas/include/test.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ MunitResult test_saxpy_stride(const MunitParameter params[],
MunitResult test_saxpy_neg_stride(const MunitParameter params[],
void* user_data_or_fixture);

MunitResult test_dasum_0(const MunitParameter params[],
void* user_data_or_fixture);
MunitResult test_dasum_12345(const MunitParameter params[],
void* user_data_or_fixture);
MunitResult test_daxpy_0(const MunitParameter params[],
void* user_data_or_fixture);
MunitResult test_daxpy_sum(const MunitParameter params[],
Expand Down
29 changes: 29 additions & 0 deletions tests/blas/level1/test_dasum.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#include "test.h"

MunitResult test_dasum_0(const MunitParameter params[],
void* user_data_or_fixture) {
const uint64_t N = 2;
float64_t* DX = dvec((double[]){0.0, 0.0}, N);

float64_t D = (float64_t) dasum(N, (float64_t*)DX, 1);
float64_t R = (float64_t){0.0};

assert_int(D.v, ==, R.v);

return MUNIT_OK;
}

MunitResult test_dasum_12345(const MunitParameter params[],
void* user_data_or_fixture) {
const uint64_t N = 5;
float64_t* DX = dvec((double[]){1.0, -2.0, 3.0, -4.0, 5.0}, N);

float64_t D = (float64_t) dasum(N, (float64_t*)DX, 1);
float64_t R = {*(uint64_t*)&(double){15.0}};

assert_int(D.v, ==, R.v);

return MUNIT_OK;
}

// TODO test stride
2 changes: 2 additions & 0 deletions tests/test_all.c
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ int main(int argc, char* argv[MUNIT_ARRAY_PARAM(argc + 1)]) {
{"/test_saxpy_sum", test_saxpy_sum, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
{"/test_saxpy_stride", test_saxpy_stride, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
{"/test_saxpy_neg_stride", test_saxpy_neg_stride, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
{"/test_dasum_0", test_dasum_0, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
{"/test_dasum_12345", test_dasum_12345, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
{"/test_daxpy_0", test_daxpy_0, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
{"/test_daxpy_sum", test_daxpy_sum, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
{"/test_daxpy_stride", test_daxpy_stride, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
Expand Down

0 comments on commit 35f6333

Please sign in to comment.