Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster packing #1172

Open
wants to merge 10 commits into
base: dev
Choose a base branch
from
Next Next commit
Initial commit for sped-up packing function in finn utils
bwintermann committed Aug 26, 2024
commit a5cb912cd2a3b437d7df59e690f979f645af51f5
37 changes: 37 additions & 0 deletions src/finn/util/data_packing.py
Original file line number Diff line number Diff line change
@@ -27,6 +27,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import binascii
from math import ceil
import numpy as np
import os
import sys
@@ -35,6 +36,31 @@
from qonnx.util.basic import roundup_to_integer_multiple


# Import the faster packing functions. This is executed when loading the module so that the faster version is always available when this is imported
import ctypes as ct
from numpy.ctypeslib import ndpointer
import os

from qonnxdtype import DataTypeMeta

# Setup
fastpack_source = os.path.join(os.path.dirname(__file__), "fast_pack.c")
fastpack_lib = os.path.join(os.path.dirname(__file__), "fastpack.so")
assert os.path.isfile(fastpack_source), "Could not find fast_pack.c in the utils/ dir of FINN"

# Compile
os.system("gcc -shared -O3 -fpic fast_pack.c -o fastpack.so")
assert os.path.isfile(fastpack_lib), "Could not find fastpack.so. Did compilation fail?"

# Load
fastpack = ct.CDLL(fastpack_lib)
fastpack_floatarray = ndpointer(ct.c_float, flags="C_CONTIGUOUS")
fastpack.array_to_hexstring_binary.argtypes = (fastpack_floatarray, ct.c_uint, ct.c_uint, ct.c_char_p)
fastpack.array_to_hexstring_binary.restype = ct.c_bool




def array2hexstring(array, dtype, pad_to_nbits, prefix="0x", reverse=False):
"""
Pack given one-dimensional NumPy array with FINN DataType dtype into a hex
@@ -71,6 +97,17 @@ def array2hexstring(array, dtype, pad_to_nbits, prefix="0x", reverse=False):
# reverse prior to packing, if desired
if reverse:
array = np.flip(array, -1)


# Check if the fast way can be taken
# TODO: Expand this to cover more cases
if dtype == DataType["BINARY"] and prefix == "0x":
output_string = ct.create_string_buffer(ceil(pad_to_nbits / 4) + 4)
success = fastpack.array_to_hexstring_binary(array, array.size, pad_to_nbits, output_string)
assert success, f"Could not convert array {array} with datatype {dtype} to hexstring!"
return output_string


lineval = BitArray(length=0)
bw = dtype.bitwidth()
# special handling for fixed point: rescale, then pack as integers
76 changes: 76 additions & 0 deletions src/finn/util/fast_pack.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#include <stdio.h>
#include <stdlib.h>
#include <stdbool.h>
#include <string.h>
#include <stdint.h>


/***
* Takes a numpy array of floats in BINARY datatype from finn and the number of elements in that array, as well as the number of padded bits required.
* It also takes an out-string buffer to write the results to. This buffer is created by python via ctypes.create_string_buffer() and must be large enough to
* hold the required number of padded bits.
*
* The function returns false on an error and true in case of success
*/
bool array_to_hexstring_binary(float* values, unsigned int elements, unsigned int padded_bits, char* out) {
// Calculate min number of bits required
unsigned int min_bits;
if (elements % 4 != 0) {
min_bits = elements + (4 - (elements % 4));
} else {
min_bits = elements;
}

// Padded bits must be atleast length of min_bits and divisible by 4 for hex repr
if (min_bits > padded_bits || padded_bits % 4 != 0) {
return false;
}

// Pad output string
strcpy(out, "0x");
unsigned int prefix_digits = (padded_bits - min_bits) / 4;
for (int i = 0; i < prefix_digits; i++) {
strcat(out, "0");
}
out[2 + prefix_digits + min_bits / 4 + 1] = '\0';

// Converting 4 at a time
uint8_t temp;
char buffer[100];
unsigned int digits = 0;
for (int i = elements - (min_bits - 4); i < elements; i += 4) {
// Clear temp
temp = 0;

// Fill lower 4 bits
for (int j = 0; j < 4; j++) {
temp <<= 1;
temp |= (unsigned int) values[i + j];
}

// Save hex digit
if (temp <= 9) {
buffer[0] = '0' + temp;
} else {
buffer[0] = 'a' + temp - 10;
}
out[2 + prefix_digits + (min_bits / 4) - digits - 1] = buffer[0];
digits++;
}

// Fill in the last odd bits
temp = 0;
for (int j = 0; j < elements - (min_bits - 4); j++) {
temp <<= 1;
temp |= (unsigned int) values[min_bits - 4 + j];
}

// Save hex digit
if (temp <= 9) {
buffer[0] = '0' + temp;
} else {
buffer[0] = 'a' + temp - 10;
}
out[2 + prefix_digits] = buffer[0];
return true;
}