diff --git a/fetch-repos.sh b/fetch-repos.sh index fbed3f2a5..fda7cb306 100755 --- a/fetch-repos.sh +++ b/fetch-repos.sh @@ -39,6 +39,7 @@ XIL_BDF_COMMIT="8cf4bb674a919ac34e3d99d8d71a9e60af93d14e" RFSOC4x2_BDF_COMMIT="13fb6f6c02c7dfd7e4b336b18b959ad5115db696" KV260_BDF_COMMIT="98e0d3efc901f0b974006bc4370c2a7ad8856c79" EXP_BOARD_FILES_MD5="226ca927a16ea4ce579f1332675e9e9a" +ATTENTION_HLSLIB_COMMIT="afc9720f10e551e1f734e137b21bb6d0a8342177" QONNX_URL="https://github.com/iksnagreb/qonnx.git" FINN_EXP_URL="https://github.com/Xilinx/finn-experimental.git" @@ -51,6 +52,7 @@ AVNET_BDF_URL="https://github.com/Avnet/bdf.git" XIL_BDF_URL="https://github.com/Xilinx/XilinxBoardStore.git" RFSOC4x2_BDF_URL="https://github.com/RealDigitalOrg/RFSoC4x2-BSP.git" KV260_BDF_URL="https://github.com/Xilinx/XilinxBoardStore.git" +ATTENTION_HLSLIB_URL="https://github.com/iksnagreb/attention-hlslib.git" QONNX_DIR="qonnx" FINN_EXP_DIR="finn-experimental" @@ -63,6 +65,7 @@ AVNET_BDF_DIR="avnet-bdf" XIL_BDF_DIR="xil-bdf" RFSOC4x2_BDF_DIR="rfsoc4x2-bdf" KV260_SOM_BDF_DIR="kv260-som-bdf" +ATTENTION_HLSLIB_DIR="attention-hlslib" # absolute path to this script, e.g. /home/user/bin/foo.sh SCRIPT=$(readlink -f "$0") @@ -126,6 +129,7 @@ fetch_repo $AVNET_BDF_URL $AVNET_BDF_COMMIT $AVNET_BDF_DIR fetch_repo $XIL_BDF_URL $XIL_BDF_COMMIT $XIL_BDF_DIR fetch_repo $RFSOC4x2_BDF_URL $RFSOC4x2_BDF_COMMIT $RFSOC4x2_BDF_DIR fetch_repo $KV260_BDF_URL $KV260_BDF_COMMIT $KV260_SOM_BDF_DIR +fetch_repo $ATTENTION_HLSLIB_URL $ATTENTION_HLSLIB_COMMIT $ATTENTION_HLSLIB_DIR # Can skip downloading of board files entirely if desired if [ "$FINN_SKIP_BOARD_FILES" = "1" ]; then diff --git a/src/finn/custom_op/fpgadataflow/__init__.py b/src/finn/custom_op/fpgadataflow/__init__.py index 91886fd3a..c6148425a 100644 --- a/src/finn/custom_op/fpgadataflow/__init__.py +++ b/src/finn/custom_op/fpgadataflow/__init__.py @@ -61,8 +61,9 @@ def register_custom_op(cls): # Import the submodule containing the Unsqueeze operation import finn.custom_op.fpgadataflow.unsqueeze - from finn.custom_op.fpgadataflow.addstreams import AddStreams +from finn.custom_op.fpgadataflow.attention import ScaledDotProductAttention +from finn.custom_op.fpgadataflow.attention_heads import MergeMultiHeads, SplitMultiHeads from finn.custom_op.fpgadataflow.channelwise_op import ChannelwiseOp from finn.custom_op.fpgadataflow.concat import StreamingConcat from finn.custom_op.fpgadataflow.convolutioninputgenerator import ( @@ -77,6 +78,7 @@ def register_custom_op(cls): from finn.custom_op.fpgadataflow.lookup import Lookup from finn.custom_op.fpgadataflow.matrixvectoractivation import MVAU from finn.custom_op.fpgadataflow.pool import Pool +from finn.custom_op.fpgadataflow.replicate_stream import ReplicateStream from finn.custom_op.fpgadataflow.split import StreamingSplit from finn.custom_op.fpgadataflow.streamingdataflowpartition import ( StreamingDataflowPartition, @@ -116,3 +118,7 @@ def register_custom_op(cls): custom_op["StreamingEltwise"] = StreamingEltwise custom_op["StreamingMaxPool"] = StreamingMaxPool custom_op["UpsampleNearestNeighbour"] = UpsampleNearestNeighbour +custom_op["ScaledDotProductAttention"] = ScaledDotProductAttention +custom_op["SplitMultiHeads"] = SplitMultiHeads +custom_op["MergeMultiHeads"] = MergeMultiHeads +custom_op["ReplicateStream"] = ReplicateStream diff --git a/src/finn/custom_op/fpgadataflow/attention.py b/src/finn/custom_op/fpgadataflow/attention.py new file mode 100644 index 000000000..7b1c3e7dd --- /dev/null +++ b/src/finn/custom_op/fpgadataflow/attention.py @@ -0,0 +1,771 @@ +# fmt: off +# Disable formatter. This is deliberately formatted to stay within 80 characters +# per line. Black, however, formats some lines going beyond this. + +# Python builtin math functions: math.ceil returns int, while np.ceil returns +# float +import math +# Numpy math and arrays +import numpy as np +# Python warning subsystem +import warnings + +# QONNX/FINN datatypes +from qonnx.core.datatype import DataType +# Multithreshold activations +from qonnx.custom_op.general.multithreshold import multithreshold +# Some utils for working with tensors in qonnx +from qonnx.util.basic import calculate_matvec_accumulator_range + +# Derive custom operators form the FINN base custom op +from finn.custom_op.fpgadataflow.hwcustomop import HWCustomOp + + +# Softmax function on numpy arrays with overflow handling matching the HLS +# operator +def softmax(x, axis): + # For overflow handling, find the maximum value along axis and place ones at + # each occurrence + max_ones = (x == np.max(x, axis=axis, keepdims=True)).astype(np.float32) + # Count the occurrences of the maximum along the normalization axis + max_counts = np.sum(max_ones, axis=axis, keepdims=True) + # Exponential of the input + exp = np.exp(x - np.max(x, axis=axis)[:, np.newaxis]) + # Compute the total along axis + total = np.sum(exp, axis=axis, keepdims=True) + # Detect overflow of the summation + overflow = np.isinf(total) + # Replace overflows by equal weight given to all instances of the maximum + # input value. For non overflow just compute normal softmax + return np.where(overflow, max_ones / max_counts, exp / total) + + +# Scaled Dot-Product Attention Custom Operator +# Note: Single head attention +class ScaledDotProductAttention(HWCustomOp): + # Initializes the operator given an onnx graph node + def __init__(self, onnx_node, **kwargs): + # Just forward all arguments to the init method of the CustomOp base + super().__init__(onnx_node, **kwargs) + + # Node attributes matching the HLS operator + def get_nodeattr_types(self): + # Start from parent operator class attributes + attrs = HWCustomOp.get_nodeattr_types(self) + # Update attributes dictionary for new custom operator + attrs.update({ + # Embedding dimension of queries and keys + "QKDim": ("i", True, 0), + # Length of the query sequence + "QLen": ("i", True, 0), + # Embedding dimension of the values + "VDim": ("i", True, 0), + # Length of the key and value sequence + "KVLen": ("i", True, 0), + + # Folding along the embedding dimensions + "EmbFold": ("i", True, 0), + # Folding along the sequence dimensions + "SeqFold": ("i", True, 0), + + # Datatype of query matrix elements + "QType": ("s", True, ""), + # Datatype of key matrix elements + "KType": ("s", True, ""), + # Datatype of value matrix elements + "VType": ("s", True, ""), + # Datatype of mask matrix elements + "MType": ("s", False, "INT0"), + # Datatype of attention weights elements + "AType": ("s", False, "UINT32"), + # Datatype of output elements + "OType": ("s", True, ""), + + # Datatype of accumulator elements of the Query x Key multiplication + "AccQKMatMul": ("s", False, "UINT32"), + # Datatype of output elements of the Query x Key multiplication + "OutQKMatMul": ("s", False, "UINT32"), + # Activation function type of the Query x Key multiplication + "ActQKMatMul": ("s", False, "none", {"none", "thresholds"}), + # Output bias to be applied to the thresholding activation following + # the Query x Key multiplication + "BiasActQKMatMul": ("f", False, 0.0), + + # Datatype of accumulator elements of the Attention x Value + # multiplication + "AccAVMatMul": ("s", False, "UINT32"), + # Datatype of output elements of the Attention x Value + # multiplication + "OutAVMatMul": ("s", False, "UINT32"), + # Activation function type of the Attention x Value multiplication + "ActAVMatMul": ("s", False, "none", {"none", "thresholds"}), + # Output bias to be applied to the thresholding activation following + # the Attention x Value multiplication + "BiasActAVMatMul": ("f", False, 0.0), + + # Scale factor preceding the softmax normalization to dequantize the + # input + "DequantSoftmax": ("f", False, 1.0), + # Datatype of softmax normalization before applying activation or + # type cast. This is called Acc to stick to the naming scheme of the + # MatMul operators before. + # Note: Currently this is ALWAYS floats + "AccASoftmax": ("s", False, "FLOAT32"), + # Activation function type of the softmax normalization of the + # attention weights + "ActASoftmax": ("s", False, "none", {"none", "thresholds"}), + # Output bias to be applied to the thresholding activation following + # the softmax normalization of the attention weights + "BiasActASoftmax": ("f", False, 0.0), + + # Mode used for providing the attention mask: There can be no mask, + # a mask sent as the fourth dynamic input, a mask provided as fourth + # constant input or a causal attention mask which is generated by + # the operator itself. + "mask_mode": ( + "s", True, "none", {"none", "input", "const", "causal"} + ), + + # Possible execution modes for simulating this node + # Note: Override to support python mode + "exec_mode": ( + "s", False, "python", {"", "rtlsim", "cppsim", "python"} + ), + + # FPGA resource type for memories/internal buffers of the operator + # Note: Currently only used for StreamTile buffers + "ram_style": ( + "s", False, "auto", {"auto", "block", "distributed", "ultra"} + ), + # FPGA resource type for memories of the thresholds parameters + # Note: Not yet used... + "ram_style_thresholds": ( + "s", False, "auto", {"auto", "block", "distributed"} + ), + # FPGA resource type for memories of the attention mask if the + # mask_mode is "const" + "ram_style_mask": ( + "s", False, "auto", {"auto", "block", "distributed"} + ), + # FPGA resource type to implement the MAC operations of the two + # internal matmul operations + "mac_resource": ("s", False, "auto", {"auto", "lut", "dsp"}), + + # Input and output FIFO depths for multi-I/O nodes + # Note: Need to override here as there are three inputs + "inFIFODepths": ("ints", False, [2, 2, 2]), + "outFIFODepths": ("ints", False, [2]), + }) + # Return updated attribute dictionary + return attrs + + # Shape configuration of the operator + @property + def shapes(self): + # Note: This matches the order of definition above and the order of the + # HLS lib template as well + return (self.get_nodeattr("QKDim"), self.get_nodeattr("QLen"), + self.get_nodeattr("VDim"), self.get_nodeattr("KVLen")) + + # Folding configuration of the operator + @property + def folds(self): + # Note: This matches the order of definition above and the order of the + # HLS lib template as well + return self.get_nodeattr("EmbFold"), self.get_nodeattr("SeqFold") + + # Tests whether the given folding is a valid configuration with respect to + # the shape configuration + @property + def is_valid_folding(self): + # Get and unpack the shape attributes (except the q matrix length, which + # is never folded) + qkdim, _, vdim, kvlen = self.shapes + # Get and unpack the folding attributes + embfold, seqfold = self.folds + # All shapes must be multiples of their corresponding fold + return not ((qkdim % embfold) or (vdim % embfold) or (kvlen % seqfold)) + + # Returns an ONNX node that has the same shape inference behavior + def make_shape_compatible_op(self, model): + # Infer the output shape from the input shapes + o_shape = (self.get_nodeattr("QLen"), self.get_nodeattr("VDim")) + # Get the node wrapped by this custom op + node = self.onnx_node + # Get the shape of the input tensor for inferring the number of + # heads and correctly propagating shapes + shape = model.get_tensor_shape(node.input[0]) + # Determine the rank of the input tensor to support batched and + # non-batched inputs + rank = len(shape) + # Constant operation producing output of given shape + # Note: Rank == 3 allows for propagating yet unrolled multi-attention + # heads. + return super().make_const_shape_op( + (shape[0], *o_shape) if (rank == 3) else o_shape + ) + + # Infers the output data types and updates the input datatypes of the node + def infer_node_datatype(self, model): + # ONNX graph node of the operator + node = self.onnx_node + + # Get input datatypes from model for query, key, value nodes in order + q_dtype = model.get_tensor_datatype(node.input[0]) + k_dtype = model.get_tensor_datatype(node.input[1]) + v_dtype = model.get_tensor_datatype(node.input[2]) + + # Test for changing query input datatype + if q_dtype != self.get_nodeattr("QType"): + # Issue a warning message + warnings.warn("QType changing for %s: %s -> %s " % ( + node.name, + str(self.get_nodeattr("QType")), + str(q_dtype), + )) + # Test for changing key input datatype + if k_dtype != self.get_nodeattr("KType"): + # Issue a warning message + warnings.warn("KType changing for %s: %s -> %s " % ( + node.name, + str(self.get_nodeattr("KType")), + str(k_dtype), + )) + # Test for changing value input datatype + if v_dtype != self.get_nodeattr("VType"): + # Issue a warning message + warnings.warn("VType changing for %s: %s -> %s " % ( + node.name, + str(self.get_nodeattr("VType")), + str(v_dtype), + )) + + # Update the node datatype attributes + self.set_nodeattr("QType", q_dtype.name) + self.set_nodeattr("KType", k_dtype.name) + self.set_nodeattr("VType", v_dtype.name) + + # Attention mask might be provided as an input as well + if self.get_nodeattr("mask_mode") == "input": + # Get the datatype attribute of the attention mask + # Note: Optional mask will be provided as the fourth input + mask_dtype = model.get_tensor_datatype(node.input[3]) + # Test for changing mask input datatype + if mask_dtype != self.get_nodeattr("MType"): + # Issue a warning message + warnings.warn("MType changing for %s: %s -> %s " % ( + node.name, + str(self.get_nodeattr("MType")), + str(mask_dtype), + )) + # Update the node datatype attribute of the attention mask + self.set_nodeattr("MType", mask_dtype.name) + + # Set the model output datatype + model.set_tensor_datatype( + node.output[0], DataType[self.get_nodeattr('OType')] + ) + + # Executes the attention operator in python mode simulation + def _execute_node_python(self, context, graph): # noqa: graph unused + # Get the node wrapped by this custom op + node = self.onnx_node + + # Read the input from the execution context and reshape to match the + # expected folding + q = context[node.input[0]].reshape(self.get_normal_input_shape(ind=0)) + k = context[node.input[1]].reshape(self.get_normal_input_shape(ind=1)) + v = context[node.input[2]].reshape(self.get_normal_input_shape(ind=2)) + + # Quantization activation function following the query and key + # multiplication + def act_qk_matmul(x): + # Only applies if this is specified as a thresholding activation + if self.get_nodeattr("ActQKMatMul") == "thresholds": + # Get the thresholds initializer by name from ordered list of + # optional inputs + thresholds = context[ + self.get_input_name_by_name("thresholds_qk_matmul") + ] + # Activation value, i.e., bias applied after thresholding + # activation + bias = self.get_nodeattr("BiasActQKMatMul") + # Applies thresholding activation in python to the input + return multithreshold(x, thresholds) + bias + # If not thresholds, assume identity function + return x + + # Quantization activation function following the softmax normalization + def act_a_softmax(x): + # Only applies if this is specified as a thresholding activation + if self.get_nodeattr("ActASoftmax") == "thresholds": + # Get the thresholds initializer by name from ordered list of + # optional inputs + thresholds = context[ + self.get_input_name_by_name("thresholds_a_softmax") + ] + # Activation value, i.e., bias applied after thresholding + # activation + bias = self.get_nodeattr("BiasActASoftmax") + # Applies thresholding activation in python to the input + return multithreshold(x, thresholds) + bias + # If not thresholds, assume identity function + return x + + # Quantization activation function following the attention and values + # multiplication + def act_av_matmul(x): + # Only applies if this is specified as a thresholding activation + if self.get_nodeattr("ActAVMatMul") == "thresholds": + # Get the thresholds initializer by name from ordered list of + # optional inputs + thresholds = context[ + self.get_input_name_by_name("thresholds_av_matmul") + ] + # Activation value, i.e., bias applied after thresholding + # activation + bias = self.get_nodeattr("BiasActAVMatMul") + # Applies thresholding activation in python to the input + return multithreshold(x, thresholds) + bias + # If not thresholds, assume identity function + return x + + # Scale used to dequantize the qk matrix before computing the softmax in + # floating point + dequant = self.get_nodeattr("DequantSoftmax") + + # 1. Queries and keys multiplication followed by quantizing activation + # function + qk = act_qk_matmul(np.matmul(q, k.T)) + + # Load or create the attention mask for mutually exclusive mask modes + + # There might be no attention mask + if self.get_nodeattr("mask_mode") == "none": + # No mask can be realized by adding zero, which does not change + # anything + mask = 0 + # There might eb a causal attention mask + elif self.get_nodeattr("mask_mode") == "causal": + # A causal mask does not need to be stored and can be generated on + # the fly + mask = np.triu(-np.inf * np.ones_like(qk), 1) + # There might be a constant initializer attention mask + elif self.get_nodeattr("mask_mode") == "const": + # Load the mask initializer from the execution context + mask = context[ + self.get_input_name_by_name("M") + ] + # The attention operator represents attention masks as binary masks, + # but the numpy simulation requires floats with 0 and -inf + mask = np.where(mask, -np.inf * np.ones_like(mask), 0) + # The attention mask might be streamed in as the third input + elif self.get_nodeattr("mask_mode") == "input": + # Load the mask input from the execution context + mask = context[ + self.get_input_name_by_name("M") + ] + # The attention operator represents attention masks as binary masks, + # but the numpy simulation requires floats with 0 and -inf + mask = np.where(mask, -np.inf * np.ones_like(mask), 0) + # All other mask modes are not supported + else: + raise NotImplementedError( + f"Mask Mode {self.get_nodeattr('mask_mode')} is not implemented" + ) + # Softmax-normalization of the attention weights followed by quantizing + # activation function + a = act_a_softmax( + # Note: Reshape after masking, as the mask might broadcast messing + # with the shape + softmax((dequant * qk + mask).reshape(qk.shape), axis=1) + ) + # 2. Attention weights and values matmul followed by quantization + # activation function + out = act_av_matmul(np.matmul(a, v)) + + # Insert the results into the execution context + context[self.onnx_node.output[0]] = out.reshape( + self.get_normal_output_shape(ind=0) + ) + + # Executes the attention operator in C++ mode simulation + def _execute_node_cppsim(self, context, graph): # noqa: graph unused + # C++ Simulation needs to be implemented in HLS backend specialization + raise NotImplementedError( + f"exec_mode cppsim of {self.__class__.__name__} is not implemented!" + ) + + # Executes the attention operator in RTL mode simulation + def _execute_node_rtlsim(self, context, graph): # noqa: graph unused + # RTL Simulation needs to be implemented in backend specialization + raise NotImplementedError( + f"exec_mode rtlsim of {self.__class__.__name__} is not implemented!" + ) + + # Executes the attention operator in simulation (either python, c++ or rtl) + def execute_node(self, context, graph): + # Get the configured execution mode + mode = self.get_nodeattr("exec_mode") + # Lookup table mapping execution modes to implementing methods + exec_fns = { + "python": self._execute_node_python, + "cppsim": self._execute_node_cppsim, + "rtlsim": self._execute_node_python, # TODO: Revert to rtlsim + } + # Select and execute the function by mode string + exec_fns[mode](context, graph) + + # Optional node verification + def verify_node(self): + pass + + # Gets the datatype of input at index ind + def get_input_datatype(self, ind=0): + # Ordered list of names of allowed inputs + inputs = ["QType", "KType", "VType"] + + # If the attention mask is provided as input, it has a type as well + if self.get_nodeattr("mask_mode") == "input": + # The mask type is an attribute itself + inputs += ["MType"] + + # TODO: All the following types are probably never requested, they are + # implemented for the sake of completeness for now. If they are ever + # actually required, check whether the following defaults and dummies + # actually still make sense. + + # If there is a thresholding activation for the first matmul, it will + # have a type as well + if self.get_nodeattr("ActQKMatMul") == "thresholds": + # The thresholds will always be of the accumulator type as the + # activation maps from AccQKMatMul to OutQKMatMul + inputs += ["AccQKMatMul"] + + # If there is a thresholding activation for the softmax normalization, + # it will have a type as well + if self.get_nodeattr("ActASoftmax") == "thresholds": + # While there is a dummy configurable attribute describing the + # threshold type of the softmax, these are currently always floats + inputs += ["AccASoftmax"] + + # If there is a thresholding activation for the second matmul, it will + # have a type as well + if self.get_nodeattr("ActAVMatMul") == "thresholds": + # The thresholds will always be of the accumulator type as the + # activation maps from AccAVMatMul to OutAVMatMul + inputs += ["AccAVMatMul"] + + # Look up datatype name in attributes and convert to DataType + return DataType[self.get_nodeattr(f"{inputs[ind]}")] + + # Gets the datatype of the output (at index ind, but there is just one) + def get_output_datatype(self, ind=0): + # Ordered list of names of allowed outputs + outputs = ["O"] + # Look up datatype name in attributes and convert to DataType + return DataType[self.get_nodeattr(f"{outputs[ind]}Type")] + + # Gets the shape of the input at index ind without folding + def get_normal_input_shape(self, ind=0): + # List shapes of inputs in order + inputs_shapes = [ + # Query input sequence + (self.get_nodeattr("QLen"), self.get_nodeattr("QKDim")), + # Key input sequence + (self.get_nodeattr("KVLen"), self.get_nodeattr("QKDim")), + # Value input sequence + (self.get_nodeattr("KVLen"), self.get_nodeattr("VDim")), + ] + + # If the attention mask is provided as input, it has a shape as well + if self.get_nodeattr("mask_mode") in {"input", "const"}: + # Mask shape is inferred from query and key sequence lengths + inputs_shapes += [ + (self.get_nodeattr("QLen"), self.get_nodeattr("KVLen")) + ] + + # TODO: All the following shapes are probably never requested, they are + # implemented for the sake of completeness for now. If they are ever + # actually required, remember to insert meaningful shapes. + + # If there is a thresholding activation for the first matmul, these will + # be the next input index after the (optional) mask + if self.get_nodeattr("ActQKMatMul") == "thresholds": + # TODO: This is just a dummy shape + inputs_shapes += [(0, 0)] + + # If there is a thresholding activation for the softmax normalization, + # these will be the next (and last) input index after the (optional) + # second thresholds + if self.get_nodeattr("ActASoftmax") == "thresholds": + # TODO: This is just a dummy shape + inputs_shapes += [(0, 0)] + + # If there is a thresholding activation for the second matmul, these + # will be the next input index after the (optional) first thresholds + if self.get_nodeattr("ActAVMatMul") == "thresholds": + # TODO: This is just a dummy shape + inputs_shapes += [(0, 0)] + + # Get the shape by indexing into the ordered list of all inputs + return inputs_shapes[ind] + + # Gets the shape of the output at index ind (there is just one) without + # folding + def get_normal_output_shape(self, ind=0): # noqa, there is just one output + # The output shape is inferred from the length of the query sequence and + # the embedding dimension of the values + return tuple((self.get_nodeattr("QLen"), self.get_nodeattr("VDim"))) + + # Gets the shape of the attention weights at index ind (there is just one) + # without folding + def get_normal_attention_shape(self, ind=0): # noqa, there is just one + # The attention weights have shape covering both sequence dimensions + return tuple((self.get_nodeattr("QLen"), self.get_nodeattr("KVLen"))) + + # Gets the shape of the input at index ind with folding + def get_folded_input_shape(self, ind=0): + # Get the unfolded size of the input + ilen, idim = self.get_normal_input_shape(ind) + # Get the folding configuration specifying the amount of parallelism + embfold, seqfold = self.folds + + # Queries, keys and values are all folded similarly along the embedding + # dimension + if ind in (0, 1, 2): + # Note: Embedding dimension is always assumed to be the second + # dimension, any transpose is handled implicitly by the operator + return ilen, embfold, idim // embfold + + # If the mask is provided as input, it is folded along the second + # sequence dimension + if ind == 3 and self.get_nodeattr("mask_mode") in {"input", "const"}: + # Note: Both dimensions are sequence dimension, the second + # corresponds to the KVLen + return ilen, seqfold, idim // seqfold + + # If this point is reached, probably something went wrong + # TODO: Requesting the folded shape of thresholds will reach here. Can + # this actually happen? Probably it is indeed an error, there should be + # no reason to ask for the shape of the thresholds, just ask for the + # initializer and get its shape? Folding of the thresholds behaves + # differently and would require to actually keep track of mapping + # indices to optional inputs to correctly associate the folding + # dimensions. + # TODO: This is just a dummy shape + return 0, 0, 0 + + # Gets the shape of the output at index ind (there is just one) with folding + def get_folded_output_shape(self, ind=0): # noqa, there is just one output + # Get the unfolded size of the output + olen, odim = self.get_normal_output_shape(ind) + # Get the folding configuration specifying the amount of parallelism + embfold, seqfold = self.folds + # The output is always folded along the embedding dimension, which is + # assumed to be the second dimension + return olen, embfold, odim // embfold + + # Gets the shape of the attention weights at index ind (there is just one) + # with folding + def get_folded_attention_shape(self, ind=0): # noqa, there is just one + # Get the unfolded size of the attention weights + alen, adim = self.get_normal_attention_shape(ind) + # Get the folding configuration specifying the amount of parallelism + embfold, seqfold = self.folds + # The attention weights are always folded along the sequence dimension, + # which is assumed to be the second dimension + return alen, seqfold, adim // seqfold + + # Widths of the input data stream of the input at index ind + def get_instream_width(self, ind=0): + # Get the number of bits used to represent the input + i_bits = self.get_input_datatype(ind).bitwidth() + # Parallelism is the number of elements in the last dimension of the + # folded input + _, _, elems = self.get_folded_input_shape(ind) + # Width of a stream receiving input elements in parallel + return elems * i_bits + + # Widths of the output data stream of the output at index ind + def get_outstream_width(self, ind=0): + # Get the number of bits used to represent the output + o_bits = self.get_output_datatype(ind).bitwidth() + # Parallelism is the number of elements in the last dimension of the + # folded output + _, _, elems = self.get_folded_output_shape(ind) + # Width of a stream producing output elements in parallel + return elems * o_bits + + # Minimize the accumulator bit width + def minimize_accumulator_width(self, model): # noqa: model is unused + # Get the query, key, value and attention weights type + QType = DataType[self.get_nodeattr("QType")] # noqa + KType = DataType[self.get_nodeattr("KType")] # noqa + VType = DataType[self.get_nodeattr("VType")] # noqa + AType = DataType[self.get_nodeattr("AType")] # noqa + + # Compute the worst-case upper and lower bounds of the accumulator range + lower_worst = QType.min() * np.ones(self.get_normal_input_shape(0)) + lower_range = calculate_matvec_accumulator_range(lower_worst, KType) + upper_worst = QType.max() * np.ones(self.get_normal_input_shape(0)) + upper_range = calculate_matvec_accumulator_range( # noqa: Duplicate + upper_worst, KType + ) + # Minimum and maximum values of the range + acc_min = min(min(lower_range), min(upper_range)) + acc_max = max(max(upper_range), max(upper_range)) + # Unsigned accumulator range + if acc_min >= 0: + # Number of bits necessary to represent the maximum value of the + # range. Some values between 0 and acc_min might be unused. + bitwidth = math.ceil(np.log2(acc_max + 1)) + # New unsigned accumulator datatype of this bitwidth + AccQKMatMul = DataType[f"UINT{bitwidth}"] # noqa + # Signed accumulator range + else: + # Maximum absolute value which needs to be represented + acc_max = max(-acc_min, 1 + acc_max) + # Number of bits necessary to represent the maximum value of the + # range. Some values on one of the ends might remain unused. + bitwidth = math.ceil(np.log2(acc_max) + 1) + # New signed accumulator datatype of this bitwidth + AccQKMatMul = DataType[f"INT{bitwidth}"] # noqa + # Update the accumulator datatype attribute + self.set_nodeattr("AccQKMatMul", AccQKMatMul.name) + # If there is no activation function following the accumulator, the + # output type needs to be adjusted as well + if self.get_nodeattr("ActQKMatMul") == "none": + # Update the output datatype attribute to the same type as the + # accumulator + self.set_nodeattr("OutQKMatMul", AccQKMatMul.name) + + # Compute the worst-case upper and lower bounds of the accumulator range + lower_worst = AType.min() * np.ones(self.get_normal_attention_shape(0)) + lower_range = calculate_matvec_accumulator_range(lower_worst, VType) + upper_worst = AType.max() * np.ones(self.get_normal_attention_shape(0)) + upper_range = calculate_matvec_accumulator_range( # noqa: Duplicate + upper_worst, VType + ) + # Minimum and maximum values of the range + acc_min = min(min(lower_range), min(upper_range)) + acc_max = max(max(upper_range), max(upper_range)) + # Unsigned accumulator range + if acc_min >= 0: + # Number of bits necessary to represent the maximum value of the + # range. Some values between 0 and acc_min might be unused. + bitwidth = math.ceil(np.log2(acc_max + 1)) + # New unsigned accumulator datatype of this bitwidth + AccAVMatMul = DataType[f"UINT{bitwidth}"] # noqa + # Signed accumulator range + else: + # Maximum absolute value which needs to be represented + acc_max = max(-acc_min, 1 + acc_max) + # Number of bits necessary to represent the maximum value of the + # range. Some values on one of the ends might remain unused. + bitwidth = math.ceil(np.log2(acc_max) + 1) + # New signed accumulator datatype of this bitwidth + AccAVMatMul = DataType[f"INT{bitwidth}"] # noqa + # Update the accumulator datatype attribute + self.set_nodeattr("AccAVMatMul", AccAVMatMul.name) + # If there is no activation function following the accumulator, the + # output type needs to be adjusted as well + if self.get_nodeattr("ActAVMatMul") == "none": + # Update the output datatype attribute to the same type as the + # accumulator + self.set_nodeattr("OutAVMatMul", AccAVMatMul.name) + # # The output type of the whole operator is the same as the output + # # type of the last MatMul + # TODO: This currently breaks MergeMultiHeads via + # MinimizeAccumulatorWidth, which re-infers datatypes after + # each custom op instead of once after traversing the whole graph. + # self.set_nodeattr("OType", AccQKMatMul.name) + + # Gets the number of expected input values, i.e. how many times read() + # could/should be called on the input stream of this operator + def get_number_input_values(self, ind=0): + # Elements over all but the last dimension of the input folded along + # the embedding dimension + return np.prod(self.get_folded_input_shape(ind=ind)[:-1]) + + # Gets the number of expected output values, i.e. how many times read() + # could/should be called on the output stream of this operator + def get_number_output_values(self): + # Elements over all but the last dimension of the output folded along + # the embedding dimension + return np.prod(self.get_folded_output_shape()[:-1]) + + # Converts names of optional inputs to the node input index and from there + # to the ONNX node input name if the input is present. + # Note: This mapping is required as the ONNX graph/node may provide + # different names (in particular automatically generated unique names) and + # some of these are optional inputs. + def get_input_name_by_name(self, name): + # Ordered names of the (optional) threshold inputs + thresholds = [ + "thresholds_qk_matmul", + "thresholds_a_softmax", + "thresholds_av_matmul", + ] + + # Ordered names of primary query, key, value inputs and optional mask + # and threshold inputs. + inputs = ["Q", "K", "V", "M", *thresholds] + + # Specify for each input whether it is present or not + inputs_present = [ + # Note: Primary inputs are always present, the mask is present in + # "input" or "const" mask mode + True, True, True, self.get_nodeattr("mask_mode") in { + "input", "const" + }, + ] + + # Thresholds are present if the activation function is set to + # thresholds + inputs_present.extend([ + self.get_nodeattr("ActQKMatMul") == "thresholds", + self.get_nodeattr("ActASoftmax") == "thresholds", + self.get_nodeattr("ActAVMatMul") == "thresholds" + ]) + + # Filter the ordered list of input names for those which are actually + # present + inputs = [x for x, present in zip(inputs, inputs_present) if present] + + # Find the position of the requested input name and look up the + # corresponding input name of the ONNX node + return self.onnx_node.input[inputs.index(name)] + + # Derives the expected cycles for the attention operation given the folding + # configuration + def get_exp_cycles(self): + # Verify the folding configuration + assert self.is_valid_folding, \ + f"Invalid folding configuration for {self.onnx_node.name}" + # Get the input/output dimensions + qk_dim, q_len, v_dim, kv_len = self.shapes + # Get folding configuration describing how to parallelize along the + # dimensions + emb_fold, seq_fold = self.folds + # Assume perfect overlap of the constituents of the operator, i.e., of + # the buffering, both matmul and the softmax, then the expected cycles + # is the maximum over these operators + # Overall worst case cycles without any parallelization: ~ T x T x d + return max( + # Transposed keys buffer cycles + # Worst case: kv_len * qk_dim, ~ T x d + kv_len * emb_fold, + # Queries - keys matmul cycles + # Worst case: q_len * qk_dim * kv_len, ~ T x T x d + q_len * emb_fold * seq_fold, + # Softmax normalization cycles + # Worst case: q_len * kv_len, ~ T x T + q_len * seq_fold, + # Values buffer cycles + # Worst case: kv_len * v_dim, ~ T x d + kv_len * emb_fold, + # Attention weights - values matmul + # Worst case: q_len * v_dim * kv_len, ~ T x T x d + q_len * emb_fold * seq_fold + ) diff --git a/src/finn/custom_op/fpgadataflow/attention_heads.py b/src/finn/custom_op/fpgadataflow/attention_heads.py new file mode 100644 index 000000000..6d7c53250 --- /dev/null +++ b/src/finn/custom_op/fpgadataflow/attention_heads.py @@ -0,0 +1,640 @@ +# fmt: off +# Disable formatter. This is deliberately formatted to stay within 80 characters +# per line. Black, however, formats some lines going beyond this. + +# Numpy math and arrays +import numpy as np +# Operating system stuff, e.g. paths +import os +# Python warning subsystem +import warnings + +# Helper for creating ONNX nodes +from onnx import helper as oh + +# QONNX/FINN datatypes +from qonnx.core.datatype import DataType +# QONNX wrapper to ONNX model graphs +from qonnx.core.modelwrapper import ModelWrapper +# Derive custom operators form the FINN base custom op +from finn.custom_op.fpgadataflow.hwcustomop import HWCustomOp +# Converts inputs/outputs to/from RTL simulation format +from finn.util.data_packing import npy_to_rtlsim_input, rtlsim_output_to_npy + + +# Splitting of attention heads (after input projections) custom operator +class SplitMultiHeads(HWCustomOp): + # Initializes the operator given an onnx graph node + def __init__(self, onnx_node, **kwargs): + # Just forward all arguments to the init method of the CustomOp base + super().__init__(onnx_node, **kwargs) + + # Need to override the default depths of outputs FIFOs here as these + # depend on the number of heads, which are not known during calls to + # get_nodeattr_types. + if not self.get_nodeattr("outFIFODepths"): + self.set_nodeattr("outFIFODepths", [2 for _ in range(self.heads)]) + + # Defines attributes which must be present on this node + def get_nodeattr_types(self): + # Start from parent operator class attributes + attrs = HWCustomOp.get_nodeattr_types(self) + # Update attributes dictionary for new custom operator + attrs.update({ + # Number of attention heads + "heads": ("i", True, 1), + # Specifies whether the output is packed as a single output tensor + # or split as multiple output tensors + "packed": ("i", True, 1), + # Data type of input and output elements + "dtype": ("s", True, ""), + # Number of input elements to be split + "num_elems": ("i", True, 1), + # Number of inputs to be processed sequentially + "num_inputs": ("ints", True, [1]), + # Possible execution modes for simulating this node + # Note: Override to support python mode + "exec_mode": ( + "s", False, "python", {"", "rtlsim", "cppsim", "python"} + ), + + # Input and output FIFO depths for multi-I/O nodes + # Note: Need to override here as there multiple outputs + "inFIFODepths": ("ints", False, [2]), + "outFIFODepths": ("ints", False, []), # Default will be override + }) + # Return updated attribute dictionary + return attrs + + # Number of attention heads attribute as property for convenience + @property + def heads(self): + return self.get_nodeattr("heads") + + # Packed attribute as property for convenience + @property + def packed(self): + # Note: Converts from int to bool + return bool(self.get_nodeattr("packed")) + + # Datatype attribute as property for convenience + @property + def dtype(self): + # Note: Converts from string to QONNX data type + return DataType[self.get_nodeattr("dtype")] + + # Number of elements attribute as property for convenience + @property + def num_elems(self): + return self.get_nodeattr("num_elems") + + # Number of inputs attribute as property for convenience + @property + def num_inputs(self): + return self.get_nodeattr("num_inputs") + + # Makes an operation compatible with the output shape for shape inference + # Note: Propagates shape forward, i.e., never asks for the shape of the + # output, even if it seems easier. + def make_shape_compatible_op(self, model: ModelWrapper): # noqa + # Get the node wrapped by this custom op + node = self.onnx_node + # Determine the rank of the input tensor to support batched and + # non-batched inputs + rank = len(self.num_inputs) + 1 + # The input shape determines the sequence length + (seq, *_), dim = self.num_inputs, self.num_elems + # Packed outputs a represented by a reshape operation producing one + # tensor + if self.packed: + # Create a new name for the temporary shape tensor + shape = model.make_new_valueinfo_name() + # Set the target shape of slices heads + model.set_initializer( + shape, np.asarray([self.heads, seq, dim // self.heads]) + ) + # Return a node simulating the shape effect of slicing into + # multi-heads + return oh.make_node( + "Reshape", [node.input[0], shape], [node.output[0]] + ) + # Prepare a dummy input to simulate reordering of batch/head dimension + # to the front + mock_input = model.make_new_valueinfo_name() + # Set the target shape of slices heads + model.set_tensor_shape( + mock_input, [1, seq, dim] if rank == 3 else [seq, dim] + ) + # If the outputs are not packed, the operation is represented as a split + # operation producing number of heads outputs along the last axis + return oh.make_node( + "Split", [mock_input], node.output, num_outputs=self.heads, axis=-1 + ) + + # Infers the datatype of the node output + def infer_node_datatype(self, model: ModelWrapper): # noqa + # Get the node wrapped by this custom op # noqa Duplicate + node = self.onnx_node + # Test for changing input datatype + if model.get_tensor_datatype(node.input[0]) != self.dtype: + # Get the new datatype + new_dtype = model.get_tensor_datatype(node.input[0]) + # Issue a warning message + warnings.warn( + f"{node.name}: dtype changing from {self.dtype} to {new_dtype}" + ) + # Set the new datatype attribute + self.set_nodeattr("dtype", new_dtype.name) + # Propagate the type from the input to each output tensor + for o in node.output: + # Slicing simply propagates the dtype to the output + model.set_tensor_datatype(o, self.dtype) + + # Executes multi-head slicing in python + def _execute_node_python(self, context, graph): # noqa: graph unused + # Get the node wrapped by this custom op + node = self.onnx_node + # Get the input out of the execution context + # Note: Shape must be either seq x 1 x dim or seq x dim + inp = context[node.input[0]] + # Packed execution boils down to a reshape of the single input to a + # single output + if self.packed: + # Reshape to separate the heads out of the embedding dimensions, + # finally transpose to heads first layout + out = inp.reshape(inp.shape[0], self.heads, -1).transpose(1, 0, 2) + # Write the output into the execution context + context[node.output[0]] = out + # Split is realized as the split operation of numpy + else: + # Produces multiple outputs as a list + splits = np.split(inp, indices_or_sections=self.heads, axis=-1) + # Correspondence between outputs and splits in order + for o, out in zip(node.output, splits): + # Write the output into the execution context + context[o] = out + + # Executes multi-head splitting in C++ simulation + def _execute_node_cppsim(self, context, graph): # noqa: graph unused + # C++ Simulation needs to be implemented in HLS backend specialization + raise NotImplementedError( + f"exec_mode cppsim of {self.__class__.__name__} is not implemented!" + ) + + # Executes multi-head slicing in RTL simulation + def _execute_node_rtlsim(self, context, graph): # noqa: graph unused + # Get the node wrapped by this custom op # noqa Duplicate + node = self.onnx_node + # Input data is stored in numpy files in the code generation dictionary + code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") + # Get the input out of the execution context + # Note: Shape must be either seq x 1 x dim or seq x dim + inp = context[node.input[0]] + # Validate the shape of the input + assert inp.shape == self.get_normal_input_shape(ind=0), \ + f"Input shape mismatch for {node.input[0]}" + # Reshape the input into folded form + inp = inp.reshape(self.get_folded_input_shape(ind=0)) + # Path to store the intermediate input in numpy format + filename = os.path.join(code_gen_dir, "in.npy") + # Save the folded inputs to file to be used by simulation + np.save(filename, inp) + # Start collecting inputs/outputs to the RTL simulation in a dictionary + # Note: Prepare one output list per head + io_dict = { + "inputs": {}, "outputs": {f"out{i}": [] for i in range(self.heads)} + } + # Type and width of the input tensor + dtype = self.get_input_datatype(ind=0) + width = self.get_instream_width(ind=0) + # Convert inputs to RTL simulation format + io_dict["inputs"]["in"] = npy_to_rtlsim_input(filename, dtype, width) + + # Setup PyVerilator simulation of the node + sim = self.get_rtlsim() + # Reset the RTL simulation + super().reset_rtlsim(sim) + super().toggle_clk(sim) + # Run the RTL Simulation + self.rtlsim_multi_io(sim, io_dict) + + # Enumerate the node outputs + for i, name in enumerate(node.output): + # Collect the output from RTL simulation + out = io_dict["outputs"][f"out{i}"] + # Type and sizes of the output tensor + dtype = self.get_output_datatype(ind=i) + width = self.get_outstream_width(ind=i) + shape = self.get_folded_output_shape(ind=i) + # Path to store the intermediate numpy file + filename = os.path.join(code_gen_dir, f"out{i}.npy") + # Convert from RTL simulation format to numpy format + rtlsim_output_to_npy( + out, filename, dtype, shape, width, dtype.bitwidth() + ) + # Load the generated output numpy file + out = np.load(filename) + # Reshape the folded output and insert into the execution context + context[name] = out.reshape(self.get_normal_output_shape(ind=i)) + + # Executes multi-head slicing in simulation (either python c++ or rtl sim) + def execute_node(self, context, graph): + # Get the configured execution mode + mode = self.get_nodeattr("exec_mode") + # Lookup table mapping execution modes to implementing methods + exec_fns = { + "python": self._execute_node_python, + "cppsim": self._execute_node_cppsim, + "rtlsim": self._execute_node_rtlsim, + } + # Select and execute the function by mode string + exec_fns[mode](context, graph) + + # Verifies the node attributes, inputs and outputs + def verify_node(self): + # TODO: Implement + return [] + + # Note: End of QONNX CustomOp region, below is FINN HWCustomOp stuff + + # Gets the datatype of input at index ind + def get_input_datatype(self, ind=0): + # All inputs (there should only be one) have the same type + return self.dtype + + # Gets the datatype of the output at index ind + def get_output_datatype(self, ind=0): + # All outputs will hae the same type, which is the same as the input + return self.dtype + + # Gets the shape of the input at index ind without folding + def get_normal_input_shape(self, ind=0): + # There is only one input with shape configured as attributes + # Unpack multi-axis inputs list to yield a flat tuple as shape + return *self.num_inputs, self.num_elems + + # Gets the shape of the output at index ind without folding + def get_normal_output_shape(self, ind=0): + # Packed layout is currently not implemented + assert not self.packed, "Packed multi-heads are not implemented yet" + # All output have the same shape, which correspond to distributing the + # number of input elements to the heads specified as attributes + # Unpack multi-axis inputs list to yield a flat tuple as shape + return *self.num_inputs, self.num_elems // self.heads + + # Gets the shape of the input at index ind with folding + def get_folded_input_shape(self, ind=0): + # No folding for now, normal and folded shape are the same + return self.get_normal_input_shape(ind=ind) + + # Gets the shape of the output at index ind with folding + def get_folded_output_shape(self, ind=0): + # No folding for now, normal and folded shape are the same + return self.get_normal_output_shape(ind=ind) + + # Widths of the input data stream of the input at index ind + def get_instream_width(self, ind=0): + # Get the number of bits used to represent the input + i_bits = self.get_input_datatype(ind).bitwidth() + # Parallelism is the number of elements in the last dimension of the + # folded input + *_, elems = self.get_folded_input_shape(ind) + # Width of a stream receiving input elements in parallel + return elems * i_bits + + # Widths of the output data stream of the output at index ind + def get_outstream_width(self, ind=0): + # Get the number of bits used to represent the output + o_bits = self.get_output_datatype(ind).bitwidth() + # Parallelism is the number of elements in the last dimension of the + # folded output + *_, elems = self.get_folded_output_shape(ind) + # Width of a stream producing output elements in parallel + return elems * o_bits + + # Gets the number of expected output values, i.e. how many times read() + # could/should be called on any output stream of this operator + def get_number_output_values(self): + # Elements over all but the last dimension of the output folded along + # the embedding dimension. Need to count across the number of heads, as + # RTL simulation actually counts individual inputs, not cycles with + # inputs, i.e., producing N heads outputs per cycle in parallel, count + # N outputs per cycle... + return np.prod(self.get_folded_output_shape()[:-1]) * self.heads + + # Derives the expected cycles for the attention head splitting operation + # given the folding configuration + def get_exp_cycles(self): + # Currently, this implicitly assumes fully parallelized processing + # along the embedding dimension, i.e., always max PE + return np.prod(self.num_inputs) + + +# Merging of attention heads (before output projections) custom operator +class MergeMultiHeads(HWCustomOp): + # Initializes the operator given an onnx graph node + def __init__(self, onnx_node, **kwargs): + # Just forward all arguments to the init method of the CustomOp base + super().__init__(onnx_node, **kwargs) + + # Need to override the default depths of input FIFOs here as these + # depend on the number of heads, which are not known during calls to + # get_nodeattr_types. + if not self.get_nodeattr("inFIFODepths"): + self.set_nodeattr("inFIFODepths", [2 for _ in range(self.heads)]) + + # Defines attributes which must be present on this node + def get_nodeattr_types(self): + # Start from parent operator class attributes + attrs = HWCustomOp.get_nodeattr_types(self) + # Update attributes dictionary for new custom operator + attrs.update({ + # Number of attention heads + "heads": ("i", True, 1), + # Specifies whether the output is packed as a single output tensor + # or split as multiple output tensors + "packed": ("i", True, 1), + # Data type of input and output elements + "dtype": ("s", True, ""), + # Number of input elements to be split + "num_elems": ("i", True, 1), + # Number of inputs to be processed sequentially + "num_inputs": ("ints", True, [1]), + # Output needs to be squeezed + "squeezed": ("i", True, 0), + # Possible execution modes for simulating this node + # Note: Override to support python mode + "exec_mode": ( + "s", False, "python", {"", "rtlsim", "cppsim", "python"} + ), + + # Input and output FIFO depths for multi-I/O nodes + # Note: Need to override here as there multiple inputs + "inFIFODepths": ("ints", False, []), # Default will be override + "outFIFODepths": ("ints", False, [2]), + }) + # Return updated attribute dictionary + return attrs + + # Number of attention heads attribute as property for convenience + @property + def heads(self): + return self.get_nodeattr("heads") + + # Packed attribute as property for convenience + @property + def packed(self): + # Note: Converts from int to bool + return bool(self.get_nodeattr("packed")) + + # Datatype attribute as property for convenience + @property + def dtype(self): + # Note: Converts from string to QONNX data type + return DataType[self.get_nodeattr("dtype")] + + # Number of elements attribute as property for convenience + @property + def num_elems(self): + return self.get_nodeattr("num_elems") + + # Number of inputs attribute as property for convenience + @property + def num_inputs(self): + return self.get_nodeattr("num_inputs") + + # Squeezed output attribute as property for convenience + @property + def squeezed(self): + # Note: Converts from int to bool + return bool(self.get_nodeattr("squeezed")) + + # Makes an operation compatible with the output shape for shape inference + # Note: Propagates shape forward, i.e., never asks for the shape of the + # output, even if it seems easier. + def make_shape_compatible_op(self, model: ModelWrapper): # noqa + # Squeeze single-element batch dimension from the output? + squeezed = self.squeezed + # Assume unpacked inputs by default, here seq sill be the number of + # input feature maps + seq = self.num_inputs + # Packed inputs a represented by a reshape operation consuming one + # tensor + if self.packed: + # Drop the heads-first dimension from packed inputs + seq = self.num_inputs[1:] + # Distribute the heads into the embedding dimension + dim = self.heads * self.num_elems + # Constant operation producing output of given shape + return super().make_const_shape_op( + [*seq, dim] if squeezed else [*seq, 1, dim] + ) + + # Infers the datatype of the node output + def infer_node_datatype(self, model: ModelWrapper): # noqa + # Get the node wrapped by this custom op + node = self.onnx_node # noqa Duplicate + # Test for changing input datatype + if model.get_tensor_datatype(node.input[0]) != self.dtype: + # Get the new datatype + new_dtype = model.get_tensor_datatype(node.input[0]) + # Issue a warning message + warnings.warn( + f"{node.name}: dtype changing from {self.dtype} to {new_dtype}" + ) + # Set the new datatype attribute + self.set_nodeattr("dtype", new_dtype.name) + # All inputs must have the same datatype + assert all( + model.get_tensor_datatype(inp) == self.dtype for inp in node.input + ), f"{node.name}: All inputs must have the same datatype" + # Merging simply propagates the datatype to the output + model.set_tensor_datatype(node.output[0], self.dtype) + + # Executes multi-head merging in python + def _execute_node_python(self, context, graph): # noqa: graph unused + # Get the node wrapped by this custom op + node = self.onnx_node + # Get the input out of the execution context + # Note: Shape must be heads x seq x dim + inp = context[node.input[0]] + # Packed execution boils down to a reshape of the single input to a + # single output + if self.packed: + # Transpose back into sequence first layout then reintegrate the + # heads via reshape + out = inp.transpose(1, 0, 2).reshape( + inp.shape[1], 1, self.heads * inp.shape[-1] + ) + # Split is realized as the concat operation of numpy + else: + # Collect the list of inputs from the execution context and + # concatenate along the last axis + out = np.concatenate([context[i] for i in node.input], axis=-1) + # Reshape to simulate the batch dimensions if it is not present + out = out.reshape(out.shape[0], 1, out.shape[-1]) + # Optionally squeeze the output (remove batch dimension of size 1) + if self.squeezed: + # Squeeze batch dimension via reshape + out = out.reshape(out.shape[0], out.shape[-1]) + # Write the output into the execution context. Force output shape + # which might be squeezed + context[node.output[0]] = out + + # Executes multi-head merging in C++ simulation + def _execute_node_cppsim(self, context, graph): # noqa: graph unused + # C++ Simulation needs to be implemented in HLS backend specialization + raise NotImplementedError( + f"exec_mode cppsim of {self.__class__.__name__} is not implemented!" + ) + + # Executes multi-head slicing in RTL simulation + def _execute_node_rtlsim(self, context, graph): # noqa: graph unused + # Get the node wrapped by this custom op + node = self.onnx_node + # Input data is stored in numpy files in the code generation dictionary + code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") + + # Start collecting inputs/outputs to the RTL simulation in a dictionary + # Note: Prepare one output list per head + io_dict = { + "inputs": {}, "outputs": {"out": []} + } + + # Enumerate the node inputs + for i, name in enumerate(node.input): + # Get the input out of the execution context + # Note: Shape must be either 1 x seq x dim or seq x dim + inp = context[name] + # Validate the shape of the input + assert inp.shape == self.get_normal_input_shape(ind=i), \ + f"Input shape mismatch for {name}" + # Reshape the input into folded form + inp = inp.reshape(self.get_folded_input_shape(ind=i)) + # Path to store the intermediate input in numpy format + filename = os.path.join(code_gen_dir, f"in{i}.npy") + # Save the folded inputs to file to be used by simulation + np.save(filename, inp) + # Type and width of the input tensor + dtype = self.get_input_datatype(ind=i) + width = self.get_instream_width(ind=i) + # Convert inputs to RTL simulation format + io_dict["inputs"][f"in{i}"] = npy_to_rtlsim_input( + filename, dtype, width + ) + + # Setup PyVerilator simulation of the node + sim = self.get_rtlsim() + # Reset the RTL simulation + super().reset_rtlsim(sim) + super().toggle_clk(sim) + # Run the RTL Simulation + self.rtlsim_multi_io(sim, io_dict) + + # Collect the output from RTL simulation + out = io_dict["outputs"]["out"] + # Type and sizes of the output tensor + dtype = self.get_output_datatype(ind=0) + width = self.get_outstream_width(ind=0) + shape = self.get_folded_output_shape(ind=0) + # Path to store the intermediate numpy file + filename = os.path.join(code_gen_dir, "out.npy") + # Convert from RTL simulation format to numpy format + rtlsim_output_to_npy( + out, filename, dtype, shape, width, dtype.bitwidth() + ) + # Load the output numpy file generated by the RTL simulation + out = np.load(filename) + # Reshape the folded output and insert into the execution context + context[node.output[0]] = out.reshape( + self.get_normal_output_shape(ind=0) + ) + + # Executes multi-head slicing in simulation (either python c++ or rtl sim) + def execute_node(self, context, graph): + # Get the configured execution mode + mode = self.get_nodeattr("exec_mode") + # Lookup table mapping execution modes to implementing methods + exec_fns = { + "python": self._execute_node_python, + "cppsim": self._execute_node_cppsim, + "rtlsim": self._execute_node_rtlsim, + } + # Select and execute the function by mode string + exec_fns[mode](context, graph) + + # Verifies the node attributes, inputs and outputs + def verify_node(self): + # TODO: Implement + return [] + + # Note: End of QONNX CustomOp region, below is FINN HWCustomOp stuff + + # Gets the datatype of input at index ind + def get_input_datatype(self, ind=0): + # All inputs (there should only be one) have the same type + return self.dtype + + # Gets the datatype of the output at index ind + def get_output_datatype(self, ind=0): + # All outputs will have the same type, which is the same as the input + return self.dtype + + # Gets the shape of the input at index ind without folding + def get_normal_input_shape(self, ind=0): + # Packed layout is currently not implemented + assert not self.packed, "Packed multi-heads are not implemented yet" + # There is only one input with shape configured as attributes + # Unpack multi-axis inputs list to yield a flat tuple as shape + return *self.num_inputs, self.num_elems + + # Gets the shape of the output at index ind without folding + def get_normal_output_shape(self, ind=0): + # All output have the same shape, which correspond to collecting the + # number of input elements from the heads specified as attributes + # Unpack multi-axis inputs list to yield a flat tuple as shape + return *self.num_inputs, self.num_elems * self.heads + + # Gets the shape of the input at index ind with folding + def get_folded_input_shape(self, ind=0): + # No folding for now, normal and folded shape are the same + return self.get_normal_input_shape(ind=ind) + + # Gets the shape of the output at index ind with folding + def get_folded_output_shape(self, ind=0): + # No folding for now, normal and folded shape are the same + return self.get_normal_output_shape(ind=ind) + + # Widths of the input data stream of the input at index ind + def get_instream_width(self, ind=0): + # Get the number of bits used to represent the input + i_bits = self.get_input_datatype(ind).bitwidth() + # Parallelism is the number of elements in the last dimension of the + # folded input + *_, elems = self.get_folded_input_shape(ind) + # Width of a stream receiving input elements in parallel + return elems * i_bits + + # Widths of the output data stream of the output at index ind + def get_outstream_width(self, ind=0): + # Get the number of bits used to represent the output + o_bits = self.get_output_datatype(ind).bitwidth() + # Parallelism is the number of elements in the last dimension of the + # folded output + *_, elems = self.get_folded_output_shape(ind) + # Width of a stream producing output elements in parallel + return elems * o_bits + + # Gets the number of expected output values, i.e. how many times read() + # could/should be called on any output stream of this operator + def get_number_output_values(self): + # Elements over all but the last dimension of the output folded along + # the embedding dimension + return np.prod(self.get_folded_output_shape()[:-1]) + + # Derives the expected cycles for the attention head merging operation given + # the folding configuration + def get_exp_cycles(self): + # Currently, this implicitly assumes fully parallelized processing + # along the embedding dimension, i.e., always max PE + return np.prod(self.num_inputs) diff --git a/src/finn/custom_op/fpgadataflow/hls/__init__.py b/src/finn/custom_op/fpgadataflow/hls/__init__.py index 7f46b1519..38a270c56 100644 --- a/src/finn/custom_op/fpgadataflow/hls/__init__.py +++ b/src/finn/custom_op/fpgadataflow/hls/__init__.py @@ -66,6 +66,11 @@ def register_custom_op(cls): import finn.custom_op.fpgadataflow.hls.unsqueeze_hls from finn.custom_op.fpgadataflow.hls.addstreams_hls import AddStreams_hls +from finn.custom_op.fpgadataflow.hls.attention_heads_hls import ( + MergeMultiHeads_hls, + SplitMultiHeads_hls, +) +from finn.custom_op.fpgadataflow.hls.attention_hls import ScaledDotProductAttention_hls from finn.custom_op.fpgadataflow.hls.channelwise_op_hls import ChannelwiseOp_hls from finn.custom_op.fpgadataflow.hls.checksum_hls import CheckSum_hls from finn.custom_op.fpgadataflow.hls.concat_hls import StreamingConcat_hls @@ -82,6 +87,7 @@ def register_custom_op(cls): from finn.custom_op.fpgadataflow.hls.lookup_hls import Lookup_hls from finn.custom_op.fpgadataflow.hls.matrixvectoractivation_hls import MVAU_hls from finn.custom_op.fpgadataflow.hls.pool_hls import Pool_hls +from finn.custom_op.fpgadataflow.hls.replicate_stream_hls import ReplicateStream_hls from finn.custom_op.fpgadataflow.hls.split_hls import StreamingSplit_hls from finn.custom_op.fpgadataflow.hls.streamingdatawidthconverter_hls import ( StreamingDataWidthConverter_hls, @@ -118,3 +124,8 @@ def register_custom_op(cls): custom_op["UpsampleNearestNeighbour_hls"] = UpsampleNearestNeighbour_hls custom_op["MVAU_hls"] = MVAU_hls custom_op["VVAU_hls"] = VVAU_hls + +custom_op["ScaledDotProductAttention_hls"] = ScaledDotProductAttention_hls +custom_op["SplitMultiHeads_hls"] = SplitMultiHeads_hls +custom_op["MergeMultiHeads_hls"] = MergeMultiHeads_hls +custom_op["ReplicateStream_hls"] = ReplicateStream_hls diff --git a/src/finn/custom_op/fpgadataflow/hls/attention_heads_hls.py b/src/finn/custom_op/fpgadataflow/hls/attention_heads_hls.py new file mode 100644 index 000000000..fe67d04c9 --- /dev/null +++ b/src/finn/custom_op/fpgadataflow/hls/attention_heads_hls.py @@ -0,0 +1,490 @@ +# fmt: off +# Disable formatter. This is deliberately formatted to stay within 80 characters +# per line. Black, however, formats some lines going beyond this. + +# Numpy math and arrays +import numpy as np + +# Operating system stuff, e.g. paths +import os + +# Base class for specializing HW operators as implemented via HLS +from finn.custom_op.fpgadataflow.hlsbackend import HLSBackend +# The generic HW custom operator version of the operator as a base class +from finn.custom_op.fpgadataflow.attention_heads import ( # noqa + MergeMultiHeads, SplitMultiHeads +) + + +# HLS Backend specialization of the multi-head attention splitting operator +class SplitMultiHeads_hls( # noqa: Class name does not follow + # CapWords convention + SplitMultiHeads, HLSBackend +): + # Node attributes matching the HLS operator + def get_nodeattr_types(self): + # Start from parent operator class attributes + attrs = SplitMultiHeads.get_nodeattr_types(self) + # Add the HLSBackend default attributes on top + attrs.update(HLSBackend.get_nodeattr_types(self)) + # Add/Specialize implementation specific attributes here... + # Return the updated attributes dictionary + return attrs + + # Executes multi-head splitting in C++ simulation + def _execute_node_cppsim(self, context, graph): # noqa: graph unused + # Get the node wrapped by this custom op # noqa Duplicate + node = self.onnx_node # noqa Duplicate + # Input data is stored in numpy files in the code generation dictionary + code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim") + # Get the input out of the execution context + # Note: Shape must be either seq x 1 x dim or seq x dim + inp = context[node.input[0]] + # Validate the shape of the input + assert inp.shape == self.get_normal_input_shape(ind=0), \ + f"Input shape mismatch for {node.input[0]}" + # Reshape the input into folded form + inp = inp.reshape(self.get_folded_input_shape(ind=0)) + # Save the folded inputs to file to be used by simulation + np.save(os.path.join(code_gen_dir, "in.npy"), inp) + + # Execute the precompiled model + super().exec_precompiled_singlenode_model() + + # Enumerate the node outputs + for i, name in enumerate(node.output): + # Load the output numpy file generated by the C++ simulation + out = np.load(os.path.join(code_gen_dir, f"out{i}.npy")) + # Reshape the folded output and insert into the execution context + context[name] = out.reshape(self.get_normal_output_shape(ind=i)) + + # Maximum width of any ap_int used in this operator + def get_ap_int_max_w(self): + # Find the widths of the widest input + # Note: There is just one input. + i_bits_max = self.get_instream_width(ind=0) + # Find the widths of the widest output + # Note: there is one output per head + o_bits_max = max( + (self.get_outstream_width(ind) for ind in range(self.heads)) + ) + # Find the biggest of the inputs/outputs + return max([i_bits_max, o_bits_max]) + + # Note: End of shape and datatype utilities + + # Generates list of C++ includes to be placed at the top of the generated + # code + def global_includes(self): + # Currently nothing to include + self.code_gen_dict["$GLOBALS$"] = [] + + # Generates C++ code of type alias, global constant and macro definitions + def defines(self, var): + # Insert constants and type aliases into the dictionary + self.code_gen_dict["$DEFINES$"] = [ + # Input and output element datatypes + f"using IType = {self.dtype.get_hls_datatype_str()};", + f"using OType = {self.dtype.get_hls_datatype_str()};", + # Datatype of elements packed into the input stream + f"using IPacked = ap_uint<{self.get_instream_width()}>;", + # Datatype of elements packed into the output stream + f"using OPacked = ap_uint<{self.get_outstream_width()}>;", + # Input and output HLS stream datatypes + "using IStream = hls::stream<" + f" ap_uint<{self.get_instream_width()}>" + ">;", + "using OStream = hls::stream<" + f" ap_uint<{self.get_outstream_width()}>" + ">;", + ] + + # Generates C++ code for reading data from .npy (numpy format) for testing + # in C++ simulation + def read_npy_data(self): + # Input data is stored in numpy files in the code generation dictionary + code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim") + # Generate function calls for reading the input files into the input + # streams + self.code_gen_dict["$READNPYDATA$"] = [ + # Generate function call reading from file into the input stream + # Note: Inputs are always represented as numpy floats + 'npy2apintstream(', + f'"{code_gen_dir}/in.npy", in_{self.hls_sname()}, false', + ');' + ] + + # Generates C++ code for declaring all streams involved in C++ simulation + # for testing + def strm_decl(self): + # Declare input and output streams + # Note: Assumes stream type aliases to be set in defines + self.code_gen_dict["$STREAMDECLARATIONS$"] = [ + # There is one input datastream + f"IStream in_{self.hls_sname()};", + # There is one output datastream per head + *(f"OStream out{i}_{self.hls_sname()};" for i in range(self.heads)) + ] + + # Generates C++ code for calling the computation part of the operator + def docompute(self): + # Generates the bit-slicing indices string for the ith split of the + # input + def split(i): + # Assemble a C++ indexing/bit-slicing string + return f"({i + 1} * OPacked::width - 1, {i} * OPacked::width)" + + # Generates the name of the ith output stream + def out(i): + return f"out{i}_{self.hls_sname()}" + + # Write the body of the head-splitting top-level function + self.code_gen_dict["$DOCOMPUTE$"] = [ + # Repeat for the number of inputs + # Note: Repeat for all num_inputs dimensions + f"for(std::size_t i = 0; i < {np.prod(self.num_inputs)}; ++i) {{", + # Pipeline the steps of this loop + "#pragma HLS pipeline II=1 style=flp", + # Read the next input element from the stream + f"const auto x = in_{self.hls_sname()}.read();", + # Split the next element from the input stream into the number of + # output elements per head and write into the corresponding stream + *(f"{out(i)}.write(x{split(i)});" for i in range(self.heads)), + # End of for-loop over repetitions body + f"}}" # noqa: f-string symmetry + ] + + # Generates C++ code for reading the output stream and converting back to + # numpy format for testing in C++ simulation + def dataoutstrm(self): + # Output data will be stored in numpy files in the # noqa Duplicate + # code generation dictionary + code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim") + # Get the expected shape of the folded output array formatted as a C++ + # vector initializer + # Note: Valid formatting relies on correct placement of curly braces + # and line breaks: Open/close all three braces on the same line of code + # to avoid '\n' to be inserted into the string + shape = f"""{{{ + ','.join((str(i) for i in self.get_folded_output_shape())) + }}}""" + # Start collecting function calls to write the output data stream + self.code_gen_dict["$DATAOUTSTREAM$"] = [] + + # Generates the name of the ith output stream + def out(i): + return f"out{i}_{self.hls_sname()}" + + # Generate code for each output stream + for i in range(self.heads): + # Append each reading/writing function call + self.code_gen_dict["$DATAOUTSTREAM$"] += [ + # Generate function call reading from stream into the output + # file + # Note: Outputs are always represented as numpy floats + 'apintstream2npy(', + f'{out(i)}, {shape}, "{code_gen_dir}/out{i}.npy", false', + ');' + ] + + # Generates C++ code for saving the output of C++ simulation to a file in + # numpy format + def save_as_npy(self): + # Note: This seems to be empty in ALL HLSCustomOps. Probably it was used + # for something before, which is now integrated into dataoutstrm()? + self.code_gen_dict["$SAVEASCNPY$"] = [] + + # Generates essentially the head of the C++ function from which the IP block + # will be generated during ipgen, i.e. actual synthesis + def blackboxfunction(self): + # Insert function head describing the top level interface of the head + # splitting operator + self.code_gen_dict["$BLACKBOXFUNCTION$"] = [ + # @formatter:off Prevent Python formatter from messing with C++ + # formatting + # Note: Assumes stream type aliases to be set in defines + f"void {self.onnx_node.name} (", + # Input HLS stream + f"IStream &in_{self.hls_sname()}, ", ",".join([ + # One output HLS stream per head # noqa: Formatting + f"OStream &out{i}_{self.hls_sname()}" for i in range(self.heads) + ]), + ")", + # @formatter:off + ] + + # Generates C++ pragmas to be inserted into the main function of the C++ + # simulation and the ipgen-blackboxfunction as well + def pragmas(self): + # Add HLS interface directives specifying how to create RTL ports for + # the top-level function arguments + self.code_gen_dict["$PRAGMAS$"] = [ + # Connect the input stream with an axi stream interface + f"#pragma HLS INTERFACE axis port=in_{self.hls_sname()}" + ] + # Connect each output stream with an axi stream interface + for i in range(self.heads): + # Add new interface directive for the output stream + self.code_gen_dict["$PRAGMAS$"] += [ + f"#pragma HLS INTERFACE axis port=out{i}_{self.hls_sname()}" + ] + # No block-level I/O protocol for the function return value + self.code_gen_dict["$PRAGMAS$"].append( + "#pragma HLS INTERFACE ap_ctrl_none port=return" + ) + + # Returns the names of input and output interfaces grouped by protocol + def get_verilog_top_module_intf_names(self): + # Start collecting interface names in a dictionary # noqa Duplicate + # starting with clock and reset + intf_names = {"clk": ["ap_clk"], "rst": ["ap_rst_n"]} # noqa + # AXI stream input interfaces + intf_names["s_axis"] = [ + # Just one input stream + (f"in_{self.hls_sname()}", self.get_instream_width_padded(ind=0)), + ] + # AXI stream output interfaces + intf_names["m_axis"] = [ + # One output stream per head + (f"out{i}_{self.hls_sname()}", + self.get_outstream_width_padded(ind=i)) for i in range(self.heads) + ] + # No AXI-MM, AXI-Lite or protocol-less interfaces + intf_names["aximm"] = [] + intf_names["axilite"] = [] + intf_names["ap_none"] = [] + # Return the interface name dictionary + return intf_names + + +# HLS Backend specialization of the multi-head attention merging operator +class MergeMultiHeads_hls( # noqa: Class name does not follow + # CapWords convention + MergeMultiHeads, HLSBackend +): + # Node attributes matching the HLS operator + def get_nodeattr_types(self): + # Start from parent operator class attributes + attrs = MergeMultiHeads.get_nodeattr_types(self) + # Add the HLSBackend default attributes on top + attrs.update(HLSBackend.get_nodeattr_types(self)) + # Add/Specialize implementation specific attributes here... + # Return the updated attributes dictionary + return attrs + + # Executes multi-head slicing in C++ simulation + def _execute_node_cppsim(self, context, graph): # noqa: graph unused + # Get the node wrapped by this custom op + node = self.onnx_node + # Input data is stored in numpy files in the code generation dictionary + code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim") + + # Enumerate the node outputs + for i, name in enumerate(node.input): + # Get the input out of the execution context + # Note: Shape must be either 1 x seq x dim or seq x dim + inp = context[name] + # Validate the shape of the input + assert inp.shape == self.get_normal_input_shape(ind=i), \ + f"Input shape mismatch for {name}" + # Reshape the input into folded form + inp = inp.reshape(self.get_folded_input_shape(ind=i)) + # Save the folded inputs to file to be used by simulation + np.save(os.path.join(code_gen_dir, f"in{i}.npy"), inp) + + # Execute the precompiled model + super().exec_precompiled_singlenode_model() + + # Load the output numpy file generated by the C++ simulation + out = np.load(os.path.join(code_gen_dir, "out.npy")) + # Reshape the folded output and insert into the execution context + context[node.output[0]] = out.reshape( + self.get_normal_output_shape(ind=0) + ) + + # Maximum width of any ap_int used in this operator + def get_ap_int_max_w(self): + # Find the widths of the widest input + # Note: There is just one input. + i_bits_max = self.get_instream_width(ind=0) + # Find the widths of the widest output + # Note: there is one output per head + o_bits_max = max( + (self.get_outstream_width(ind) for ind in range(self.heads)) + ) + # Find the biggest of the inputs/outputs + return max([i_bits_max, o_bits_max]) + +# Note: End of shape and datatype utilities + + # Generates list of C++ includes to be placed at the top of the generated + # code + def global_includes(self): + # Currently nothing to include + self.code_gen_dict["$GLOBALS$"] = [] + + # Generates C++ code of type alias, global constant and macro definitions + def defines(self, var): + # Insert constants and type aliases into the dictionary + self.code_gen_dict["$DEFINES$"] = [ + # Input and output element datatypes + f"using IType = {self.dtype.get_hls_datatype_str()};", + f"using OType = {self.dtype.get_hls_datatype_str()};", + # Datatype of elements packed into the input stream + f"using IPacked = ap_uint<{self.get_instream_width()}>;", + # Datatype of elements packed into the output stream + f"using OPacked = ap_uint<{self.get_outstream_width()}>;", + # Input and output HLS stream datatypes + "using IStream = hls::stream<" + f" ap_uint<{self.get_instream_width()}>" + ">;", + "using OStream = hls::stream<" + f" ap_uint<{self.get_outstream_width()}>" + ">;", + ] + + # Generates C++ code for reading data from .npy (numpy format) for testing + # in C++ simulation + def read_npy_data(self): + # Input data is stored in numpy files in the code generation dictionary + code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim") + # Generate function calls for reading the input files into the input + # streams + self.code_gen_dict["$READNPYDATA$"] = [] + # Generate code for each input stream + for i in range(self.heads): + # Append each reading/writing function call + self.code_gen_dict["$READNPYDATA$"] += [ + # Generate function call reading from file into the input stream + # Note: Inputs are always represented as numpy floats + 'npy2apintstream(', + f'"{code_gen_dir}/in{i}.npy", in{i}_{self.hls_sname()}, false', + ');' + ] + + # Generates C++ code for declaring all streams involved in C++ simulation + # for testing + def strm_decl(self): + # Declare input and output streams + # Note: Assumes stream type aliases to be set in defines + self.code_gen_dict["$STREAMDECLARATIONS$"] = [ + # There is one output stream + f"OStream out_{self.hls_sname()};", + # There is one input stream per head + *(f"IStream in{i}_{self.hls_sname()};" for i in range(self.heads)) + ] + + # Generates C++ code for calling the computation part of the operator + def docompute(self): + reversed_reads = ", ".join([ + f"in{i}_{self.hls_sname()}.read()" + for i in reversed(range(self.heads)) + ]) + + # Write the body of the head-splitting top-level function + self.code_gen_dict["$DOCOMPUTE$"] = [ + # Repeat for the number of inputs + # Note: Repeat for all num_inputs dimensions + f"for(std::size_t i = 0; i < {np.prod(self.num_inputs)}; ++i) {{", + # Pipeline the steps of this loop + "#pragma HLS pipeline II=1 style=flp", + # Read the next input element from each input stream and concatenate + # using the comma operator overload of ap_uint, writing into the + # output stream + f"out_{self.hls_sname()}.write(({reversed_reads}));" + # End of for-loop over repetitions body + f"}}" # noqa: f-string symmetry + ] + + # Generates C++ code for reading the output stream and converting back to + # numpy format for testing in C** simulation + def dataoutstrm(self): + # Output data will be stored in numpy files in the code generation + # dictionary + code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim") + # Get the expected shape of the folded output array formatted as a C++ + # vector initializer + # Note: Valid formatting relies on correct placement of curly braces + # and line breaks: Open/close all three braces on the same line of code + # to avoid '\n' to be inserted into the string + shape = f"""{{{ + ','.join((str(i) for i in self.get_folded_output_shape())) + }}}""" + # Generate function call for reading from the output stream into the + # output file + self.code_gen_dict["$DATAOUTSTREAM$"] = [ + # Generate function call reading from stream into the output file + # Note: Outputs are always represented as numpy floats + 'apintstream2npy(', + f'out_{self.hls_sname()}, {shape}, "{code_gen_dir}/out.npy", false', + ');', + ] + + # Generates C++ code for saving the output of C++ simulation to a file in + # numpy format + def save_as_npy(self): + # Note: This seems to be empty in ALL HLSCustomOps. Probably it was used + # for something before, which is now integrated into dataoutstrm()? + self.code_gen_dict["$SAVEASCNPY$"] = [] + + # Generates essentially the head of the C++ function from which the IP block + # will be generated during ipgen, i.e. actual synthesis + def blackboxfunction(self): + # Insert function head describing the top level interface of the head + # splitting operator + self.code_gen_dict["$BLACKBOXFUNCTION$"] = [ + # @formatter:off Prevent Python formatter from messing with C++ + # formatting + # Note: Assumes stream type aliases to be set in defines + f"void {self.onnx_node.name} (", + # Output HLS stream + f"OStream &out_{self.hls_sname()}, ", ",".join([ + # One input HLS stream per head # noqa: Formatting + f"IStream &in{i}_{self.hls_sname()}" for i in range(self.heads) + ]), + ")", + # @formatter:off + ] + + # Generates C++ pragmas to be inserted into the main function of the C++ + # simulation and the ipgen-blackboxfunction as well + def pragmas(self): + # Add HLS interface directives specifying how to create RTL ports for + # the top-level function arguments + self.code_gen_dict["$PRAGMAS$"] = [ + # Connect the output stream with an axi stream interface + f"#pragma HLS INTERFACE axis port=out_{self.hls_sname()}" + ] + # Connect each input stream with an axi stream interface + for i in range(self.heads): + # Add new interface directive for the input stream + self.code_gen_dict["$PRAGMAS$"] += [ + f"#pragma HLS INTERFACE axis port=in{i}_{self.hls_sname()}" + ] + # No block-level I/O protocol for the function return value + self.code_gen_dict["$PRAGMAS$"].append( + "#pragma HLS INTERFACE ap_ctrl_none port=return" + ) + + # Returns the names of input and output interfaces grouped by protocol + def get_verilog_top_module_intf_names(self): + # Start collecting interface names in a dictionary starting with clock + # and reset + intf_names = {"clk": ["ap_clk"], "rst": ["ap_rst_n"]} # noqa + # AXI stream input interfaces + intf_names["s_axis"] = [ + # One input stream per head + (f"in{i}_{self.hls_sname()}", + self.get_instream_width_padded(ind=i)) for i in range(self.heads) + ] + # AXI stream output interfaces + intf_names["m_axis"] = [ + # Just one output stream + (f"out_{self.hls_sname()}", self.get_outstream_width_padded(ind=0)), + ] + # No AXI-MM, AXI-Lite or protocol-less interfaces + intf_names["aximm"] = [] + intf_names["axilite"] = [] + intf_names["ap_none"] = [] + # Return the interface name dictionary + return intf_names diff --git a/src/finn/custom_op/fpgadataflow/hls/attention_hls.py b/src/finn/custom_op/fpgadataflow/hls/attention_hls.py new file mode 100644 index 000000000..332313d8d --- /dev/null +++ b/src/finn/custom_op/fpgadataflow/hls/attention_hls.py @@ -0,0 +1,809 @@ +# fmt: off +# Disable formatter. This is deliberately formatted to stay within 80 characters +# per line. Black, however, formats some lines going beyond this. + +# Numpy math and arrays +import numpy as np +# Operating system stuff, e.g. paths +import os + +# QONNX/FINN datatypes +from qonnx.core.datatype import DataType +# QONNX wrapper to ONNX model graphs +from qonnx.core.modelwrapper import ModelWrapper +# Some utils for working with tensors in qonnx +from qonnx.util.basic import interleave_matrix_outer_dim_from_partitions + +# The generic HW custom operator version of the operator as a base class +from finn.custom_op.fpgadataflow.attention import ScaledDotProductAttention +# Base class for specializing HW operators as implemented via HLS +from finn.custom_op.fpgadataflow.hlsbackend import HLSBackend +# Convert and pack (numpy) data for C++ code generation +from finn.util.data_packing import numpy_to_hls_code + +# Mapping of memory resource attributes to the corresponding C++ HLS +# pragma directives +RAM_STYLES = { + "auto": "AUTO", "block": "BRAM", "distributed": "LUTRAM", "ultra": "URAM" +} + + +# HLS Backend specialization of the Scale Dot-product Attention Operator +class ScaledDotProductAttention_hls( # noqa: Class name does not follow + # CapWords convention + ScaledDotProductAttention, HLSBackend +): + # Node attributes matching the HLS operator + def get_nodeattr_types(self): + # Start from parent operator class attributes + attrs = ScaledDotProductAttention.get_nodeattr_types(self) + # Add the HLSBackend default attributes on top + attrs.update(HLSBackend.get_nodeattr_types(self)) + # Add/Specialize implementation specific attributes here... + # Return the updated attributes dictionary + return attrs + + # Executes the attention operator in C++ mode simulation + def _execute_node_cppsim(self, context, graph): # noqa: graph unused + # Get the node wrapped by this custom op + node = self.onnx_node + # Input data is stored in numpy files in the code generation dictionary + code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim") + + # By convention, inputs 0, 1 and 2 correspond to named inputs q, k and v + + # Read the input from the execution context and reshape to match the + # expected folding + q = context[node.input[0]].reshape(self.get_folded_input_shape(ind=0)) + # Save the folded inputs to file to be used by simulation + np.save(os.path.join(code_gen_dir, "q.npy"), q) + + # Read the input from the execution context and reshape to match the + # expected folding + k = context[node.input[1]].reshape(self.get_folded_input_shape(ind=1)) + # Save the folded inputs to file to be used by simulation + np.save(os.path.join(code_gen_dir, "k.npy"), k) + + # Read the input from the execution context and reshape to match the + # expected folding + v = context[node.input[2]].reshape(self.get_folded_input_shape(ind=2)) + # Save the folded inputs to file to be used by simulation + np.save(os.path.join(code_gen_dir, "v.npy"), v) + + # Optionally, the mask may be provided as an input as well + if self.get_nodeattr("mask_mode") == "input": + # Read the input from the execution context and reshape to match the + # expected folding + m = context[node.input[3]].reshape( + self.get_folded_input_shape(ind=3) + ) + # Save the folded inputs to file to be used by simulation + np.save(os.path.join(code_gen_dir, "m.npy"), m) + + # Execute the precompiled model + super().exec_precompiled_singlenode_model() + + # Load the output numpy file generated by the C++ simulation + out = np.load(os.path.join(code_gen_dir, "out.npy")) + # Reshape the folded output and insert into the execution context + context[self.onnx_node.output[0]] = out.reshape( + self.get_normal_output_shape(ind=0) + ) + + # Executes the attention operator in RTL mode simulation + def _execute_node_rtlsim(self, context, graph): # noqa: graph unused + # TODO: Implement rtlsim mode + # Note: Cannot even compile this right now due to missing float ips + raise NotImplementedError( + "exec_mode rtlsim is not implemented yet!" + ) + + # Maximum width of any ap_int used in this operator + def get_ap_int_max_w(self): + # Find the widths of the widest input + i_bits_max = max((self.get_instream_width(ind) for ind in range(3))) + # Find the widths of the widest output + o_bits_max = max((self.get_outstream_width(ind) for ind in range(1))) + # Assume no bits to represent the mask, if there is no mask + m_bits = 0 + # A mask received as input has a bit-width as well + if self.get_nodeattr("mask_mode") in {"input", "const"}: + # Parallelism is the number of elements in the last dimension of the + # folded mask input + _, _, elems = self.get_folded_input_shape(ind=3) + # Get width of the mask datatype + m_bits = elems * DataType[self.get_nodeattr("MType")].bitwidth() + + # Elements per folded key input (second input) + _, _, i_elems = self.get_folded_input_shape(ind=1) + # Elements per folded value input (third input), same as the number of + # output elements + _, _, o_elems = self.get_folded_input_shape(ind=2) + + # Parallelism is the number of elements in the last dimension of the + # folded attention weights + _, _, s_elems = self.get_folded_attention_shape() + # Number of bits used for the attention weights stream + a_bits = s_elems * DataType[self.get_nodeattr("AType")].bitwidth() + + # Maximum bits per tile of the key and value matrix streams + tile_bits_max = max([ + i_elems * s_elems * DataType[self.get_nodeattr("KType")].bitwidth(), + o_elems * s_elems * DataType[self.get_nodeattr("VType")].bitwidth(), + ]) + # Maximum bits per matmul accumulators + acc_bits_max = max([ + # These are not streamed, thus single element width is counted + DataType[self.get_nodeattr("AccQKMatMul")].bitwidth(), + DataType[self.get_nodeattr("AccAVMatMul")].bitwidth(), + ]) + # Maximum bits per matmul outputs + out_bits_max = max([ + # These are the stream widths, which are always >= than individual + # elements + s_elems * DataType[self.get_nodeattr("OutQKMatMul")].bitwidth(), + o_elems * DataType[self.get_nodeattr("OutAVMatMul")].bitwidth(), + ]) + # Aggregate the maximum bit width in both matmul operators over all + # inputs, intermediates and outputs + matmul_bits_max = max([ + tile_bits_max, acc_bits_max, out_bits_max + ]) + + # Find maximum of all (maximal) bit-widths + return max([i_bits_max, o_bits_max, m_bits, a_bits, matmul_bits_max]) + + # Generates list of C++ includes to be placed at the top of the generated + # code + def global_includes(self): + # FINN HLSLIB activation functions: e.g. PassThroughActivation + self.code_gen_dict["$GLOBALS$"] = ['#include "activations.hpp"'] + # Attention operator HLS code + self.code_gen_dict["$GLOBALS$"] += ['#include "attention.hpp"'] + + # Generates C++ parameters file, i.e. activation function thresholds + def generate_params(self, model: ModelWrapper, path): + # The code generation directory is specified as an argument, so this + # will work for both RTL and C++ simulation + code_gen_dir = path + + # Note: The attention operator itself has no weights to be generated as + # a parameter file + + # Start all three activations defaulting to pass-through of the + # accumulator type. + # Note: This might allow type-casts to the output types if they are + # not the same as the accumulators. + act_qk_matmul = "PassThroughActivation" + act_av_matmul = "PassThroughActivation" + act_a_softmax = "PassThroughActivation" + + # Start all thresholds defaulting to empty default initializer braces + thresholds_qk_matmul = "{}" + thresholds_av_matmul = "{}" + thresholds_a_softmax = "{}" + + # Prepares a threshold tensor as C++ string for code generation + def prepare_thresholds(ts, length, fold, dtype): + # Number of thresholds is given as the last dimension of the + # threshold tensor, first dimension is covering all output elements + num = ts.shape[-1] # noqa + # Partition the thresholds along the length into folds of parallel + # elements + ts = interleave_matrix_outer_dim_from_partitions(ts, length // fold) + # Reshape folded thresholds adding an outer dimension + # TODO: Why? MVAU does this, just copied the behavior. This is + # probably to generate the outer C++ initializer braces {} for + # object construction. Isn't it weird to rely on an artificial + # dimension just to have the code generator produce the correct + # string? + ts = ts.reshape(1, length // fold, fold, num) + # Format the thresholds as C++ array code + # Note: no packing, no variable name/type declaration + return numpy_to_hls_code(ts, dtype, "_", False, True), num + + # Get shape and folding configuration. None of the activations fold + # along the query-key embedding dimension or the query sequence length + (_, _, vdim, kvlen), (embfold, seqfold) = self.shapes, self.folds + + # Query-key matmul can have an optional activation function set to + # thresholding activations via node attribute + if self.get_nodeattr("ActQKMatMul") == "thresholds": + # In this case there will be a thresholds parameter initializer + thresholds = model.get_initializer( + self.get_input_name_by_name("thresholds_qk_matmul") + ) + # Get the datatype of the thresholds + thresholds_dtype = DataType[self.get_nodeattr("AccQKMatMul")] + # Activation value, i.e., bias applied after thresholding activation + bias = self.get_nodeattr("BiasActQKMatMul") + # No support for floating-point bias + assert int(bias) == bias, "BiasActQKMatMul must be integer" + # Convert the bias to integer representation, so it can be used as a + # template argument + bias = int(bias) + # Format the thresholds as C++ array code: QK matmul outputs fold + # along the key-value sequence length dimension + thresholds_qk_matmul, num = prepare_thresholds( + thresholds, kvlen, seqfold, thresholds_dtype + ) + # Get the HLS datatype string corresponding to the thresholds + # datatype for C++ code generation + dtype_str = thresholds_dtype.get_hls_datatype_str() + # Replace default pass-through activation by thresholding activation + # Note: Relies on type and shape definitions generated by the + # "defines" method + act_qk_matmul = "\n".join([ + "ThresholdsActivation<", + " SeqFold," + " KVLen/SeqFold," + f" {num}," + " AccQKMatMul," + " OutQKMatMul," + f" {bias}," + # Note: Not sure why the default comp::less does not work... + f" comp::less_equal<{dtype_str}, {dtype_str}>", + ">" + ]) + + # Softmax can have an optional activation function set to thresholding + # activations via node attribute + if self.get_nodeattr("ActASoftmax") == "thresholds": + # In this case there will be a thresholds parameter initializer + thresholds = model.get_initializer( + self.get_input_name_by_name("thresholds_a_softmax") + ) + # Get the datatype of the thresholds + thresholds_dtype = DataType[self.get_nodeattr("AccASoftmax")] + # Activation value, i.e., bias applied after thresholding activation + bias = self.get_nodeattr("BiasActASoftmax") + # No support for floating-point bias + assert int(bias) == bias, "BiasActASoftmax must be integer" + # Convert the bias to integer representation, so it can be used as a + # template argument + bias = int(bias) + # Format the thresholds as C++ array code: Softmax outputs fold + # along the key-value sequence length dimension + thresholds_a_softmax, num = prepare_thresholds( + thresholds, kvlen, seqfold, thresholds_dtype + ) + # Get the HLS datatype string corresponding to the thresholds + # datatype for C++ code generation + dtype_str = thresholds_dtype.get_hls_datatype_str() + # Replace default pass-through activation by thresholding activation + # Note: Relies on type and shape definitions generated by the + # "defines" method + act_a_softmax = "\n".join([ + "ThresholdsActivation<", + " SeqFold," + " KVLen/SeqFold," + f" {num}," + " AccASoftmax," + " AType," + f" {bias}," + # Note: Not sure why the default comp::less does not work... + f" comp::less_equal<{dtype_str}, {dtype_str}>", + ">" + ]) + + # Attention-value matmul can have an optional activation function set to + # thresholding activations via node attribute + if self.get_nodeattr("ActAVMatMul") == "thresholds": + # In this case there will be a thresholds parameter initializer + thresholds = model.get_initializer( + self.get_input_name_by_name("thresholds_av_matmul") + ) + # Get the datatype of the thresholds + thresholds_dtype = DataType[self.get_nodeattr("AccAVMatMul")] + # Activation value, i.e., bias applied after thresholding activation + bias = self.get_nodeattr("BiasActAVMatMul") + # No support for floating-point bias + assert int(bias) == bias, "BiasActAVMatMul must be integer" + # Convert the bias to integer representation, so it can be used as a + # template argument + bias = int(bias) + # Format the thresholds as C++ array code: AV matmul outputs fold + # along the value embedding dimension + thresholds_av_matmul, num = prepare_thresholds( + thresholds, vdim, embfold, thresholds_dtype + ) + # Get the HLS datatype string corresponding to the thresholds + # datatype for C++ code generation + dtype_str = thresholds_dtype.get_hls_datatype_str() + # Replace default pass-through activation by thresholding activation + # Note: Relies on type and shape definitions generated by the + # "defines" method + act_av_matmul = "\n".join([ + "ThresholdsActivation<", + " EmbFold," + " VDim/EmbFold," + f" {num}," + " AccAVMatMul," + " OutAVMatMul," + f" {bias}," + # Note: Not sure why the default comp::less does not work... + f" comp::less_equal<{dtype_str}, {dtype_str}>", + ">" + ]) + + # Assume no attention mask as a default: Generate C++ code of tag + # instance of "none" mask type + attention_mask = \ + "static const auto attention_mask = attention::mask::NONE" + + # If a causal mask is specified, set the appropriate tag dispatching + # instance + if self.get_nodeattr("mask_mode") == "causal": + # Generate C++ code of tag instance of causal mask type + attention_mask = \ + "static const auto attention_mask = attention::mask::CAUSAL" + + # If a constant mask is specified, array code needs to be generated + if self.get_nodeattr("mask_mode") == "const": + # Attention mask type of folded constant mask array + mask_type = "attention::mask::Const" + # Get the constant mask values + mask = model.get_initializer(self.get_input_name_by_name("M")) + # Num should always be equal to QLen + num = mask.shape[-1] + # Partition the mask along the length into folds of parallel + # elements + mask = interleave_matrix_outer_dim_from_partitions( + mask, kvlen // seqfold + ) + # Reshape folded mask adding an outer dimension + mask = mask.reshape(num, kvlen // seqfold, seqfold).squeeze() + # Format the mask as C++ array code + # Note: no packing, no variable name/type declaration + mask = numpy_to_hls_code(mask, DataType["BINARY"], "_", False, True) + # Generate C++ code initializing the constant mask array + attention_mask = f"static const {mask_type} attention_mask = {mask}" + + # If a mask is provided as input, no object parameters need to be + # generated here + if self.get_nodeattr("mask_mode") == "input": + # Attention mask type of input stream + mask_type = "attention::mask::Input" + # Generate C++ code creating an input stream instance for the mask + # Note: This is just a dummy, the real input stream will be part + # of the operator interface + attention_mask = f"static const {mask_type} attention_mask;" + + # Open a file to store the thresholds parameters as C++ code + with open(f"{code_gen_dir}/params.hpp", "w") as file: + # Write lines of C++ code separated by newlines to the file + file.write("\n".join([ + # Scale factor preceding the softmax activation function to + # dequantize the input to floating-point representation + "static const float dequant_softmax =" + f" {self.get_nodeattr('DequantSoftmax')};", + # Attention mask parameters if "none", "causal" or "const" + f"{attention_mask};", + # Type alias to the generated attention mask for convenience + "using AttentionMask = decltype(attention_mask);", + # Add type definition and threshold initialization of the + # query-key matmul activation + f"using ActQKMatMul = {act_qk_matmul};", + f"ActQKMatMul act_qk_matmul = {thresholds_qk_matmul};", + # Add type definition and threshold initialization of the + # attention-value matmul activation + f"using ActAVMatMul = {act_av_matmul};", + f"ActAVMatMul act_av_matmul = {thresholds_av_matmul};", + # Add type definition and threshold initialization of the + # softmax activation + f"using ActASoftmax = {act_a_softmax};", + f"ActASoftmax act_a_softmax = {thresholds_a_softmax};", + # Append a newline at the end of the file (to avoid problems + # when including, required by C standard?) + "\n" + ])) + + # Generates C++ code of type alias, global constant and macro definitions + def defines(self, var): + # Generate shape definitions from attributes to C++ constant definitions + def shapedefs(*names): + # C++ qualified type to be used for shape constants + shape = "static constexpr std::size_t" + # Generate a C++ constant definition for each of the attributes + # given by argument list names + return ( + f"{shape} {name} = {self.get_nodeattr(name)};" for name in names + ) + + # Generate datatype definitions mapping from QONNX DataType to HLS type + def typedefs(*names): + # Gets the HLS type string for the datatype specified by the named + # attribute + def hls_type(name): + # Looks up the datatype specified for the attribute and + # translates from QONNX to HLS type + return DataType[self.get_nodeattr(name)].get_hls_datatype_str() + + # Generate a C++ type alias definition for each of the attributes + # given by argument list names + return (f"using {name} = {hls_type(name)};" for name in names) + + # Attribute specifying the memory to use for internal buffers + ram_style = self.get_nodeattr("ram_style") + # Attribute specifying the resources to use for implementing MAC + # operations + mac_resource = self.get_nodeattr("mac_resource") + + # Mapping of memory resource attributes to the corresponding C++ tag + # types + mem_resources = { + "auto": "Resource::AUTO", + "block": "Resource::BRAM", + "distributed": "Resource::LUTRAM", + "ultra": "Resource::URAM" + } + # Mapping of compute resource attributes to the corresponding C++ tag + # types + compute_resources = { + "auto": "ap_resource_dflt", + "lut": "ap_resource_lut", + "dsp": "ap_resource_dsp" + } + + # Insert constants and type aliases into the dictionary + self.code_gen_dict["$DEFINES$"] = [ + # Shape constant definitions of attention inputs (query, key and + # value) and folding configuration + *shapedefs( + "QKDim", + "QLen", + "VDim", + "KVLen", + "EmbFold", + "SeqFold" + ), + # Type alias definitions for all input, output and intermediate + # datatypes + *typedefs( + "QType", + "KType", + "VType", + "MType", + "AType", + "OType" + ), + # Type alias definitions for the matmul accumulators and output + # datatypes + *typedefs( + "AccQKMatMul", + "OutQKMatMul", + "AccAVMatMul", + "OutAVMatMul", + "AccASoftmax" + ), + # Type alias definitions for the resource type selection tags + f"using MacResource = {compute_resources[mac_resource]};", + f"using MemResource = {mem_resources[ram_style]};", + # Include the activation function type definitions and parameters + # Note: The typedefs in this header require the typedefs above, + # thus adding this to the global includes is not possible. + '#include "params.hpp"', + # Type alias of the properly configured attention operator class + "using Attention = ScaledDotProductAttention<", + " QKDim,", + " QLen,", + " VDim,", + " KVLen,", + " EmbFold,", + " SeqFold,", + " QType,", + " KType,", + " VType,", + " MType,", + " AType,", + " OType,", # Note: OType and last MatMul out must match + " AccQKMatMul,", + " OutQKMatMul,", + " ActQKMatMul,", + " AccAVMatMul,", + " OType,", # Note: OType and last MatMul out must match + " ActAVMatMul,", + " ActASoftmax,", + " MacResource,", + " MemResource" + ">;", + # Short type aliases of attention input and output streams + "using QStream = Attention::QStream;", + "using KStream = Attention::KStream;", + "using VStream = Attention::VStream;", + "using OStream = Attention::OStream;", + "using MStream = Attention::MStream;", + ] + + # Generates C++ code for reading data from .npy (numpy format) for testing + # in C++ simulation + def read_npy_data(self): + # Input data is stored in numpy files in the code generation dictionary + code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim") + + # Generate function calls for reading the input files into the input + # streams + self.code_gen_dict["$READNPYDATA$"] = [ + # Deduce the datatype of elements packed into the query input stream + # TODO: Maybe these type-deductions can be removed by changing the + # order of the template arguments of the npy2apintstream, such + # that type-deduction is handled there? + 'using QPacked = decltype(QStream{}.read());', + # Generate function call reading from file into the input stream + # Note: Inputs are always represented as numpy floats + 'npy2apintstream(', + f' "{code_gen_dir}/q.npy", q_{self.hls_sname()}, false', + ');', + + # Deduce the datatype of elements packed into the key input stream + 'using KPacked = decltype(KStream{}.read());', + # Generate function call reading from file into the input stream + # Note: Inputs are always represented as numpy floats + 'npy2apintstream(', + f' "{code_gen_dir}/k.npy", k_{self.hls_sname()}, false', + ');', + + # Deduce the datatype of elements packed into the value input stream + 'using VPacked = decltype(VStream{}.read());', + # Generate function call reading from file into the input stream + # Note: Inputs are always represented as numpy floats + 'npy2apintstream(', + f' "{code_gen_dir}/v.npy", v_{self.hls_sname()}, false', + ');', + ] + + # If the mask is provided as an input, it needs to be read as well + if self.get_nodeattr("mask_mode") == "input": + # Generate function call for reading the mask file into the input + # stream + self.code_gen_dict["$READNPYDATA$"] += [ + # Deduce the datatype of elements packed into the mask input + # stream + 'using MPacked = decltype(MStream{}.read());', + # Generate function call reading from file into the input stream + # Note: Inputs are always represented as numpy floats + 'npy2apintstream(', + f' "{code_gen_dir}/m.npy", m_{self.hls_sname()}, false', + ');', + ] + + # Generates C++ code for declaring all streams involved in C++ simulation + # for testing + def strm_decl(self): + # Declare input (query, key, value) and output streams + self.code_gen_dict["$STREAMDECLARATIONS$"] = [ + # Note: Assumes stream type aliases to be set in defines + f"QStream q_{self.hls_sname()};", + f"KStream k_{self.hls_sname()};", + f"VStream v_{self.hls_sname()};", + f"OStream out_{self.hls_sname()};" + ] + # If the mask is provided as an input, it needs a stream declaration as + # well + if self.get_nodeattr("mask_mode") == "input": + # Append the mask stream to the declaration list + self.code_gen_dict["$STREAMDECLARATIONS$"] += [ + # Note: Assumes stream type aliases to be set in defines + f"MStream m_{self.hls_sname()};", + ] + + # Generates C++ code for calling the computation part of the operator + def docompute(self): + # Convert the thresholds RAM style attribute to HLS directive + ram_style_thresholds = RAM_STYLES[ + self.get_nodeattr("ram_style_thresholds") + ] + # Convert the attention mask RAM style attribute to HLS directive + ram_style_mask = RAM_STYLES[self.get_nodeattr("ram_style_mask")] + + # Generates the "BIND_STORAGE" pragma for the threshold activations + # threshold memory of "name" + def bind_threshold_storage(name: str): + return (f"#pragma HLS BIND_STORAGE variable={name}" + f" type=ROM_2P impl={ram_style_thresholds}") + + # Generates the ARRAY_PARTITION pragma for the threshold activations + # threshold memory of "name" and along dimension "dim" + def partition_thresholds_array(name: str, dim: int): + return (f"#pragma HLS ARRAY_PARTITION variable={name}" + f" complete dim={dim}") + + # Collect pragmas which need to be inserted into the DOCOMPUTE code + pragmas = [] + + # If there are thresholds activations following the query-key matmul, + # these need storage and array partition pragmas + if self.get_nodeattr("ActQKMatMul") == "thresholds": + # Add pragma compiler directives to the list of pragmas inserted + # into the DOCOMPUTE + pragmas.extend([ + # Partition the thresholds array along the PE (dim=1) and number + # of thresholds (dim=3) axis for parallel access + partition_thresholds_array( + "attention.qk_matmul.activation.m_thresholds", dim=1 + ), + partition_thresholds_array( + "attention.qk_matmul.activation.m_thresholds", dim=3 + ), + # Implement the thresholds array as a dual-port ROM with the + # RAM-Style selected via attribute + bind_threshold_storage( + "attention.qk_matmul.activation.m_thresholds" + ) + ]) + + # If there are thresholds activations following the attention-value + # matmul, these need storage and array partition pragmas + if self.get_nodeattr("ActAVMatMul") == "thresholds": + # Add pragma compiler directives to the list of pragmas inserted + # into the DOCOMPUTE + pragmas.extend([ + # Partition the thresholds array along the PE (dim=1) and number + # of thresholds (dim=3) axis for parallel access + partition_thresholds_array( + "attention.av_matmul.activation.m_thresholds", dim=1 + ), + partition_thresholds_array( + "attention.av_matmul.activation.m_thresholds", dim=3 + ), + # Implement the thresholds array as a dual-port ROM with the + # RAM-Style selected via attribute + bind_threshold_storage( + "attention.av_matmul.activation.m_thresholds" + ) + ]) + + # If there are thresholds activations following the softmax + # normalization, these need storage and array partition pragmas + if self.get_nodeattr("ActASoftmax") == "thresholds": + # Add pragma compiler directives to the list of pragmas inserted + # into the DOCOMPUTE + pragmas.extend([ + # Partition the thresholds array along the PE (dim=1) and number + # of thresholds (dim=3) axis for parallel access + partition_thresholds_array( + "attention.softmax.activation.m_thresholds", dim=1 + ), + partition_thresholds_array( + "attention.softmax.activation.m_thresholds", dim=3 + ), + # Implement the thresholds array as a dual-port ROM with the + # RAM-Style selected via attribute + bind_threshold_storage( + "attention.softmax.activation.m_thresholds" + ) + ]) + + # If a constant mask is specified, there needs to be storage and array + # partition pragmas to be inserted + if self.get_nodeattr("mask_mode") == "const": + # Note: Probably no need for partitioning this array, as the PE + # dimension is packed into the datatype (which is a bitvector with + # one bit per element, i.e., per PE) + # Implement the attention mask array as a dual-port ROM with the + # RAM-Style selected via attribute + pragmas.extend([ + f"#pragma HLS BIND_STORAGE variable=attention_mask" + f" type=ROM_2P impl={ram_style_mask}" + ]) + + # Write the body of the attention top-level function + self.code_gen_dict["$DOCOMPUTE$"] = [ + # Instantiate the attention operator and connect to the generated + # threshold parameters + # Note: Assumes "Attention" to be aliased and configured in defines + # Note: Assumes parameters to be generated in 'generate_params' and + # made available via include/defines before. + "Attention attention {", + " act_qk_matmul, act_av_matmul, act_a_softmax, dequant_softmax", + "};", + # Insert some more pragmas here to be able to configure + # implementation details of components internal to "attention" + *pragmas, + # Connect the attention operator to the input and output streams + "attention(" + f"q_{self.hls_sname()}, " + f"k_{self.hls_sname()}, " + f"v_{self.hls_sname()}, " + f"out_{self.hls_sname()}, " + # TODO: Does not work for "input" mode mask + "attention_mask" + ");", + ] + + # Generates C++ code for reading the output stream and converting back to + # numpy format for testing in C** simulation + def dataoutstrm(self): + # Output data will be stored in numpy files in the code generation + # dictionary + code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim") + # Get the expected shape of the folded output array formatted as a C++ + # vector initializer + # Note: Valid formatting relies on correct placement of curly braces + # and line breaks: Open/close all three braces on the same line of code + # to avoid '\n' to be inserted into the string + shape = f"""{{{ + ','.join((str(i) for i in self.get_folded_output_shape())) + }}}""" + # Generate function call for reading from the output stream into the + # output file + self.code_gen_dict["$DATAOUTSTREAM$"] = [ + # Deduce the datatype of elements packed into the output stream + 'using OPacked = decltype(OStream{}.read());', + # Generate function call reading from stream into the output file + # Note: Outputs are always represented as numpy floats + 'apintstream2npy(', + f'out_{self.hls_sname()}, {shape}, "{code_gen_dir}/out.npy", false', + ');', + ] + + # Generates C++ code for saving the output of C++ simulation to a file in + # numpy format + def save_as_npy(self): + # Note: This seems to be empty in ALL HLSCustomOps. Probably it was used + # for something before, which is now integrated into dataoutstrm()? + self.code_gen_dict["$SAVEASCNPY$"] = [] + + # Generates essentially the head of the C++ function from which the IP block + # will be generated during ipgen, i.e. actual synthesis + def blackboxfunction(self): + # Insert function head describing the top level interface of the + # attention operator + self.code_gen_dict["$BLACKBOXFUNCTION$"] = [ + # Note: Assumes stream type aliases to be set in defines + f"void {self.onnx_node.name} (", + f" QStream &q_{self.hls_sname()}," + f" KStream &k_{self.hls_sname()}," + f" VStream &v_{self.hls_sname()}," + f" OStream &out_{self.hls_sname()}", + ")", + ] + + # Generates C++ pragmas to be inserted into the main function of the C++ + # simulation and the ipgen-blackboxfunction as well + def pragmas(self): + # Add HLS interface directives specifying how to create RTL ports for + # the top-level function arguments + self.code_gen_dict["$PRAGMAS$"] = [ + # Connect the query input stream with an axi stream interface + f"#pragma HLS INTERFACE axis port=q_{self.hls_sname()}", + # Connect the key input stream with an axi stream interface + f"#pragma HLS INTERFACE axis port=k_{self.hls_sname()}", + # Connect the value input stream with an axi stream interface + f"#pragma HLS INTERFACE axis port=v_{self.hls_sname()}", + # Connect the output stream with an axi stream interface + f"#pragma HLS INTERFACE axis port=out_{self.hls_sname()}", + ] + # No block-level I/O protocol for the function return value + self.code_gen_dict["$PRAGMAS$"].append( + "#pragma HLS INTERFACE ap_ctrl_none port=return" + ) + + # Returns the names of input and output interfaces grouped by protocol + def get_verilog_top_module_intf_names(self): + # Start collecting interface names in a dictionary starting with clock + # and reset + intf_names = {"clk": ["ap_clk"], "rst": ["ap_rst_n"]} # noqa + # AXI stream input interfaces + intf_names["s_axis"] = [ + (f"q_{self.hls_sname()}", self.get_instream_width_padded(ind=0)), + (f"k_{self.hls_sname()}", self.get_instream_width_padded(ind=1)), + (f"v_{self.hls_sname()}", self.get_instream_width_padded(ind=2)) + ] + # AXI stream output interfaces + intf_names["m_axis"] = [ + (f"out_{self.hls_sname()}", self.get_outstream_width_padded(ind=0)) + ] + # No AXI-MM, AXI-Lite or protocol-less interfaces + intf_names["aximm"] = [] + intf_names["axilite"] = [] + intf_names["ap_none"] = [] + # Return the interface name dictionary + return intf_names + + # Prepare for RTL simulation: There is no RTL simulation of the attention + # operator for now + def prepare_rtlsim(self): + # This attribute must be present anyway, but it is ok if it points + # nowhere as long as execute_node doe not ry to execute the rtlsim + self.set_nodeattr("rtlsim_so", "none") diff --git a/src/finn/custom_op/fpgadataflow/hls/replicate_stream_hls.py b/src/finn/custom_op/fpgadataflow/hls/replicate_stream_hls.py new file mode 100644 index 000000000..84631f5dd --- /dev/null +++ b/src/finn/custom_op/fpgadataflow/hls/replicate_stream_hls.py @@ -0,0 +1,253 @@ +# fmt: off +# Disable formatter. This is deliberately formatted to stay within 80 characters +# per line. Black, however, formats some lines going beyond this. + +# Numpy math and arrays +import numpy as np +# Operating system stuff, e.g. paths +import os + +# Base class for specializing HW operators as implemented via HLS +from finn.custom_op.fpgadataflow.hlsbackend import HLSBackend +# The generic HW custom operator version of the operator as a base class +from finn.custom_op.fpgadataflow.replicate_stream import ReplicateStream + + +# HLS Backend specialization of the stream-replication operator +class ReplicateStream_hls( # noqa: Class name does not follow + # CapWords convention + ReplicateStream, HLSBackend +): + # Node attributes matching the HLS operator + def get_nodeattr_types(self): + # Start from parent operator class attributes + attrs = ReplicateStream.get_nodeattr_types(self) + # Add the HLSBackend default attributes on top + attrs.update(HLSBackend.get_nodeattr_types(self)) + # Add/Specialize implementation specific attributes here... + # Return the updated attributes dictionary + return attrs + + # Executes replicating inputs in C++ simulation + def _execute_node_cppsim(self, context, graph): # noqa: graph unused + # Get the node wrapped by this custom op # noqa Duplicate + node = self.onnx_node # noqa Duplicate + # Input data is stored in numpy files in the code generation dictionary + code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim") + # Get the input out of the execution context + inp = context[node.input[0]] + # Validate the shape of the input + assert inp.shape == self.get_normal_input_shape(ind=0), \ + f"Input shape mismatch for {node.input[0]}" + # Reshape the input into folded form + inp = inp.reshape(self.get_folded_input_shape(ind=0)) + # Save the folded inputs to file to be used by simulation + np.save(os.path.join(code_gen_dir, "in.npy"), inp) + + # Execute the precompiled model + super().exec_precompiled_singlenode_model() + + # Enumerate the node outputs + for i, name in enumerate(node.output): + # Load the output numpy file generated by the C++ simulation + out = np.load(os.path.join(code_gen_dir, f"out{i}.npy")) + # Reshape the folded output and insert into the execution context + context[name] = out.reshape(self.get_normal_output_shape(ind=i)) + + # Maximum width of any ap_int used in this operator + def get_ap_int_max_w(self): + # Find the widths of the widest input + # Note: There is just one input. + i_bits_max = self.get_instream_width(ind=0) + # Find the widths of the widest output + # Note: there is one output per replica + o_bits_max = max( + (self.get_outstream_width(ind) for ind in range(self.num)) + ) + # Find the biggest of the inputs/outputs + return max([i_bits_max, o_bits_max]) + + # Note: End of shape and datatype utilities + + # Generates list of C++ includes to be placed at the top of the generated + # code + def global_includes(self): + # Currently nothing to include + self.code_gen_dict["$GLOBALS$"] = [] + + # Generates C++ code of type alias, global constant and macro definitions + def defines(self, var): + # Insert constants and type aliases into the dictionary + self.code_gen_dict["$DEFINES$"] = [ + # Input and output element datatypes + f"using IType = {self.dtype.get_hls_datatype_str()};", + f"using OType = {self.dtype.get_hls_datatype_str()};", + # Width of single elements to avoid using ::width attribute which is + # not present for datatype float + f"static constexpr auto ElemWidth = {self.dtype.bitwidth()};" + # Datatype of elements packed into the input stream + f"using IPacked = ap_uint<{self.get_instream_width()}>;", + # Datatype of elements packed into the output stream + f"using OPacked = ap_uint<{self.get_outstream_width()}>;", + # Input and output HLS stream datatypes + "using IStream = hls::stream<" + f" ap_uint<{self.get_instream_width()}>" + ">;", + "using OStream = hls::stream<" + f" ap_uint<{self.get_outstream_width()}>" + ">;", + ] + + # Generates C++ code for reading data from .npy (numpy format) for testing + # in C++ simulation + def read_npy_data(self): + # Input data is stored in numpy files in the code generation dictionary + code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim") + # Generate function calls for reading the input files into the input + # streams + self.code_gen_dict["$READNPYDATA$"] = [ + # Generate function call reading from file into the input stream + # Note: Inputs are always represented as numpy floats + 'npy2apintstream(', + f'"{code_gen_dir}/in.npy", in_{self.hls_sname()}, false', + ');' + ] + + # Generates C++ code for declaring all streams involved in C++ simulation + # for testing + def strm_decl(self): + # Declare input and output streams + # Note: Assumes stream type aliases to be set in defines + self.code_gen_dict["$STREAMDECLARATIONS$"] = [ + # There is one input datastream + f"IStream in_{self.hls_sname()};", + # There is one output datastream per replica + *(f"OStream out{i}_{self.hls_sname()};" for i in range(self.num)) + ] + + # Generates C++ code for calling the computation part of the operator + def docompute(self): + # Generates the name of the ith output stream + def out(i): + return f"out{i}_{self.hls_sname()}" + + # Number of iterations required to process the whole folded input stream + # Note: This is all but the PE (last) dimension + num_iter = np.prod(self.get_folded_output_shape()[:-1]) + + # Write the body of the stream replicating top-level function + self.code_gen_dict["$DOCOMPUTE$"] = [ + # Repeat for the number of inputs + # Note: Repeat for all num_inputs dimensions + f"for(std::size_t i = 0; i < {num_iter}; ++i) {{", + # Pipeline the steps of this loop + "#pragma HLS pipeline II=1 style=flp", + # Read the next input element from the stream + f"const auto x = in_{self.hls_sname()}.read();", + # Write the same input element into each output stream + *(f"{out(i)}.write(x);" for i in range(self.num)), + # End of for-loop over repetitions body + f"}}" # noqa: f-string symmetry + ] + + # Generates C++ code for reading the output stream and converting back to + # numpy format for testing in C++ simulation + def dataoutstrm(self): + # Output data will be stored in numpy files in the # noqa Duplicate + # code generation dictionary + code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim") + # Get the expected shape of the folded output array formatted as a C++ + # vector initializer + # Note: Valid formatting relies on correct placement of curly braces + # and line breaks: Open/close all three braces on the same line of code + # to avoid '\n' to be inserted into the string + shape = f"""{{{ + ','.join((str(i) for i in self.get_folded_output_shape())) + }}}""" + # Start collecting function calls to write the output data stream + self.code_gen_dict["$DATAOUTSTREAM$"] = [] + + # Generates the name of the ith output stream + def out(i): + return f"out{i}_{self.hls_sname()}" + + # Generate code for each output stream + for i in range(self.num): + # Append each reading/writing function call + self.code_gen_dict["$DATAOUTSTREAM$"] += [ + # Generate function call reading from stream into the output + # file + # Note: Outputs are always represented as numpy floats + 'apintstream2npy(', + f'{out(i)}, {shape}, "{code_gen_dir}/out{i}.npy", false', + ');' + ] + + # Generates C++ code for saving the output of C++ simulation to a file in + # numpy format + def save_as_npy(self): + # Note: This seems to be empty in ALL HLSCustomOps. Probably it was used + # for something before, which is now integrated into dataoutstrm()? + self.code_gen_dict["$SAVEASCNPY$"] = [] + + # Generates essentially the head of the C++ function from which the IP block + # will be generated during ipgen, i.e. actual synthesis + def blackboxfunction(self): + # Insert function head describing the top level interface of the stream + # replicating operator + self.code_gen_dict["$BLACKBOXFUNCTION$"] = [ + # @formatter:off Prevent Python formatter from messing with C++ + # formatting + # Note: Assumes stream type aliases to be set in defines + f"void {self.onnx_node.name} (", + # Input HLS stream + f"IStream &in_{self.hls_sname()}, ", ",".join([ + # One output HLS stream per replica # noqa: Formatting + f"OStream &out{i}_{self.hls_sname()}" for i in range(self.num) + ]), + ")", + # @formatter:off + ] + + # Generates C++ pragmas to be inserted into the main function of the C++ + # simulation and the ipgen-blackboxfunction as well + def pragmas(self): + # Add HLS interface directives specifying how to create RTL ports for + # the top-level function arguments + self.code_gen_dict["$PRAGMAS$"] = [ + # Connect the input stream with an axi stream interface + f"#pragma HLS INTERFACE axis port=in_{self.hls_sname()}" + ] + # Connect each output stream with an axi stream interface + for i in range(self.num): + # Add new interface directive for the output stream + self.code_gen_dict["$PRAGMAS$"] += [ + f"#pragma HLS INTERFACE axis port=out{i}_{self.hls_sname()}" + ] + # No block-level I/O protocol for the function return value + self.code_gen_dict["$PRAGMAS$"].append( + "#pragma HLS INTERFACE ap_ctrl_none port=return" + ) + + # Returns the names of input and output interfaces grouped by protocol + def get_verilog_top_module_intf_names(self): + # Start collecting interface names in a dictionary # noqa Duplicate + # starting with clock and reset + intf_names = {"clk": ["ap_clk"], "rst": ["ap_rst_n"]} # noqa + # AXI stream input interfaces + intf_names["s_axis"] = [ + # Just one input stream + (f"in_{self.hls_sname()}", self.get_instream_width_padded(ind=0)), + ] + # AXI stream output interfaces + intf_names["m_axis"] = [ + # One output stream per replica + (f"out{i}_{self.hls_sname()}", + self.get_outstream_width_padded(ind=i)) for i in range(self.num) + ] + # No AXI-MM, AXI-Lite or protocol-less interfaces + intf_names["aximm"] = [] + intf_names["axilite"] = [] + intf_names["ap_none"] = [] + # Return the interface name dictionary + return intf_names diff --git a/src/finn/custom_op/fpgadataflow/hlsbackend.py b/src/finn/custom_op/fpgadataflow/hlsbackend.py index 4677960ea..1063adb46 100644 --- a/src/finn/custom_op/fpgadataflow/hlsbackend.py +++ b/src/finn/custom_op/fpgadataflow/hlsbackend.py @@ -242,6 +242,9 @@ def compile_singlenode_code(self): builder.append_includes("-I$FINN_ROOT/src/finn/qnn-data/cpp") builder.append_includes("-I$FINN_ROOT/deps/cnpy/") builder.append_includes("-I$FINN_ROOT/deps/finn-hlslib") + # TODO: Is it ok to add this here? Add some specialization to the + # attention operator? Eventually integrate this into the finn-hlslib? + builder.append_includes("-I$FINN_ROOT/deps/attention-hlslib") builder.append_includes("-I$FINN_ROOT/custom_hls") builder.append_includes("-I{}/include".format(os.environ["HLS_PATH"])) builder.append_includes("--std=c++14") diff --git a/src/finn/custom_op/fpgadataflow/replicate_stream.py b/src/finn/custom_op/fpgadataflow/replicate_stream.py new file mode 100644 index 000000000..b593da1c7 --- /dev/null +++ b/src/finn/custom_op/fpgadataflow/replicate_stream.py @@ -0,0 +1,295 @@ +# fmt: off +# Disable formatter. This is deliberately formatted to stay within 80 characters +# per line. Black, however, formats some lines going beyond this. + +# Numpy math and arrays +import numpy as np +# Operating system stuff, e.g. paths +import os +# Python warning subsystem +import warnings + +# Helper for creating ONNX nodes +from onnx import helper as oh + +# QONNX/FINN datatypes +from qonnx.core.datatype import DataType +# QONNX wrapper to ONNX model graphs +from qonnx.core.modelwrapper import ModelWrapper +# Derive custom operators form the FINN base custom op +from finn.custom_op.fpgadataflow.hwcustomop import HWCustomOp +# Converts inputs/outputs to/from RTL simulation format +from finn.util.data_packing import npy_to_rtlsim_input, rtlsim_output_to_npy + + +# Replicates an input stream to arbitrary many output streams +# See DuplicateStreams_Batch for feeding exactly two streams +class ReplicateStream(HWCustomOp): + # Initializes the operator given an onnx graph node + def __init__(self, onnx_node, **kwargs): + # Just forward all arguments to the init method of the CustomOp base + super().__init__(onnx_node, **kwargs) + + # Need to override the default depths of outputs FIFOs here as these + # depend on the number of replicas, which are not known during calls to + # get_nodeattr_types. + if not self.get_nodeattr("outFIFODepths"): + self.set_nodeattr("outFIFODepths", [2 for _ in range(self.num)]) + + # Defines attributes which must be present on this node + def get_nodeattr_types(self): + # Start from parent operator class attributes + attrs = HWCustomOp.get_nodeattr_types(self) + # Update attributes dictionary for new custom operator + attrs.update({ + # Number of replicas to produce + "num": ("i", True, 1), + # Data type of input and output elements + "dtype": ("s", True, ""), + # Number of input elements in the last dimension + "num_elems": ("i", True, 1), + # Number of elements in the last dimensions processed in parallel + "PE": ("i", True, 1), + # Number of inputs to be processed sequentially + "num_inputs": ("ints", True, [1]), + # Possible execution modes for simulating this node + # Note: Override to support python mode + "exec_mode": ( + "s", False, "python", {"", "rtlsim", "cppsim", "python"} + ), + # Input and output FIFO depths for multi-I/O nodes + # Note: Need to override here as there multiple outputs + "inFIFODepths": ("ints", False, [2]), + "outFIFODepths": ("ints", False, []), # Default will be override + }) + # Return updated attribute dictionary + return attrs + + # Number of replicas attribute as property for convenience + @property + def num(self): + return self.get_nodeattr("num") + + # Datatype attribute as property for convenience + @property + def dtype(self): + # Note: Converts from string to QONNX data type + return DataType[self.get_nodeattr("dtype")] + + # Number of elements attribute as property for convenience + @property + def num_elems(self): + return self.get_nodeattr("num_elems") + + # Number of parallel processed elements as property for convenience + @property + def pe(self): + return self.get_nodeattr("PE") + + # Number of inputs attribute as property for convenience + @property + def num_inputs(self): + return self.get_nodeattr("num_inputs") + + # Makes an operation compatible with the output shape for shape inference + # Note: Propagates shape forward, i.e., never asks for the shape of the + # output, even if it seems easier. + def make_shape_compatible_op(self, model: ModelWrapper): # noqa + # Get the node wrapped by this custom op + node = self.onnx_node + # Prepare a dummy input to simulate a large input that can be split into + # the desired number and shapes of outputs + mock_input = model.make_new_valueinfo_name() + # Simulate an input of number of replicas many elements + model.set_tensor_shape( + mock_input, [*self.num_inputs, self.num * self.num_elems] + ) + # Simulate behavior via the standard ONNX split operation + return oh.make_node( + "Split", [mock_input], node.output, num_outputs=self.num, axis=-1 + ) + + # Infers the datatype of the node output + def infer_node_datatype(self, model: ModelWrapper): # noqa + # Get the node wrapped by this custom op # noqa Duplicate + node = self.onnx_node + # Test for changing input datatype + if model.get_tensor_datatype(node.input[0]) != self.dtype: + # Get the new datatype + new_dtype = model.get_tensor_datatype(node.input[0]) + # Issue a warning message + warnings.warn( + f"{node.name}: dtype changing from {self.dtype} to {new_dtype}" + ) + # Set the new datatype attribute + self.set_nodeattr("dtype", new_dtype.name) + # Propagate the type from the input to each output tensor + for o in node.output: + # Replicating simply propagates the dtype to the output + model.set_tensor_datatype(o, self.dtype) + + # Executes replicating inputs in python + def _execute_node_python(self, context, graph): # noqa: graph unused + # Get the node wrapped by this custom op + node = self.onnx_node + # Get the input out of the execution context + inp = context[node.input[0]] + # Copy the input into each of the outputs + for o in node.output: + # Insert copy of input into the execution context at output + context[o] = inp + + # Executes replicating inputs in C++ simulation + def _execute_node_cppsim(self, context, graph): # noqa: graph unused + # C++ Simulation needs to be implemented in HLS backend specialization + raise NotImplementedError( + f"exec_mode cppsim of {self.__class__.__name__} is not implemented!" + ) + + # Executes replicating inputs in RTL simulation + def _execute_node_rtlsim(self, context, graph): # noqa: graph unused + # Get the node wrapped by this custom op # noqa Duplicate + node = self.onnx_node + # Input data is stored in numpy files in the code generation dictionary + code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") + # Get the input out of the execution context + inp = context[node.input[0]] + # Validate the shape of the input + assert inp.shape == self.get_normal_input_shape(ind=0), \ + f"Input shape mismatch for {node.input[0]}" + # Reshape the input into folded form + inp = inp.reshape(self.get_folded_input_shape(ind=0)) + # Path to store the intermediate input in numpy format + filename = os.path.join(code_gen_dir, "in.npy") + # Save the folded inputs to file to be used by simulation + np.save(filename, inp) + # Start collecting inputs/outputs to the RTL simulation in a dictionary + # Note: Prepare one output list per replica + io_dict = { + "inputs": {}, "outputs": {f"out{i}": [] for i in range(self.num)} + } + # Type and width of the input tensor + dtype = self.get_input_datatype(ind=0) + width = self.get_instream_width(ind=0) + # Convert inputs to RTL simulation format + io_dict["inputs"]["in"] = npy_to_rtlsim_input(filename, dtype, width) + + # Setup PyVerilator simulation of the node + sim = self.get_rtlsim() + # Reset the RTL simulation + super().reset_rtlsim(sim) + super().toggle_clk(sim) + # Run the RTL Simulation + self.rtlsim_multi_io(sim, io_dict) + + # Enumerate the node outputs + for i, name in enumerate(node.output): + # Collect the output from RTL simulation + out = io_dict["outputs"][f"out{i}"] + # Type and sizes of the output tensor + dtype = self.get_output_datatype(ind=i) + width = self.get_outstream_width(ind=i) + shape = self.get_folded_output_shape(ind=i) + # Path to store the intermediate numpy file + filename = os.path.join(code_gen_dir, f"out{i}.npy") + # Convert from RTL simulation format to numpy format + rtlsim_output_to_npy( + out, filename, dtype, shape, width, dtype.bitwidth() + ) + # Load the generated output numpy file + out = np.load(filename) + # Reshape the folded output and insert into the execution context + context[name] = out.reshape(self.get_normal_output_shape(ind=i)) + + # Executes replicating inputs in simulation (either python c++ or rtl sim) + def execute_node(self, context, graph): + # Get the configured execution mode + mode = self.get_nodeattr("exec_mode") + # Lookup table mapping execution modes to implementing methods + exec_fns = { + "python": self._execute_node_python, + "cppsim": self._execute_node_cppsim, + "rtlsim": self._execute_node_rtlsim, + } + # Select and execute the function by mode string + exec_fns[mode](context, graph) + + # Verifies the node attributes, inputs and outputs + def verify_node(self): + # TODO: Implement + return [] + + # Note: End of QONNX CustomOp region, below is FINN HWCustomOp stuff + + # Gets the datatype of input at index ind + def get_input_datatype(self, ind=0): + # All inputs (there should only be one) have the same type + return self.dtype + + # Gets the datatype of the output at index ind + def get_output_datatype(self, ind=0): + # All outputs will hae the same type, which is the same as the input + return self.dtype + + # Gets the shape of the input at index ind without folding + def get_normal_input_shape(self, ind=0): + # There is only one input with shape configured as attributes + # Unpack multi-axis inputs list to yield a flat tuple as shape + return *self.num_inputs, self.num_elems + + # Gets the shape of the output at index ind without folding + def get_normal_output_shape(self, ind=0): + # All outputs have the same shape, which is the same as the input + # Unpack multi-axis inputs list to yield a flat tuple as shape + return *self.num_inputs, self.num_elems + + # Gets the shape of the input at index ind with folding + def get_folded_input_shape(self, ind=0): + # Valid folding requires the PE to divides the number of elements + assert self.num_elems % self.pe == 0, "PE must divide num_elems" + # Folding along the last dimension + return *self.num_inputs, self.num_elems // self.pe, self.pe + + # Gets the shape of the output at index ind with folding + def get_folded_output_shape(self, ind=0): + # Valid folding requires the PE to divides the number of elements + assert self.num_elems % self.pe == 0, "PE must divide num_elems" + # Folding along the last dimension + return *self.num_inputs, self.num_elems // self.pe, self.pe + + # Widths of the input data stream of the input at index ind + def get_instream_width(self, ind=0): + # Get the number of bits used to represent the input + i_bits = self.get_input_datatype(ind).bitwidth() + # Parallelism is the number of elements in the last dimension of the + # folded input + *_, elems = self.get_folded_input_shape(ind) + # Width of a stream receiving input elements in parallel + return elems * i_bits + + # Widths of the output data stream of the output at index ind + def get_outstream_width(self, ind=0): + # Get the number of bits used to represent the output + o_bits = self.get_output_datatype(ind).bitwidth() + # Parallelism is the number of elements in the last dimension of the + # folded output + *_, elems = self.get_folded_output_shape(ind) + # Width of a stream producing output elements in parallel + return elems * o_bits + + # Gets the number of expected output values, i.e. how many times read() + # could/should be called on any output stream of this operator + def get_number_output_values(self): + # Elements over all but the last dimension of the output folded along + # the embedding dimension. Need to count across the number of replicas, + # as RTL simulation actually counts individual outputs, not cycles with + # outputs, i.e., producing N replica outputs per cycle in parallel, + # count N outputs per cycle... + return np.prod(self.get_folded_output_shape()[:-1]) * self.num + + # Derives the expected cycles for the stream replication operation given the + # folding configuration + def get_exp_cycles(self): + # Number of iterations required to process the whole folded input stream + # Note: This is all but the PE (last, parallelized) dimension + return np.prod(self.get_folded_output_shape()[:-1]) diff --git a/src/finn/custom_op/fpgadataflow/templates.py b/src/finn/custom_op/fpgadataflow/templates.py index ddc1d1f99..3a86a2a7d 100644 --- a/src/finn/custom_op/fpgadataflow/templates.py +++ b/src/finn/custom_op/fpgadataflow/templates.py @@ -136,13 +136,15 @@ set config_proj_part "$FPGAPART$" set config_bnnlibdir "$::env(FINN_ROOT)/deps/finn-hlslib" puts "finn-hlslib dir: $config_bnnlibdir" +set config_attentionlibdir "$::env(FINN_ROOT)/deps/attention-hlslib" +puts "attention-hlslib dir: $config_attentionlibdir" set config_customhlsdir "$::env(FINN_ROOT)/custom_hls" puts "custom HLS dir: $config_customhlsdir" set config_toplevelfxn "$TOPFXN$" set config_clkperiod $CLKPERIOD$ open_project $config_proj_name -add_files $config_hwsrcdir/top_$TOPFXN$.cpp -cflags "-std=c++14 -I$config_bnnlibdir -I$config_customhlsdir" +add_files $config_hwsrcdir/top_$TOPFXN$.cpp -cflags "-std=c++14 -I$config_bnnlibdir -I$config_customhlsdir -I$config_attentionlibdir" set_top $config_toplevelfxn open_solution sol1 diff --git a/src/finn/transformation/fpgadataflow/attention.py b/src/finn/transformation/fpgadataflow/attention.py new file mode 100644 index 000000000..0b68d0588 --- /dev/null +++ b/src/finn/transformation/fpgadataflow/attention.py @@ -0,0 +1,658 @@ +# fmt: off +# Disable formatter. This is deliberately formatted to stay within 80 characters +# per line. Black, however, formats some lines going beyond this. + +# Standard math functions +import math + +# Need numpy for modifying the onnx graph tensors, which are numpy style arrays +import numpy as np + +# Output warning messages +import warnings + +# Utility for handling ONNX nodes and tensors +from onnx import NodeProto +from onnx import helper as oh + +# QONNX datatypes +from qonnx.core.datatype import BaseDataType, DataType + +# QONNX wrapper of ONNX model graphs +from qonnx.core.modelwrapper import ModelWrapper + +# Convert ONNX nodes to QONNX custom ops +from qonnx.custom_op.registry import getCustomOp + +# QONNX graph transformation base class +from qonnx.transformation.base import Transformation + +# Transformations running qonnx datatype inference +from qonnx.transformation.infer_datatypes import InferDataTypes + +# Transformation running onnx shape inference +from qonnx.transformation.infer_shapes import InferShapes + +# Gets items from protobuf by name +from qonnx.util.basic import get_by_name, remove_by_name + +# Utility function for transforming ONNX graphs +from finn.transformation.util import ( + all_upstream_to_matmul, + is_add, + is_join_matmul, + is_matmul, + is_mul, + is_softmax, + op_types, +) + + +# Convert the operator pattern corresponding to scaled dot-product attention to +# the hardware custom operator node +class InferScaledDotProductAttention(Transformation): + # Applies the transform to a whole model graph + def apply(self, model: ModelWrapper): # noqa + # Get the model graph out of the model wrapper object + graph = model.graph + # Keep track of whether the graph has been modified + graph_modified = False + # Iterate all nodes in the graph keeping track of the index + for index, node in enumerate(graph.node): + # This transformation is triggered by finding a join-node MatMul + if is_join_matmul(node, model): + # If there are more than two branches feeding the MatMul, this + # is probably not attention, softly skip the node + if len(node.input) != 2: + continue + # Follow both branches upstream looking for the next MatMul + lhs, rhs = all_upstream_to_matmul(node, model) + # Exactly one of the branches is supposed to contain a Softmax + # operation + if ("Softmax" in op_types(lhs)) == ("Softmax" in op_types(rhs)): + # TODO: Near match. But what is this? just skip? + continue + # By convention and following the equation, the left hand side + # of attention is the attention matrix, i.e., the one containing + # Softmax and terminating in a join-node MatMul + if "Softmax" not in op_types(lhs): + # Softmax must currently be on the right hand side, swap the + # order + lhs, rhs = rhs, lhs + # The left hand side, i.e, attention matrix must terminate in a + # join-node MatMul involving the query and key input + if not is_join_matmul(lhs[-1], model): + # TODO: Near match. But what is this? just skip? + continue + # Get shapes of input tensors, expect the second inputs, i.e., + # the keys to be transposed + qh, ql, qe = model.get_tensor_shape(lhs[-1].input[0]) + kh, ke, kl = model.get_tensor_shape(lhs[-1].input[1]) + # The input shapes of the two matmul inputs must be compatible, + # i.e., they must have matching embedding dimension + if (qh, True, qe) != (kh, True, ke): + # Issue a warning of near match of the supported attention + # pattern + # @formatter:off + warnings.warn( + f"{self.__class__.__name__}: Skipping near match: " + f"Mismatch in head or embedding dim at {lhs[-1].name}: " + f" {(qh, ql, qe)} vs. {(kh, kl, ke)}" + ) + # @formatter:on + # Skip transforming this instance + continue + # There must be a Transpose feeding the key input + transpose = model.find_producer(lhs[-1].input[1]) + # The transform applies only to transpose with exactly one input + if transpose is None or len(transpose.input) != 1: + # Issue a warning of near match of the supported attention + # pattern + # @formatter:off + warnings.warn( + f"{self.__class__.__name__}: Skipping near match: " + f"Missing Transpose near {lhs[-1].name}: " + f" {op_types([transpose])[0]}" + ) + # @formatter:on + # Skip transforming this instance + continue + + # Skip this node if the transpose output forks into multiple + # branches + if model.is_fork_node(transpose): + # Issue a warning of near match of the supported attention + # pattern + # @formatter:off + warnings.warn( + f"{self.__class__.__name__}: Skipping near match: " + f"Fork Transpose near {node.name}: {transpose.name}" + ) + # @formatter:on + # Skip transforming this instance + continue + + # The input shape of the transpose must match the transpose + # of the key matrix + # @formatter:off + assert model.get_tensor_shape(transpose.input[0]) == [ + kh, kl, ke + ] + # @formatter:on + # Collect the input tensors to the attention operation, i.e., + # the query, key and value tensors + q, k, v = lhs[-1].input[0], transpose.input[0], rhs[0].output[0] + # Validate that the values are actually consumed by the final + # matmul. For queries and keys this should all be given, as we + # just walked upwards the graph. + assert node in model.find_consumers(v) + + # Get the (optional) Softmax activation function + act_a_softmax = lhs[0] if is_softmax(lhs[1]) else None + # Get the (optional) query-key matmul activation function + act_qk_matmul = lhs[-2] if is_matmul(lhs[-1]) else None + + # There might be no activation function between qk matmul and + # softmax normalization + if is_mul(act_qk_matmul) or is_softmax(act_qk_matmul): + # Remove the detected activation function node from the + # pattern candidates + act_qk_matmul = None + + # Check whether the node is a supported type of activation + def is_supported_activation(n: NodeProto): # noqa: Shadows name + # Currently, only none-type and MultiThreshold activations + # are supported + return n is None or n.op_type in {"MultiThreshold"} + + # Get the (optional) output matmul activation function + act_av_matmul = model.find_direct_successors(node) + # If the final matmul is a fork node, this needs to be handled + # separately + if act_av_matmul is not None and len(act_av_matmul) > 1: + # Assume no activation in this case + act_av_matmul = [None] + # Unwrap the output activation from the list + act_av_matmul, = act_av_matmul + # The final activation can be omitted if it is not supported as + # it might just be part of the next operator pattern + if not is_supported_activation(act_av_matmul): + # Remove by setting to None (will be ignored by the next + # steps) + act_av_matmul = None + # List all activations for validation and further processing + # Note: Order matters! + acts = [act_qk_matmul, act_a_softmax, act_av_matmul] + # Skip this node if any activation is not supported + if not all(is_supported_activation(act) for act in acts): + # Issue a warning of near match of the supported attention + # pattern + # @formatter:off + warnings.warn( + f"{self.__class__.__name__}: Skipping near match: " + f"Unsupported activation near {node.name}: " + f" One of {', '.join(op_types(acts))}" + ) + # @formatter:on + # Skip transforming this instance + continue + + # Check whether there is a de-quantizer scale factor preceding + # the Softmax operator + dequant_softmax = lhs[2] if is_softmax(lhs[1]) else None + + # If there is no dequant softmax yet, check alternative pattern + if dequant_softmax is None: + # Alternatively, there might not be a quantizer following + # the softmax + dequant_softmax = lhs[1] if is_softmax(lhs[0]) else None + + # Assume no attention mask by default + mask, mask_mode, mask_dtype = [], 'none', DataType["BINARY"] + # If there is an elementwise add operation where we have + # expected the dequantizer, this might be an attention mask + if is_add(dequant_softmax): + # Remember the candidate of the masking operation + maybe_mask = dequant_softmax + # If there is a mask candidate, the dequantizer, must be + # right before + dequant_softmax = model.find_direct_predecessors( + dequant_softmax + ) + # The attention mask may not have multiple producers + if len(dequant_softmax) != 1: + # Issue a warning of near match of the supported + # attention pattern + # @formatter:off + warnings.warn( + f"{self.__class__.__name__}: Skipping near match: " + f"Unsupported de-quantizer near {maybe_mask.name}: " + f" {op_types(dequant_softmax)}" + ) + # @formatter:on + # Skip transforming this instance + continue + # There is a single producer, which is probably the + # dequantizer + dequant_softmax, = dequant_softmax + + # The mask can be an initializer or provided as an input. If + # it is given as an initializer, it can either be a causal + # mask or some arbitrary pattern. + + # Check whether a tensor is a valid mask tensor + def valid_mask(tensor): + # Valid masks contain only two types of values, i.e., + # zero for not masked and -inf for masked slots + return all( + x in {0.0, -np.inf} for x in np.unique(tensor) + ) + + # Check whether a tensor describes a causal attention mask + def is_causal(tensor): + # Generate a causal mask of the same size + causal = np.triu(-np.inf * np.ones_like(tensor), 1) + # Compare candidate against the causal mask + return (tensor == causal).all() # noqa: 'all' + + # Try to get the initializer of the masking operation + mask_tensor = model.get_initializer(maybe_mask.input[1]) + # Check whether this is constant mask known at export time + if mask_tensor is not None: + # We have a constant mask and need to validated that it + # only contains valid values + if not valid_mask(mask_tensor): + # Issue a warning of near match of the supported + # attention pattern + # @formatter:off + warnings.warn( + f"{self.__class__.__name__}: Skipping near" + f" match: Invalid values in mask near" + f" {maybe_mask.name}: {np.unique(mask_tensor)}" + ) + # @formatter:on + # Skip transforming this instance + continue + # If this is a causal mask, just set the flag and drop + # the input as the behavior can be generated on the fly + if is_causal(mask_tensor): + # Set the mode flag + mask_mode = "causal" + # This is a constant but non-causal mask which needs to + # be kept as an input to the operator + else: + # Keep the input and set the mode flag + mask, mask_mode = [maybe_mask.input[1]], "const" + # Convert the mask to a binary mask getting rid of + # explicitly storing the infinities + mask_tensor = (mask_tensor == -np.inf) + # Set the initializer to the binary mask still using + # float as the container type + model.set_initializer( + *mask, mask_tensor.astype(np.float32) + ) + # Set the quantization type annotation to binary + model.set_tensor_datatype(*mask, DataType["BINARY"]) + # Dynamic input mask, cannot be validated beforehand + else: + # # Keep the input and set the corresponding mode flag + # mask, mask_mode = [maybe_mask.input[1]], "input" + # # Keep track of the datatype of the mask + # mask_dtype = model.get_tensor_datatype(*mask) + + # Handling dynamic masks is more difficult and there is + # no solution for now. + # @formatter:off + warnings.warn( + f"{self.__class__.__name__}: Skipping near match: " + f"Unsupported dynamic mask near {maybe_mask.name}: " + f" {mask}" + ) + # @formatter:on + # Skip transforming this instance + continue + + # Currently, only elementwise Mul is supported as de-quantizer + if not is_mul(dequant_softmax): + # Issue a warning of near match of the supported attention + # pattern + # @formatter:off + warnings.warn( + f"{self.__class__.__name__}: Skipping near match: " + f"Unsupported de-quantizer near {lhs[1].name}: " + f" {dequant_softmax.op_type}" + ) + # @formatter:on + # Skip transforming this instance + continue + + # If there is a dequant scale factor, try to lift it from + # initializer to node attribute + if dequant_softmax is not None: + # Get the initializer tensor + scale = model.get_initializer(dequant_softmax.input[1]) + # This must be an initializer, the attention operator + # currently does not handle any dynamically produced scale + # factors + if scale is None: + # Issue a warning of near match of the supported + # attention pattern + # @formatter:off + warnings.warn( + f"{self.__class__.__name__}: Skipping near match: " + f"Non-constant dequantizer near {node.name}: " + f" {dequant_softmax.name}" + ) + # @formatter:on + # Skip transforming this instance + continue + # Currently, only scalar dequantizer scale factors are + # supported + if not all(x == 1 for x in scale.shape): + # Issue a warning of near match of the supported + # attention pattern + # @formatter:off + warnings.warn( + f"{self.__class__.__name__}: Skipping near match: " + f"Non-scalar dequantizer near {node.name}: " + f" {dequant_softmax.name}" + ) + # @formatter:on + # Skip transforming this instance + continue + # Extract the single float value of the tensor + dequant_softmax = float(scale.item()) + # Insert default scale if the is no dequantizer present + else: + # Default is identity scale + dequant_softmax = 1.0 + + # The last node of the attention operator is either the detected + # matmul or the following, optional activation function + last = act_av_matmul if act_av_matmul is not None else node + + # Tensor names of the threshold inputs + # Note: order matters + thresholds = [ + # TODO: Fix condition once more activation types are + # supported, currently there are only none and thresholds + act.input[1] for act in acts if act is not None + ] + + # Convert activation function types to string representation + def act_op_type_str(act): + # Only MultiThreshold is supported currently + if act is not None and act.op_type == "MultiThreshold": + # The attention custom op uses "thresholds" to identify + return "thresholds" + # All other types are not supported + return "none" + + # The value tensor shape must be compatible with the attention + # matrix + assert model.get_tensor_shape(v)[:2] == [qh, kl] + + # Output type of the first matmul + out_qk_matmul = lhs[-1].output[0] + # Extend the output type to include the optional thresholding + # activation + if act_qk_matmul is not None: + # Single output tensor of the activation function + out_qk_matmul = act_qk_matmul.output[0] + + # Extract output bias of the thresholding activation functions + def out_bias(act): + # Does only apply to thresholding activations + if act is not None and act.op_type == "MultiThreshold": + # Extract via interpreting the node as QONNX custom op + return getCustomOp(act).get_nodeattr("out_bias") + # Default bias if no bias + return 0.0 + + # Fixed node attributes and extracted input/output/initializer + # tensor names + kwargs = { + # Refer to this operator type by its name + "op_type": "ScaledDotProductAttention", + # Execution will try to look up the implementation in the + # package + # referred to by the domain + "domain": "finn.custom_op.fpgadataflow", + # Execution backend: Required attribute inherited from + # HLSCustomOp + "backend": "fpgadataflow", + # Named inputs and activation thresholds extracted from the + # graph pattern + "inputs": [q, k, v, *mask, *thresholds], + # Named model output extracted from the graph pattern + "outputs": last.output, + # Set the attribute specifying how to handel the optional + # attention mask + "mask_mode": mask_mode, + # Give node name derived from the operator type and the name + # of the triggering node to be removed + "name": f"ScaledDotProductAttention_{node.name}" + } + + # Extract the node attributes of the attention operator from + # all constituent nodes + node_attrs = { + # Number of attention heads + "Heads": qh, + # Embedding dimension of queries and keys + "QKDim": qe, + # Length of the query sequence + "QLen": ql, + # Embedding dimension of the values + "VDim": model.get_tensor_shape(v)[2], + # Length of the key and value sequence + "KVLen": kl, + + # Folding along the embedding dimensions + # Note: Assume biggest folding possible fitting both + # embedding dimensions + "EmbFold": math.gcd(qe, model.get_tensor_shape(v)[2]), + # Folding along the sequence dimensions + # Note: Assume biggest folding possible fitting both + # sequence dimensions + "SeqFold": math.gcd(ql, kl), + + # Datatype of query matrix elements + "QType": model.get_tensor_datatype(q), + # Datatype of key matrix elements + "KType": model.get_tensor_datatype(k), + # Datatype of value matrix elements + "VType": model.get_tensor_datatype(v), + # # Datatype of mask matrix elements + "MType": mask_dtype.name, + # Datatype of attention weights elements + "AType": model.get_tensor_datatype(lhs[0].output[0]), + # Datatype of output elements + "OType": model.get_tensor_datatype(last.output[0]), + + # Datatype of accumulator elements of the first matmul + "AccQKMatMul": model.get_tensor_datatype(lhs[-1].output[0]), + # Datatype of output elements of the first matmul + # Note: Can be extracted from the left hand side + # intermediate outputs + "OutQKMatMul": model.get_tensor_datatype(out_qk_matmul), + # Activation function type following the first matmul + "ActQKMatMul": act_op_type_str(act_qk_matmul), + # Output bias to be applied to the thresholding activation + # following the Query x Key multiplication + "BiasActQKMatMul": out_bias(act_qk_matmul), + + # Datatype of accumulator elements of the second matmul + "AccAVMatMul": model.get_tensor_datatype(node.output[0]), + # Datatype of output elements of the second matmul + # Note: Always the same as the OType + "OutAVMatMul": model.get_tensor_datatype(last.output[0]), + # Activation function type following the second matmul + "ActAVMatMul": act_op_type_str(act_av_matmul), + # Output bias to be applied to the thresholding activation + # following the Attention x Value multiplication + "BiasActAVMatMul": out_bias(act_av_matmul), + + # Softmax may be preceded by a de-quantizer scalar + # multiplication + "DequantSoftmax": dequant_softmax, + # Datatype of softmax normalization before applying + # activation or type cast. This is called Acc to stick to + # the naming scheme of the MatMul operators before. + # Note: Currently this is ALWAYS floats + "AccASoftmax": "FLOAT32", + # Activation function type following the softmax + # normalization of the attention weights + "ActASoftmax": act_op_type_str(act_a_softmax), + # Output bias to be applied to the thresholding activation + # following the softmax normalization of the attention + # weights + "BiasActASoftmax": out_bias(act_a_softmax), + } + + # Converts QONNX datatypes to their name (as a string) + def maybe_name(value): + # All QONNX datatypes are instances of the BaseDataType + if isinstance(value, BaseDataType): + # Convert to the name by referring to the datatypes name + # attribute + return value.name + # Everything else is just assumed to be in the right format + return value + + # Convert all node attributes DataTypes to string + # representations of their names + node_attrs = { + key: maybe_name(value) for key, value in node_attrs.items() + } + + # Create a new custom node replacing the scaled dot-product + # attention pattern + attention = oh.make_node(**kwargs, **node_attrs) + # Insert the new node into the graph + graph.node.insert(index, attention) + # Collect all nodes comprising the original pattern + nodes = [node, transpose, *lhs, act_av_matmul] + # Remove all nodes of the original pattern + for n in nodes: + # Do not try to remove non-existing nodes + if n is not None: + graph.node.remove(n) + # The graph has been modified + graph_modified = True + # After rewiring need to re-do the shape annotations + model = model.transform(InferShapes()) # noqa: Shadows model + # As attention mask datatype might have been changed, it might be + # necessary to re-do the datatype annotations + model = model.transform(InferDataTypes()) + # Return the transformed model and indicate whether the graph actually + # has been transformed + return model, graph_modified + + +# Absorbs a MultiThreshold into ScaledDotProductAttention if there is not +# already an activation included +class AbsorbMultiThresholdIntoScaledDotProductAttention(Transformation): + # Applies the transform to a whole model graph + def apply(self, model: ModelWrapper): # noqa + # Get the model graph out of the model wrapper object + graph = model.graph + # Keep track of whether the graph has been modified + graph_modified = False + # Iterate all nodes in the graph keeping track of the index + for index, node in enumerate(graph.node): + # Any MultiThreshold is a candidate node + if node.op_type == "MultiThreshold": + # Cannot be a join-node + if model.is_join_node(node): + # Softly skip transforming this node + continue + # Now we know there is only one producer operation preceding the + # multi-threshold node + attention = model.find_direct_predecessors(node) + # The first node in the graph might have no predecessor + if attention is None: + # Skip this node + continue + # Unpack the single predecessor from the list + attention = attention[0] + # Predecessor must actually be a ScaledDotProductAttention for + # this transform to apply + if not attention.op_type == "ScaledDotProductAttention": + # Skip transforming this instance, probably no need to warn + continue + # The attention operation may not fork for this transformation + # to be applicable + if model.is_fork_node(attention): + # Softly skip transforming this, will result in standalone + # thresholds + continue + + # Check whether the attention operation already has an output + # activation + if getCustomOp(attention).get_nodeattr("ActAVMatMul") != "none": + # Issue a warning to make the user aware of this mismatch + # pattern + # @formatter:off + warnings.warn( + f"{self.__class__.__name__}: Skipping near match: " + f" {attention.name} already has an activation:" + f" {get_by_name(attention.attribute, 'ActAVMatMul').s}" + ) + # @formatter:on + # Skip transforming this instance + continue + + # Datatype of the thresholding output, which will be the new + # output datatype of the attention operator + dtype = getCustomOp(node).get_nodeattr("out_dtype") + # Output bias after the thresholding, needs to be absorbed into + # the attention operator as well + out_bias = getCustomOp(node).get_nodeattr("out_bias") + + # Collect new attributes + attrs = { + # Datatype of output elements of the second matmul + # Note: Always the same as the OType + "OutAVMatMul": dtype, + # Attention operator output type must be the same as the + # output type of the last matmul + "OType": dtype, + # Activation function type following the second matmul + "ActAVMatMul": "thresholds", + # Output bias to be applied to the thresholding activation + # following the Attention x Value multiplication + "BiasActAVMatMul": out_bias, + } + + # Run over all attributes to be changed + for key, value in attrs.items(): + # Remove the existing attribute + remove_by_name(attention.attribute, key) + # Insert a new attribute with the same name + attention.attribute.append(oh.make_attribute(key, value)) + + # Append the new threshold tensor as the last input + attention.input.append(node.input[1]) + # Annotate the new thresholds tensor datatype + model.set_tensor_datatype( + node.input[1], model.get_tensor_datatype(node.input[0]) + ) + # Rewire the output of the attention operator to skip the + # thresholds node + attention.output[0] = node.output[0] + # Remove the thresholding node + graph.node.remove(node) + # The graph has been modified + graph_modified = True + # Break the loop after adding and removing nodes to start over + # with a clean index + break + # After rewiring need to re-do the shape annotations + model = model.transform(InferShapes()) # noqa: Shadows model + # As attention mask datatype might have been changed, it might be + # necessary to re-do the datatype annotations + model = model.transform(InferDataTypes()) + # Return the transformed model and indicate whether the graph actually + # has been transformed + return model, graph_modified diff --git a/src/finn/transformation/fpgadataflow/attention_heads.py b/src/finn/transformation/fpgadataflow/attention_heads.py new file mode 100644 index 000000000..0b5bc0e7d --- /dev/null +++ b/src/finn/transformation/fpgadataflow/attention_heads.py @@ -0,0 +1,828 @@ +# fmt: off +# Disable formatter. This is deliberately formatted to stay within 80 characters +# per line. Black, however, formats some lines going beyond this. + +# Make copies and deep copies of python objects +import copy + +# Need numpy for modifying the onnx graph tensors, which are numpy style arrays +import numpy as np + +# Output warning messages +import warnings + +# Utility for handling ONNX nodes and tensors +from onnx import NodeProto +from onnx import helper as oh + +# QONNX wrapper of ONNX model graphs +from qonnx.core.modelwrapper import ModelWrapper + +# QONNX graph transformation base class +from qonnx.transformation.base import Transformation + +# QONNX graph transformations for renaming and cleaning up +from qonnx.transformation.general import GiveUniqueParameterTensors + +# Transformation running qonnx datatype inference +from qonnx.transformation.infer_datatypes import InferDataTypes + +# Transformation running onnx shape inference +from qonnx.transformation.infer_shapes import InferShapes + +# Gets items from protobuf by name +from qonnx.util.basic import get_by_name, remove_by_name + +# Utility function for transforming ONNX graphs +from finn.transformation.util import ( + is_reshape_transpose, + is_transpose_reshape, + op_types, +) + + +# Infers reshaping of attention heads, i.e., converts the Reshape and transpose +# patterns to the SplitMultiHeads and MergeMultiHeads hardware custom operators. +class InferMultiHeads(Transformation): + # Applies the transform to a whole model graph + def apply(self, model: ModelWrapper): # noqa + # Get the model graph out of the model wrapper object + graph = model.graph + # Keep track of whether the graph has been modified + graph_modified = False + # Iterate all nodes in the graph keeping track of the index + for index, node in enumerate(graph.node): + # Head-slicing reshaping is triggered by detecting a reshape + # operation followed by a transpose + if is_reshape_transpose(node, model): + # Get the single successor node + transpose = model.find_direct_successors(node)[0] + + # Get the input and output tensor names to the pattern + inp = node.input[0] + mid = node.output[0] + end = transpose.output[0] + + # Get the shape of the input tensor for inferring the number of + # heads and correctly propagating shapes + shape = model.get_tensor_shape(inp) + # Determine the rank of the input tensor to support batched and + # non-batched inputs + rank = len(shape) + + # Can only handle 3-dimensional (2-dimensional) layouts for now + if rank not in {2, 3}: + # Issue a warning of near match of the supported head + # pattern + # @formatter:off + warnings.warn( + f"{self.__class__.__name__}: Skipping near match: " + f"Unsupported shape near {transpose.name}: {inp}" + ) + # @formatter:on + # Skip transforming this instance + continue + + # The input shape determines the sequence length + seq, _, dim = shape if (rank == 3) else (shape[0], 1, shape[1]) + + # Can only handle 3-dimensional (2-dimensional) layouts for now + if len(model.get_tensor_shape(mid)) != 3: + # Issue a warning of near match of the supported head + # pattern + # @formatter:off + warnings.warn( + f"{self.__class__.__name__}: Skipping near match: " + f"Unsupported shape near {transpose.name}: {mid}" + ) + # @formatter:on + # Skip transforming this instance + continue + + # The intermediate shape must be the same as specified as the + # second input to the reshape operation + assert (model.get_tensor_shape(mid) # noqa + == model.get_initializer(node.input[1])).all() # noqa + # Expected layout after reshape is "head last" + _, heads, _ = model.get_tensor_shape(mid) + + # Get the (optional) permutation indices of the transpose in + # case it is a multi-axis transpose + perm = get_by_name(transpose.attribute, "perm") + # Convert permutation indices to list of integers if it is + # given + perm = perm.ints if perm is not None else None + + # Transpose must either keep or flip the sequence and embedding + # dimensions + if perm not in [[1, 0, 2], [1, 2, 0]]: + # Issue a warning of near match of the supported head + # pattern + # @formatter:off + warnings.warn( + f"{self.__class__.__name__}: Skipping near match: " + f"Unsupported permutation near {transpose.name}: {perm}" + ) + # @formatter:on + # Skip transforming this instance + continue + + # Check whether the transpose only permutes to head first or + # additionally transposes sequence and embedding dimension as + # well + keep_transpose = (perm == [1, 2, 0]) + + # Start assuming there is no middle node, as the transpose is + # removed + maybe_mid = end + + # Insert a new transpose node if the sequence and embedding + # dimensions are flipped + if keep_transpose: + # Construct a new intermediate tensor using the current one + # as template + maybe_mid = mid + # Construct a new Transpose with attributes inferred from + # the detected graph patter + new_transpose = oh.make_node(**{ + "op_type": "Transpose", + # Named inputs extracted from the graph pattern + "inputs": [maybe_mid], + # Named outputs extracted from the graph pattern + "outputs": [end], + # Give node name derived from the operator type and the + # name of the triggering node to be removed + "name": f"MultiHeads_Transpose_{node.name}", + # Permute the last two dimensions + "perm": [0, 2, 1] + }) + # Insert the new node into the graph + graph.node.insert(index + 1, new_transpose) + # Change the shape of the intermediate tensor to reflect + # partial reshaping + model.set_tensor_shape( + maybe_mid, (heads, seq, dim // heads) + ) + + # Fixed node attributes and extracted input/output/initializer + # tensor names + kwargs = { + # Refer to this operator type by its name + "op_type": "SplitMultiHeads", + # Execution will try to look up the implementation in the + # package referred to by the domain + "domain": "finn.custom_op.fpgadataflow", + # Execution backend: Required attribute inherited from + # HLSCustomOp + "backend": "fpgadataflow", + # Named inputs extracted from the graph pattern + "inputs": [inp], + # Named outputs extracted from the graph pattern + "outputs": [maybe_mid], + # Give node name derived from the operator type and the name + # of the triggering node to be removed + "name": f"SplitMultiHeads_{node.name}", + # Number of attention heads inferred + "heads": heads, + # Inferred multi-heads produce packed tensors + "packed": True, + # Datatype of inputs and outputs + "dtype": model.get_tensor_datatype(node.input[0]).name, + # Number of input elements, i.e., embedding dimension + "num_elems": dim, + # Number of embeddings in the whole input sequence/feature + # map + "num_inputs": [seq, 1] if (rank == 3) else [seq] + } + + # Create a new custom node replacing the multi head reshape + heads = oh.make_node(**kwargs) + # Insert the new node into the graph + graph.node.insert(index, heads) + # Collect all nodes comprising the original pattern + nodes = [node, transpose] + # Remove all nodes of the original pattern + for n in nodes: + # Do not try to remove non-existing nodes + if n is not None: + graph.node.remove(n) + # The graph has been modified + graph_modified = True + + # Head-merging reshaping is triggered by detecting a transpose + # operation followed by a reshape + if is_transpose_reshape(node, model): + # Get the single successor node + reshape = model.find_direct_successors(node)[0] + + # Get the input and output tensor names to the pattern + inp = node.input[0] + end = reshape.output[0] + + # Get the shape of the input tensor for inferring the number of + # heads and correctly propagating shapes + shape = model.get_tensor_shape(inp) + # Determine the rank of the input tensor to support batched and + # non-batched inputs + rank = len(shape) + + # Can only handle 3-dimensional (2-dimensional) layouts for now + if rank not in {3}: + # Issue a warning of near match of the supported head + # pattern + # @formatter:off + warnings.warn( + f"{self.__class__.__name__}: Skipping near match: " + f"Unsupported shape near {reshape.name}: {inp}" + ) + # @formatter:on + # Skip transforming this instance + continue + + # The input shape determines the heads, sequence length and + # embedding dimension + heads, seq, dim = shape + + # Get the (optional) permutation indices of the transpose in + # case it is a multi-axis transpose + perm = get_by_name(node.attribute, "perm") + # Convert permutation indices to list of integers if it is given + perm = perm.ints if perm is not None else None + + # Transpose must flip the heads and sequence dimensions + if perm not in [[1, 0, 2]]: + # Issue a warning of near match of the supported head + # pattern + # @formatter:off + warnings.warn( + f"{self.__class__.__name__}: Skipping near match: " + f"Unsupported permutation near {node.name}: {perm}" + ) + # @formatter:on + # Skip transforming this instance + continue + + # Shape of the final output of the operator pattern + out_shape = model.get_tensor_shape(end) + + # The output of the reshape must be the same as specified as the + # second input to the reshape operation + assert (out_shape # noqa + == model.get_initializer(reshape.input[1])).all() + + # The final output shape must match the expectation of + # reintegrating the heads back into the embeddings + if out_shape not in [[seq, heads * dim], [seq, 1, heads * dim]]: + # Issue a warning to make the user aware of this mismatch + # pattern + # @formatter:off + warnings.warn( + f"{self.__class__.__name__}: Skipping near match: " + f"Output shape mismatch near: {reshape.name}" + ) + # @formatter:on + # Skip transforming this instance + continue + + # Fixed node attributes and extracted input/output/initializer + # tensor names + kwargs = { + # Refer to this operator type by its name + "op_type": "MergeMultiHeads", + # Execution will try to look up the implementation in the + # package referred to by the domain + "domain": "finn.custom_op.fpgadataflow", + # Execution backend: Required attribute inherited from + # HLSCustomOp + "backend": "fpgadataflow", + # Named inputs extracted from the graph pattern + "inputs": [inp], + # Named outputs extracted from the graph pattern + "outputs": [end], + # Give node name derived from the operator type and the name + # of the triggering node to be removed + "name": f"MergeMultiHeads_{node.name}", + # Number of attention heads inferred + "heads": heads, + # Remember, whether the output needs to be squeezed + "squeezed": out_shape == [seq, heads * dim], + # Inferred multi-heads produce packed tensors + "packed": True, + # Datatype of inputs and outputs + "dtype": model.get_tensor_datatype(node.input[0]).name, + # Number of input elements, i.e., embedding dimension + "num_elems": dim, + # Number of embeddings in the whole input sequence/feature + # map + "num_inputs": [heads, seq], + } + + # Create a new custom node replacing the multi head reshape + heads = oh.make_node(**kwargs) + # Insert the new node into the graph + graph.node.insert(index, heads) + # Collect all nodes comprising the original pattern + nodes = [node, reshape] + # Remove all nodes of the original pattern + for n in nodes: + # Do not try to remove non-existing nodes + if n is not None: + graph.node.remove(n) + # The graph has been modified + graph_modified = True + # After rewiring need to re-do the shape annotations + model = model.transform(InferShapes()) # noqa: Shadows from outer scope + # Return the transformed model and indicate whether the graph actually + # has been transformed + return model, graph_modified + + +# Move SplitMultiHeads operation past MultiThreshold operation. This is required +# as a precondition for later unrolling the attention heads, as there may not be +# any other operations between splitting and merging the attention heads, +# besides the actual attention operator. +class MoveSplitMultiHeadsPastMultiThreshold(Transformation): + # Applies the transform to a whole model graph + def apply(self, model: ModelWrapper): # noqa + # Get the model graph out of the model wrapper object + graph = model.graph + # Keep track of whether the graph has been modified + graph_modified = False + # Iterate all nodes in the graph keeping track of the index + for index, node in enumerate(graph.node): + # Transformation applies to SplitMultiHeads operation (not Merge) + if node.op_type == "SplitMultiHeads": + # Slicing should not fork or join + if model.is_fork_node(node) or model.is_join_node(node): + # Issue a warning to make the user aware of this mismatch + # pattern + # @formatter:off + warnings.warn( + f"{self.__class__.__name__}: Skipping near match: " + f"Slicing may not join or fork: {node.name}" + ) + # @formatter:on + # Skip transforming this instance + continue + # Now we know there is only one consumer operation following the + # slice node + thresholds_node = model.find_direct_successors(node)[0] # noqa + # Successor must actually be a MultiThresholds for this + # transform to apply + if not thresholds_node.op_type == "MultiThreshold": + # Skip transforming this instance, probably no need to warn + continue + + # Thresholds should not fork or join either + if (model.is_fork_node(thresholds_node) + or model.is_join_node(thresholds_node)): + # Issue a warning to make the user aware of this mismatch + # pattern + # @formatter:off + warnings.warn( + f"{self.__class__.__name__}: Skipping near match: " + f"MultiThreshold may not join or fork:" + f" {thresholds_node.name}" + ) + # @formatter:on + # Skip transforming this instance + continue + + # Get the thresholds tensor, which must be an initializer at + # the second input + thresholds = model.get_initializer(thresholds_node.input[1]) + # This is indeed an error, no way to recover from this, so + # assertion is fine + assert thresholds is not None, \ + f"Missing threshold tensor for {thresholds_node.name}" + + # The slice node should have an attribute specifying the number + # of heads + heads = get_by_name(node.attribute, "heads") + # Heads must be present, otherwise this is an errr + assert heads is not None, \ + f"Missing number of heads for {node.name}" + # Convert heads attribute proto to integer + heads = heads.i + + # Repeat the thresholds for each head along the channel + # dimension + thresholds = np.concatenate(heads * [thresholds]) + # Update the thresholds tensor to simply repurpose the existing + # node + model.set_initializer(thresholds_node.input[1], thresholds) + + # Get names of all tensors involved in connecting the nodes + inp = node.input[0] + mid = node.output[0] + out = thresholds_node.output[0] + + # The middle tensor is now produced by the multi-threshold, + # which does not change the shape. Propagate the shape of the + # input tensor + model.set_tensor_shape(mid, model.get_tensor_shape(inp)) + # As the middle tensor is now produced by the multi-threshold, + # the datatype needs to be taken from the output tensor + model.set_tensor_datatype(mid, model.get_tensor_datatype(out)) + # Remove the datatype attribute before setting the new + # datatype + remove_by_name(node.attribute, "dtype") + # Insert new datatype attribute + node.attribute.append( + oh.make_attribute( + "dtype", model.get_tensor_datatype(out).name + ) + ) + + # Rewire the nodes locally switching order. Reuses all the + # exising tensors. + thresholds_node.input[0] = inp + thresholds_node.output[0] = mid + node.input[0] = mid + node.output[0] = out + + # Graph has been modified, required additional transformations + # to be run + graph_modified = True + # After rewiring need to re-do the shape annotations + model = model.transform(InferShapes()) # noqa: Shadows from outer scope + # Return the transformed model and indicate whether the graph actually + # has been transformed + return model, graph_modified + + +# Move MergeMultiHeads operation past MultiThreshold operation to avoid merging +# excessively large streams and maybe even allow absorbing the thresholds into +# the attention operator. +class MoveMergeMultiHeadsPastMultiThreshold(Transformation): + # Applies the transform to a whole model graph + def apply(self, model: ModelWrapper): # noqa + # Get the model graph out of the model wrapper object + graph = model.graph + # Keep track of whether the graph has been modified + graph_modified = False + # Iterate all nodes in the graph keeping track of the index + for index, node in enumerate(graph.node): + # Transformation applies to MergeMultiHeads operation + if node.op_type == "MergeMultiHeads": + # Merging should not fork, but it may join + if model.is_fork_node(node): + # Issue a warning to make the user aware of this mismatch + # pattern + # @formatter:off + warnings.warn( + f"{self.__class__.__name__}: Skipping near match: " + f"Slicing may not fork: {node.name}" + ) + # @formatter:on + # Skip transforming this instance + continue + # Now we know there is only one consumer operation following the + # slice node + thresholds_node = model.find_direct_successors(node)[0] # noqa + # Successor must actually be a MultiThresholds for this + # transform to apply + if not thresholds_node.op_type == "MultiThreshold": + # Skip transforming this instance, probably no need to warn + continue + + # Thresholds must not fork or join either + if (model.is_fork_node(thresholds_node) + or model.is_join_node(thresholds_node)): + # Issue a warning to make the user aware of this mismatch + # pattern + # @formatter:off + warnings.warn( + f"{self.__class__.__name__}: Skipping near match: " + f"MultiThreshold may not join or fork:" + f" {thresholds_node.name}" + ) + # @formatter:on + # Skip transforming this instance + continue + + # Get the thresholds tensor, which must be an initializer at + # the second input + thresholds = model.get_initializer(thresholds_node.input[1]) + # This is indeed an error, no way to recover from this, so + # assertion is fine + assert thresholds is not None, \ + f"Missing threshold tensor for {thresholds_node.name}" + + # The merge node should have an attribute specifying the number + # of heads + heads = get_by_name(node.attribute, "heads") + # Heads must be present, otherwise this is an errr + assert heads is not None, \ + f"Missing number of heads for {node.name}" + # Convert heads attribute proto to integer + heads = heads.i + + # Split the thresholds for each head along the channel dimension + # Note: This is a list of thresholds per head now + thresholds = np.split(thresholds, heads) + + # Need to insert a new thresholding operation at each input of + # the multi-head merging + for i, inp in enumerate(node.input): + # Start by making a full copy of the original thresholds + # node + new_thresholds = copy.deepcopy(thresholds_node) + # The input to the original merging node becomes the first + # input to the new thresholds node + new_thresholds.input[0] = inp + # Create a new input tensor name for the thresholds + new_thresholds.input[1] = model.make_new_valueinfo_name() + # Annotate the new thresholds input with the new shape of + # the split thresholds + model.set_tensor_shape( + new_thresholds.input[1], thresholds[i].shape + ) + # Set the initializer input to the split thresholds + model.set_initializer( + new_thresholds.input[1], thresholds[i] + ) + # Create a new output tensor name + new_thresholds.output[0] = model.make_new_valueinfo_name() + # Annotate the new output with the shape of the input + model.set_tensor_shape( + new_thresholds.output[0], model.get_tensor_shape(inp) + ) + # Connect the new output tensor to the corresponding input + # of the merge node + node.input[i] = new_thresholds.output[0] + # Connect the output of the merging node to successor of the + # original thresholding node + node.output[0] = thresholds_node.output[0] + # Insert the thresholding node into the graph + graph.node.insert(index + i - 1, new_thresholds) + # Remove the original thresholds node + graph.node.remove(thresholds_node) + # Graph has been modified, required additional transformations + # to be run + graph_modified = True + # Break the loop after adding and removing nodes to start over + # with a clean index + break + # After rewiring need to re-do the shape annotations + model = model.transform(InferShapes()) # noqa: Shadows from outer scope + # Re-do the datatype annotations after inserting new tensors without and + # moving tensors with existing annotations + model = model.transform(InferDataTypes()) + # Return the transformed model and indicate whether the graph actually + # has been transformed + return model, graph_modified + + +# Detects multi-head attention pattern, i.e., scaled dot-product attention +# between head splitting and merging +def is_multi_head_attention(node: NodeProto, model: ModelWrapper): # noqa + # The anchor node must be scaled dot product attention + if node.op_type == "ScaledDotProductAttention": + # Get the nodes feeding the attention operation + predecessors = model.find_direct_predecessors(node) + # There must be exactly three predecessors of type head-splitting + # Note: there must be nothing in between splitting and the attention + # itself + if op_types(predecessors) == 3 * ["SplitMultiHeads"]: + # Get the node fed by the attention operation + successors = model.find_direct_successors(node) + # There must be exactly onde successor of type head-merging + # Note: there must be nothing in between attention and the merging + if op_types(successors) == 1 * ["MergeMultiHeads"]: + # Get the shape of the input tensor for inferring the number of + # heads and correctly propagating shapes + shape = model.get_tensor_shape(node.input[0]) + # Determine the rank of the input tensor to support batched and + # non-batched inputs + rank = len(shape) + # The input shape determines the sequence length + heads, _, _ = shape if (rank == 3) else (1, shape[0], shape[1]) + # Pattern detected, if there are actually multiple heads + return heads > 1 + # Pattern not detected + return False + + +# Unrolls multiple attention heads in the onnx graph to be implemented in +# parallel +class UnrollMultiHeadAttention(Transformation): + # Applies the transform to a whole model graph + def apply(self, model: ModelWrapper): # noqa + # Get the model graph out of the model wrapper object + graph = model.graph + # Keep track of whether the graph has been modified + graph_modified = False + # Iterate all nodes in the graph keeping track of the index + for index, node in enumerate(graph.node): + # Apply transformation to nodes which match the multi-head attention + # pattern + if is_multi_head_attention(node, model): + # Get the splitting nodes fed by the attention operation + split0, split1, split2 = model.find_direct_predecessors(node) + # Get the single merging node + merge0, = model.find_direct_successors(node) + # Get the number of heads produced by an arbitrary splitters + heads = get_by_name(split0.attribute, "heads").i + # Get the number of input elements to the heads splitting + # Note: Embedding dims might actually differ per input stream, + # e.g., for cross-attention + dim0 = get_by_name(split0.attribute, "num_elems").i + dim1 = get_by_name(split1.attribute, "num_elems").i + dim2 = get_by_name(split2.attribute, "num_elems").i + # get the number of input features per splitting + # Note: Feature map sizes might actually differ per input + # stream, e.g., for cross-attention + ins0 = get_by_name(split0.attribute, "num_inputs").ints + ins1 = get_by_name(split1.attribute, "num_inputs").ints + ins2 = get_by_name(split2.attribute, "num_inputs").ints + # Validate the number of heads matches between all slice and + # merge nodes + for n in [split0, split1, split2, merge0]: + # All heads must match, otherwise this is a failure from + # which we cannot recover + assert get_by_name(n.attribute, "heads").i == heads, \ + f"Differing number of heads at {node.name} and {n.name}" + # Remove the original node from the graph + graph.node.remove(n) + + # TODO: Clean up the following code + + # Create replicas of the splitting nodes with expanded output + # list + split0 = oh.make_node( + # Refer to this operator type by its name + op_type="SplitMultiHeads", + # Execution will try to look up the implementation in the + # package referred to by the domain + domain="finn.custom_op.fpgadataflow", + # Execution backend: Required attribute inherited from + # HLSCustomOp + backend="fpgadataflow", + # Connect to the same input as the original + inputs=split0.input, + # Generate new output tensor names for each head + outputs=[ + model.make_new_valueinfo_name() for _ in range(heads) + ], + # Attribute specifying the number of heads + heads=heads, + # Unrolled heads do not produce packed tensors + packed=False, + # Datatype of inputs and outputs + dtype=get_by_name(split1.attribute, "dtype").s, + # Number of input elements, i.e., embedding dimension + num_elems=dim0, + # Number of embeddings in the whole input sequence/feature + # map + num_inputs=[*ins0] + ) + split1 = oh.make_node( + # Refer to this operator type by its name + op_type="SplitMultiHeads", + # Execution will try to look up the implementation in the + # package referred to by the domain + domain="finn.custom_op.fpgadataflow", + # Execution backend: Required attribute inherited from + # HLSCustomOp + backend="fpgadataflow", + # Connect to the same input as the original + inputs=split1.input, + # Generate new output tensor names for each head + outputs=[ + model.make_new_valueinfo_name() for _ in range(heads) + ], + # Attribute specifying the number of heads + heads=heads, + # Unrolled heads do not produce packed tensors + packed=False, + # Datatype of inputs and outputs + dtype=get_by_name(split1.attribute, "dtype").s, + # Number of input elements, i.e., embedding dimension + num_elems=dim1, + # Number of embeddings in the whole input sequence/feature + # map + num_inputs=[*ins1] + ) + split2 = oh.make_node( + # Refer to this operator type by its name + op_type="SplitMultiHeads", + # Execution will try to look up the implementation in the + # package referred to by the domain + domain="finn.custom_op.fpgadataflow", + # Execution backend: Required attribute inherited from + # HLSCustomOp + backend="fpgadataflow", + # Connect to the same input as the original + inputs=split2.input, + # Generate new output tensor names for each head + outputs=[ + model.make_new_valueinfo_name() for _ in range(heads) + ], + # Attribute specifying the number of heads + heads=heads, + # Unrolled heads do not produce packed tensors + packed=False, + # Datatype of inputs and outputs + dtype=get_by_name(split2.attribute, "dtype").s, + # Number of input elements, i.e., embedding dimension + num_elems=dim2, + # Number of embeddings in the whole input sequence/feature + # map + num_inputs=[*ins2] + ) + # Create replica of the merging node with expanded input list + merge0 = oh.make_node( + # Refer to this operator type by its name + op_type="MergeMultiHeads", + # Execution will try to look up the implementation in the + # package referred to by the domain + domain="finn.custom_op.fpgadataflow", + # Execution backend: Required attribute inherited from + # HLSCustomOp + backend="fpgadataflow", + # Generate new input tensor names for each head + inputs=[ + model.make_new_valueinfo_name() for _ in range(heads) + ], + # Connect to the same input as the original + outputs=merge0.output, + # Attribute specifying the number of heads + heads=heads, + # Attribute specifying whether the output needs to be + # squeezed + squeezed=get_by_name(merge0.attribute, "squeezed").i, + # Unrolled heads do not produce packed tensors + packed=False, + # Datatype of inputs and outputs + dtype=get_by_name(merge0.attribute, "dtype").s, + # Number of input elements, i.e., embedding dimension + num_elems=get_by_name(merge0.attribute, "num_elems").i, + # Number of embeddings in the whole input sequence/feature + # map + # Note: Drop head-first head dimension of previously packed + # input + num_inputs=get_by_name( + merge0.attribute, "num_inputs").ints[1:] + ) + + # Replicate the attention operator for each head + for i in range(heads): + # Start by making a full copy of the original node + attention = copy.deepcopy(node) + # Get the original shape of each input to remove the head + # number + _, seq, dim = model.get_tensor_shape(attention.input[0]) + model.set_tensor_shape(split0.output[i], (1, seq, dim)) + _, seq, dim = model.get_tensor_shape(attention.input[1]) + model.set_tensor_shape(split1.output[i], (1, seq, dim)) + _, seq, dim = model.get_tensor_shape(attention.input[2]) + model.set_tensor_shape(split2.output[i], (1, seq, dim)) + + # Propagate the original datatype to each of the head inputs + dtype = model.get_tensor_datatype(attention.input[0]) + model.set_tensor_datatype(split0.output[i], dtype) + dtype = model.get_tensor_datatype(attention.input[1]) + model.set_tensor_datatype(split1.output[i], dtype) + dtype = model.get_tensor_datatype(attention.input[2]) + model.set_tensor_datatype(split2.output[i], dtype) + + # Connect the inputs of the replica to the output of each + # of the new slice operators + attention.input[0] = split0.output[i] + attention.input[1] = split1.output[i] + attention.input[2] = split2.output[i] + + # Get the original shape the output to remove the head + # number + _, seq, dim = model.get_tensor_shape(attention.output[0]) + model.set_tensor_shape(merge0.input[i], (1, seq, dim)) + + # Propagate the original datatype to each of the head + # outputs + dtype = model.get_tensor_datatype(attention.output[0]) + model.set_tensor_datatype(merge0.input[i], dtype) + + # Connect the output of the attention replica to the input + # of the new merge operator + attention.output[0] = merge0.input[i] + # Insert the new node into the graph + graph.node.insert(index + i + 1, attention) + # Insert the new slice and merge nodes into the graph + for i, n in enumerate([split0, split1, split2, merge0]): + # Insert the new node into the graph at index offset by + # number of heads + graph.node.insert(index + heads + i + 1, n) + # Remove the original attention operator from the graph + graph.node.remove(node) + # The graph has been modified, needs to be reported back to the + # caller + graph_modified = True + # After rewiring need to re-do the shape annotations + model = model.transform(InferShapes()) # noqa: Shadows model + # By replicating the attention operator, multiple instances refer to the + # same initializer, replace these by a unique one for each head + model = model.transform(GiveUniqueParameterTensors()) + # Return the transformed model and indicate whether the graph actually + # has been transformed + return model, graph_modified diff --git a/src/finn/transformation/fpgadataflow/convert_to_hw_layers.py b/src/finn/transformation/fpgadataflow/convert_to_hw_layers.py index 9afd86857..02fa2e0f2 100644 --- a/src/finn/transformation/fpgadataflow/convert_to_hw_layers.py +++ b/src/finn/transformation/fpgadataflow/convert_to_hw_layers.py @@ -1498,7 +1498,11 @@ def apply(self, model): (WMEM * PE * SIMD) is violated.""" ) # see if we have any following thresholds - consumer = model.find_consumer(mm_output) + consumers = model.find_consumers(mm_output) + # Only a single consumer node can be absorbed. Absorbing one + # branch of a forking matmul would lead to detached nodes + # breaking the graph. + consumer = consumers[0] if len(consumers) == 1 else None if consumer is not None and consumer.op_type == "MultiThreshold": # TODO ensure integer thresholds? # create MVTU (i.e. including activation) @@ -1622,7 +1626,11 @@ def apply(self, model): (WMEM * PE * SIMD) is violated.""" ) # see if we have any following thresholds - consumer = model.find_consumer(mm_output) + consumers = model.find_consumers(mm_output) + # Only a single consumer node can be absorbed. Absorbing one + # branch of a forking matmul would lead to detached nodes + # breaking the graph. + consumer = consumers[0] if len(consumers) == 1 else None if consumer is not None and consumer.op_type == "MultiThreshold": # TODO ensure integer thresholds? # create MVTU (i.e. including activation) @@ -1776,7 +1784,11 @@ def apply(self, model): # create node with pe=channels as default pe = channels # see if we have any following thresholds - consumer = model.find_consumer(mm_output) + consumers = model.find_consumers(mm_output) + # Only a single consumer node can be absorbed. Absorbing one + # branch of a forking matmul would lead to detached nodes + # breaking the graph. + consumer = consumers[0] if len(consumers) == 1 else None if consumer is not None and consumer.op_type == "MultiThreshold": # create VVAU (i.e. including activation) mt_output = consumer.output[0] diff --git a/src/finn/transformation/fpgadataflow/replicate_stream.py b/src/finn/transformation/fpgadataflow/replicate_stream.py new file mode 100644 index 000000000..fa7fd6a27 --- /dev/null +++ b/src/finn/transformation/fpgadataflow/replicate_stream.py @@ -0,0 +1,110 @@ +# fmt: off +# Disable formatter. This is deliberately formatted to stay within 80 characters +# per line. Black, however, formats some lines going beyond this. + +# Utility for handling ONNX nodes and tensors +from onnx import TensorProto +from onnx import helper as oh + +# QONNX wrapper of ONNX model graphs +from qonnx.core.modelwrapper import ModelWrapper + +# QONNX graph transformation base class +from qonnx.transformation.base import Transformation + +# Transformations running qonnx datatype inference +from qonnx.transformation.infer_datatypes import InferDataTypes + +# Transformation running onnx shape inference +from qonnx.transformation.infer_shapes import InferShapes + + +# Inserts the ReplicateStream hardware operator on tensors with multiple +# consumers +class InferReplicateStream(Transformation): + # Applies the transform to a whole model graph + def apply(self, model: ModelWrapper): # noqa + # Get the model graph out of the model wrapper object + graph = model.graph + # Keep track of whether the graph has been modified + graph_modified = False + # Iterate all nodes in the graph keeping track of the index + for index, node in enumerate(graph.node): + # Check each output of the node, as there might be multiple distinct + # outputs, each feeding multiple consumers + for out in node.output: + # Get the list of all consumers of this output tensor + consumers = model.find_consumers(out) + # No need to replicate if there is just one or no consumer + if consumers is None or len(consumers) <= 1: + # Check next output tensor + continue + # Ok, now we have multiple consumers of a single output tensor + # which requires streams to be replicated for HLS synthesis + # Get the shape of the original output tensor + out_shape = model.get_tensor_shape(out) + # Generate a list of unique replicas of the output tensor, one + # for each consumer + replicas = [model.make_new_valueinfo_name() for _ in consumers] + # Create an instance of the ReplicateStream operator for this + # output + replicate_stream = oh.make_node( + # Name of the operator class as it can be found within FINN + "ReplicateStream", + # Execution will try to look up the implementation in the + # package referred to by the domain + domain="finn.custom_op.fpgadataflow", + # Execution backend: Required attribute inherited from + # HLSCustomOp + backend="fpgadataflow", + # Connect to the original output tensor + inputs=[out], + # Connect to a unique output tensor for each consumer + outputs=replicas, + # The operator needs to now the number of replicas as an + # attribute + num=len(replicas), + # Number of input elements in the last dimension + num_elems=out_shape[-1], + # Number of elements to process in parallel: default fully + # sequential + PE=1, + # Number of inputs to be processed sequentially + num_inputs=out_shape[:-1], + # Infer the datatype from the original output + dtype=model.get_tensor_datatype(out).name, + # Derive a node name based on the original node name + name=f"ReplicateStream_{node.name}" + ) + # Insert the replicate operator into the graph right behind the + # current node + graph.node.insert(index + 1, replicate_stream) + # Need to modify each consumer to have the replica as input + for replica, consumer in zip(replicas, consumers): + # Properly construct a value info object for the new tensor + # replica + model.graph.value_info.append(oh.make_tensor_value_info( + replica, TensorProto.FLOAT, out_shape + )) + # Find the first input of the consumer corresponding to the + # original output tensor + for i, inp in enumerate(consumer.input): + # Check whether this input is the original output + if inp == out: + # Connect this input to the replica of the output + consumer.input[i] = replica + # Break here as multiple inputs to the node might + # connect to the original output, but each gets its + # own replica. + break + # The graph has been modified, needs to be reported back to the + # caller + graph_modified = True + # After rewiring need to re-do the shape annotations + model = model.transform(InferShapes()) # noqa: Shadows model + # As new tensor value infos have been inserted, it is necessary to re-do + # the datatype annotations + model = model.transform(InferDataTypes()) + # Return the transformed model and indicate whether the graph actually + # has been transformed + return model, graph_modified diff --git a/src/finn/transformation/fpgadataflow/set_folding.py b/src/finn/transformation/fpgadataflow/set_folding.py index 9a6526942..7b4f4a4fe 100644 --- a/src/finn/transformation/fpgadataflow/set_folding.py +++ b/src/finn/transformation/fpgadataflow/set_folding.py @@ -120,6 +120,7 @@ def apply(self, model): "GlobalAccPool_hls", "Thresholding_hls", "Thresholding_rtl", + "ReplicateStream_hls", *ELEMENTWISE_BINARY_OPS, "Squeeze_hls", "Unsqueeze_hls", diff --git a/src/finn/transformation/squeeze.py b/src/finn/transformation/squeeze.py new file mode 100644 index 000000000..d9c657b05 --- /dev/null +++ b/src/finn/transformation/squeeze.py @@ -0,0 +1,404 @@ +# QONNX wrapper of ONNX model graphs +# For array handling +import numpy as np + +# Python warning subsystem +import warnings + +# Helper for creating ONNX nodes +from onnx import helper as oh + +# QONNX wrapper of ONNX model graphs +from qonnx.core.modelwrapper import ModelWrapper + +# QONNX graph transformation base class +from qonnx.transformation.base import Transformation + +# Transformations running qonnx datatype inference +from qonnx.transformation.infer_datatypes import InferDataTypes + +# Transformation running onnx shape inference +from qonnx.transformation.infer_shapes import InferShapes + +# Reuse node removal and rewiring from qonnx +from qonnx.transformation.remove import remove_node_and_rewire + +# Gets items from protobuf by name +from qonnx.util.basic import get_by_name, remove_by_name + +# Small utility functions for graph transformations +from .util import is_threshold + + +# Squeezes, i.e., removes, dimensions of size 1 +# Note: Use this transformation with great care, it currently serves only the +# purpose of turning the not well-supported 3d data layouts encountered in +# transformer models with batch dimension of size 1 into 2d data layouts where +# the sequence dimension is treated as a batch dimension. Everything else is +# not tested, it might break the model or simply lack support for certain node +# op-types. +class Squeeze(Transformation): + # Applies the transform to a whole model graph + def apply(self, model: ModelWrapper): # noqa + # Get the model graph out of the model wrapper object + graph = model.graph + # # Keep track of whether the graph has been modified + # graph_modified = False + # Iterate all nodes in the graph keeping track of the index + for index, node in enumerate(graph.node): + # There should not be any squeeze or unsqueeze operations in the + # graph as these would interfere with this transformation + if node.op_type in {"Squeeze", "Unsqueeze"}: + # Issue a warning to make the user aware of this potential issue + # fmt: off + warnings.warn( + f"Squeezing graph containing {node.op_type}: {node.name}" + ) + # fmt: on + + # Validate slice not slicing along squeezed dimension + if node.op_type == "Slice": + # Axes to slice along is supplied as the 4th input to the node + axes = model.get_initializer(node.input[3]) + # If this is an initializer, there are constant axes to slice + if axes is not None: + # Get the shape of the input, assuming the input from + # upstream to be the 1st input + shape = model.get_tensor_shape(node.input[0]) + # Slice might operate on multiple axes + for axis in axes: + # Axis must not refer to a dimension of size 1 + # fmt: off + assert shape[axis] > 1, \ + f"Slice along dimension to be squeezed: {node.name}" + # fmt: on + + # Need to adapt reshape operations to drop dimensions of size 1 + if node.op_type == "Reshape": + # Second input to the reshape operation is the target shape + shape = model.get_initializer(node.input[1]) + # If the initializer is present, this is a constant shape + # reshape which can be replaced by the squeezed shape + if shape is not None: + # Squeeze the shape by removing all dimensions with size 1 + # fmt: off + new_shape = np.asarray([ + size for size in shape if size != 1 + ]) + # fmt: on + # Reassign the squeezed tensor + model.set_initializer(node.input[1], new_shape) + # Track whether the shape actually changed + if len(new_shape) != len(shape): + # Is never reset back to False during iteration + # graph_modified = True + pass + + # Need to drop dimensions of size 1 from transpose permutation list + if node.op_type == "Transpose": + # Get the (optional) permutation indices of the transpose in + # case it is a multi-axis transpose + perm = get_by_name(node.attribute, "perm") + # If the permutation indices are given, we need to remove all + # dimension of size 1 from these + if perm is not None: + # Convert permutation indices to list of integers + perm = perm.ints + # Get the shape of the input tensor to seek for input + # dimensions of size 1 + shape = model.get_tensor_shape( + # fmt: off + node.input[0], fix_missing_init_shape=True + # fmt: on + ) + # Keep track of new axis enumeration, skipping dimensions of + # size 1 + mapping, new_axis = {}, 0 + # Enumerate the sizes per axis + for axis, size in enumerate(shape): + # Insert mapping from old to new axis + mapping[axis] = new_axis + # Only advance the new axis index for dimensions not to + # be squeezed + new_axis += size > 1 + # Filter and remap the axis enumeration of the permutation + new_perm = [ + # fmt: off + mapping[axis] for axis in perm if shape[axis] > 1 + # fmt: on + ] + # Track whether the permutations actually changed + if len(new_perm) != len(perm) or new_perm != perm: + # # Is never reset back to False during iteration + # graph_modified = True + pass + # Remove the permutation attribute before setting the new + # permutation + remove_by_name(node.attribute, "perm") + # Insert new permutation attribute + node.attribute.append(oh.make_attribute("perm", new_perm)) + + # Need to squeeze the number of inputs to multi-head splitting + if node.op_type == "SplitMultiHeads": + # Get number of input feature maps to the merging operation + num_inputs = get_by_name(node.attribute, "num_inputs") # noqa + # Squeeze all dimensions of size 1 + new_num_inputs = [size for size in num_inputs.ints if size != 1] + # Update the attribute by removing and reinserting + remove_by_name(node.attribute, "num_inputs") + node.attribute.append( + # fmt: off + oh.make_attribute("num_inputs", new_num_inputs) + # fmt: on + ) + # Track whether the number of inputs actually changed + if len(new_num_inputs) != len(num_inputs.ints): + # # Is never reset back to False during iteration + # graph_modified = True + pass + + # Need to adjust the index of the split axis by the amount of + # squeezed axes before + if node.op_type == "Split": + # Get the axis attribute from the Split operator + axis = get_by_name(node.attribute, "axis") + # Convert to integer or substitute default 0 according to ONNX + # reference + axis = axis.i if axis is not None else 0 + # Get the shape of the input tensor to the split operation + shape = model.get_tensor_shape(node.input[0]) + # Subtract the number of squeezed, i.e, size=1, axes before axis + axis = axis - sum(size == 1 for size in shape[:axis]) + # Update the attribute by removing and reinserting + remove_by_name(node.attribute, "axis") + node.attribute.append(oh.make_attribute("axis", axis)) + + # Need to set the squeezed output mode of multi-head merging + if node.op_type == "MergeMultiHeads": + # Remove the squeezed attribute + remove_by_name(node.attribute, "squeezed") + # Set squeezed mode attribute + node.attribute.append(oh.make_attribute("squeezed", True)) + # Get number of input feature maps to the merging operation + num_inputs = get_by_name(node.attribute, "num_inputs") # noqa + # Squeeze all dimensions of size 1 + new_num_inputs = [size for size in num_inputs.ints if size != 1] + # Update the attribute by removing and reinserting + remove_by_name(node.attribute, "num_inputs") + node.attribute.append( + # fmt: off + oh.make_attribute("num_inputs", new_num_inputs) + # fmt: on + ) + # Track whether the number of inputs actually changed + if len(new_num_inputs) != len(num_inputs.ints): + # # Is never reset back to False during iteration + # graph_modified = True + pass + + # Need to patch the Im2Col operator when squeezing as this cannot + # operate on other data layouts than 4-dimensional layouts + if node.op_type == "Im2Col": + # Do not squeeze the same operation twice + if get_by_name(node.attribute, "squeezed"): + continue + # Add a new marker attribute to not squeeze this node again + node.attribute.append(oh.make_attribute("squeezed", True)) + # Get the shape of the input tensor to seek for input + # dimensions of size 1 + shape = model.get_tensor_shape( + # fmt: off + node.input[0], fix_missing_init_shape=True + # fmt: on + ) + # Skip if there is no shape + if shape is None: + continue + # Get the axes to be squeezed, i.e., dimensions of size 1 + axes = [dim for dim, size in enumerate(shape) if size == 1] + # To be compatible with ONNX opset >= 13, the axes to + # unsqueeze/squeeze need to be provided as an input + axes_input = model.make_new_valueinfo_name() + # Set the axes as an initializer list + model.set_initializer(axes_input, np.asarray(axes)) + # Instantiate an unsqueeze operation adapting from the squeezed + # layout back to the 4-dimensional layout + unsqueeze = oh.make_node( + # Unsqueeze ONNX operators + "Unsqueeze", + # Inherit the inputs from the Im2Col operation + inputs=[node.input[0], axes_input], + # Create a new output tensor + outputs=[model.make_new_valueinfo_name()], + # Specify the axes to unsqueeze + axes=axes, + ) + # Instantiate a squeeze operator adapting from unsqueezed + # 4-dimensional layout back to the squeezed layout + squeeze = oh.make_node( + # Squeeze ONNX operators + "Squeeze", + # Create a new input tensor + inputs=[model.make_new_valueinfo_name(), axes_input], + # Inherit the output tensor from the Im2Col operation + outputs=node.output, + # Specify the axes to squeeze + axes=axes, + ) + # Rewire the input/output to/from the Im2Col operator to connect + # the Unsqueeze/Squeeze wrapper + node.input[0] = unsqueeze.output[0] + node.output[0] = squeeze.input[0] + # Insert the new nodes + graph.node.insert(index, unsqueeze) + graph.node.insert(index, squeeze) + # # The graph has now been modified. This is never reset back to + # # False during iteration + # graph_modified = True + + # Iterate the graph once again to get rid of existing Squeeze/Unsqueeze + # Note: This needs to be done after all other operations to not mess + # with the shape annotations + for index, node in enumerate(graph.node): + # Squeeze and Unsqueeze can be handled the same + if node.op_type in {"Squeeze", "Unsqueeze"}: + # Do not touch the Unsqueeze/Squeeze surrounding the Im2Col + # operation + if "Im2Col" not in [ + n.op_type + for n in [ + *model.find_direct_predecessors(node), + *model.find_direct_successors(node), + ] + ]: + # Remove existing Squeeze/Unsqueeze from the graph as these + # will not have any effect anymore + remove_node_and_rewire(model, node) + + # Get the names of all global input tensors to insert a Squeeze + # operation in front + global_inputs = [inp.name for inp in model.graph.input] + # Insert Squeeze operators at each global input + for inp in global_inputs: + # Get the shape of the tensor to seek for dimensions of size 1 + shape = model.get_tensor_shape( # noqa: Duplicate + inp, fix_missing_init_shape=True + ) + # Skip if there is no shape and skip squeezing 0d or 1d tensors + if shape is None or len(shape) <= 1: + continue + # Get the axes to be squeezed, i.e., dimensions of size 1 + axes = [dim for dim, size in enumerate(shape) if size == 1] + # Te be compatible with ONNX opset >= 13, the axes to + # unsqueeze/squeeze need to be provided as an input + axes_input = model.make_new_valueinfo_name() + # Set the axes as an initializer list + model.set_initializer(axes_input, np.asarray(axes)) + # Instantiate the squeeze operator + squeeze = oh.make_node( + # Squeeze ONNX operators + "Squeeze", + # Inherit the input from the global input and add axes to be + # squeezed to the input list + inputs=[inp, axes_input], + # Create a new output connecting to the graph + outputs=[model.make_new_valueinfo_name()], + # Specify the axes to squeeze + axes=axes, + ) + # Connect the new squeeze operator to all consumers of this + # global input + for consumer in model.find_consumers(inp): + # Find the inputs of the consumer which are the global input + for i, c_inp in enumerate(consumer.input): + # Note: This might happen multiple times? + if c_inp == inp: + # Rewire consumer's input directly to the output of + # the squeeze operation + consumer.input[i] = squeeze.output[0] + # Insert the squeeze operator into the model graph + model.graph.node.insert(0, squeeze) + + # Get the names of all global output tensors to insert an Unsqueeze + # operation afterward + global_outputs = [out.name for out in model.graph.output] + # Insert Unsqueeze operators at each global output + for out in global_outputs: + # Get the shape of the tensor to seek for dimensions of size 1 + shape = model.get_tensor_shape( # noqa: Duplicate + out, fix_missing_init_shape=True + ) + # Skip if there is no shape and skip squeezing 0d or 1d tensors + if shape is None or len(shape) <= 1: + continue + # Get the axes to be squeezed, i.e., dimensions of size 1 + axes = [dim for dim, size in enumerate(shape) if size == 1] + # Te be compatible with ONNX opset >= 13, the axes to + # unsqueeze/squeeze need to be provided as an input + axes_input = model.make_new_valueinfo_name() + # Set the axes as an initializer list + model.set_initializer(axes_input, np.asarray(axes)) + # Instantiate the unsqueeze operator + unsqueeze = oh.make_node( + # Unsqueeze ONNX operators + "Unsqueeze", + # Connect to a new intermediate tensor + inputs=[model.make_new_valueinfo_name(), axes_input], + # Connect tho the global output + outputs=[out], + # Specify the axes to unsqueeze + axes=axes, + ) + # Connect the new unsqueeze operator to the producer of this global + # output + producer = model.find_producer(out) + # Find the output of the producer which is the global output + for i, p_out in enumerate(producer.output): + # Note: This might happen multiple times? + if p_out == out: + # Rewire producer's output directly to the input of + # the unsqueeze operation + producer.output[i] = unsqueeze.input[0] + # Insert the unsqueeze operator into the model graph + model.graph.node.insert(0, unsqueeze) + + # Iterate all tensors in the graph keeping track of the index + for index, name in enumerate(model.get_all_tensor_names()): + # Skip the global inputs and outputs + if name in [*global_inputs, *global_outputs]: + # Skip without warning, these are handled by explicit + # Squeeze/Unsqueeze operations + continue + # Skip initializer tensors: Shape inference should actually restore + # these shapes, but for some reason it does not work... + if (init := model.get_initializer(name)) is not None: + # If any of the consumers of this initializer is a + # multi-threshold function, it should not be squeezed as the + # thresholding is quite sensitive to data layouts and does not + # handle broadcasting. + # Note: Not sue whether there can actually be cases wih multiple + # consumers of a threshold tensor, but this should be perfectly + # legal according to standard ONNX. + if any(is_threshold(op) for op in model.find_consumers(name)): + # Skip without warning + continue + # First squeeze the actual data of the initializer tensors + model.set_initializer(name, np.squeeze(init)) + # Now also annotate the squeezed shape, otherwise the following + # shape inference might fail or break the graph + # Note: Deleting the annotation is not sufficient here, it is + # not recovered properly from the tensor data for some reason... + model.set_tensor_shape(name, np.squeeze(init).shape) + # Continue with the next tensor, skipping the default case below + continue + # Just delete all existing shape annotations to redo them later + model.set_tensor_shape(name, None) + # Re-do shape and data type annotations after potential changes to the + # model graph + model = model.transform(InferShapes()) + model = model.transform(InferDataTypes()) + # Return the transformed model and indicate whether this transformation + # needs to be repeated + # Note: Never repeat this transformation as it might break when + # inserting multiple Squeeze operators + return model, False diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py index 902249371..53d665267 100644 --- a/src/finn/transformation/streamline/reorder.py +++ b/src/finn/transformation/streamline/reorder.py @@ -1938,7 +1938,10 @@ def apply(self, model: ModelWrapper): # noqa if (value := model.get_initializer(a)) is not None: # Do not transpose scalar or effectively scalar # initializers - if not (value.shape is None or all(x == 1 for x in value.shape)): + # fmt: off + if not (value.shape is None or all( + x == 1 for x in value.shape)): + # fmt: on # Transpose the initializer and re-insert into the # model # fmt: off diff --git a/src/finn/transformation/util.py b/src/finn/transformation/util.py index 1e9ae1817..28371ef15 100644 --- a/src/finn/transformation/util.py +++ b/src/finn/transformation/util.py @@ -4,7 +4,6 @@ # Protobuf onnx graph node type from onnx import NodeProto - # QONNX wrapper of ONNX model graphs from qonnx.core.modelwrapper import ModelWrapper diff --git a/tests/fpgadataflow/test_fpgadataflow_attention.py b/tests/fpgadataflow/test_fpgadataflow_attention.py new file mode 100644 index 000000000..a5c7ac0d1 --- /dev/null +++ b/tests/fpgadataflow/test_fpgadataflow_attention.py @@ -0,0 +1,537 @@ +# fmt: off +# Disable formatter. This is deliberately formatted to stay within 80 characters +# per line. Black, however, formats some lines going beyond this. + +# Testing framework +import pytest +# Use numpy for python execution / computing the ground truth expected values +import numpy as np + +# Automatically generate init, repr, ... for classes containing a lot of +# attributes +from dataclasses import dataclass + +# Utility types and function for creating onnx nodes and graphs +from onnx import TensorProto, helper + +# QONNX datatypes +from qonnx.core.datatype import BaseDataType, DataType, FloatType, IntType +# Wrapper around ONNX model with some graph manipulation utility +from qonnx.core.modelwrapper import ModelWrapper +# Execute onnx model graphs +from qonnx.core.onnx_exec import execute_onnx +# Multithreshold activations +from qonnx.custom_op.general.multithreshold import multithreshold +# Registry of all QONNX CustomOps +from qonnx.custom_op.registry import getCustomOp +# Graph transformation giving unique names to each node in a QONNX model graph +from qonnx.transformation.general import GiveUniqueNodeNames +# QONNX utility for generating random input data for testing and for creating +# models +from qonnx.util.basic import gen_finn_dt_tensor, qonnx_make_model + +# Softmax function on numpy arrays with overflow handling matching the HLS +# operator +from finn.custom_op.fpgadataflow.attention import softmax +from finn.transformation.fpgadataflow.compile_cppsim import CompileCppSim +from finn.transformation.fpgadataflow.hlssynth_ip import HLSSynthIP +from finn.transformation.fpgadataflow.prepare_cppsim import PrepareCppSim +from finn.transformation.fpgadataflow.prepare_ip import PrepareIP +from finn.transformation.fpgadataflow.prepare_rtlsim import PrepareRTLSim + +# FINN graph transformations for preparing simulation (cppsim or rtlsim) +from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode +from finn.transformation.fpgadataflow.specialize_layers import SpecializeLayers + + +# Python/Numpy model of the scaled dot-product attention operator as it is (will +# be...) implemented in the attention-hlslib +@dataclass +class MockScaledDotProductAttention: + # Embedding dimension of queries and keys + QKDim: int + # Length of the query sequence + QLen: int + # Embedding dimension of the values + VDim: int + # Length of the key and value sequence + KVLen: int + + # Folding along the embedding dimensions + EmbFold: int + # Folding along the sequence dimensions + SeqFold: int + + # Datatype of query matrix elements + QType: IntType + # Datatype of key matrix elements + KType: IntType + # Datatype of value matrix elements + VType: IntType + # Datatype of mask matrix elements + MType: IntType + # Datatype of attention weights elements + AType: IntType + # Datatype of output elements + OType: IntType + + # Datatype of accumulator elements of the Query x Key multiplication + AccQKMatMul: IntType = DataType["UINT4"] + # Datatype of output elements of the Query x Key multiplication + OutQKMatMul: IntType = DataType["UINT4"] + # Activation function type of the Query x Key multiplication + ActQKMatMul: str = "thresholds" + # Output bias to be applied to the thresholding activation following + # the Query x Key multiplication + BiasActQKMatMul: float = 0.0 + + # Datatype of accumulator elements of the Attention x Value + # multiplication + AccAVMatMul: IntType = DataType["UINT4"] + # Datatype of output elements of the Attention x Value + # multiplication + OutAVMatMul: IntType = DataType["UINT4"] + # Activation function type of the Attention x Value multiplication + ActAVMatMul: str = "thresholds" + # Output bias to be applied to the thresholding activation following + # the Attention x Value multiplication + BiasActAVMatMul: float = 0.0 + + # Scale factor preceding the softmax normalization to dequantize the + # input + DequantSoftmax: float = 1.0 + # Datatype of softmax normalization before applying activation or + # type cast. THis is called Acc to stick to the naming scheme of the + # MatMul operators before. + # Note: Currently this is ALWAYS floats + AccASoftmax: FloatType = DataType["FLOAT32"] + # Activation function type of the softmax normalization of the + # attention weights + ActASoftmax: str = "thresholds" + # Output bias to be applied to the thresholding activation following + # the softmax normalization of the attention weights + BiasActASoftmax: float = 0.0 + + # Initializes those parameters which depend on the initial configuration, + # which is set by the generated __init__ + def __post_init__(self): + # The last matmul output type must match with the specified output type + assert self.OType == self.OutAVMatMul + + # Converts QONNX datatypes to their name (as a string) + def maybe_name(value): + # All QONNX datatypes are instances of the BaseDataType + if isinstance(value, BaseDataType): + # Convert to the name by referring to the datatypes name + # attribute + return value.name + # Everything else is just assumed to be in the right format + return value + + # Convert all node attributes which are registered so far to a + # dictionary matching the CustomOp format, where DataTypes are converted + # to string representations of their names + self.node_attrs = { + key: maybe_name(value) for key, value in self.__dict__.items() + } + + # Dummy float type to use the threshold generator with flot inputs + @dataclass + class DummyFloat32: + # Minimum and maximum of the represented float range + _min: float + _max: float + + # Getter for minimum of the represented range + def min(self): + return self._min + + # Getter for maximum of the represented range + def max(self): + return self._max + + # Generates thresholds representing a quantized identity function + # mapping input datatype (idt) to output datatype (odt) + def make_identity_thresholds(idt, odt, repeat=1): + # The number of thresholds is determined by the range of the output + # datatype + steps = odt.get_num_possible_values() - 1 + # The scale, or step size, is determined by the ratio between input + # and output range + scale = (idt.max() - idt.min()) / (odt.max() - odt.min()) + # Generate step thresholds covering the input range and repeat for + # multiple matrix rows/cols + return np.array( + repeat * [[scale * i + idt.min() for i in range(steps)]] + ).astype(dtype=np.float32) + + # Generate identity function thresholds mapping the query-key matmul + # accumulator type to the specified output type + self.qk_thresholds = np.round(make_identity_thresholds( + # Note: Repeat for all KVLen cols of the attention weights + self.AccQKMatMul, self.OutQKMatMul, self.KVLen + )) + + # Generate identity function thresholds mapping the float attention + # weights to the specified integer type + self.a_thresholds = make_identity_thresholds( + # Note: Repeat for all KVLen cols of the attention weights + DummyFloat32(0.0, 1.0), self.AType, self.KVLen + ) + + # Generate identity function thresholds mapping the attention-value + # matmul accumulator type to the specified output type + self.av_thresholds = np.round(make_identity_thresholds( + # Note: Repeat for all VDim cols of the output + self.AccAVMatMul, self.OutAVMatMul, self.VDim + )) + + # Computes the query-key matmul with activation function simulating + # quantization via thresholding + def qk_matmul(self, query, key): + return multithreshold(query @ key.T, self.qk_thresholds) + + # Computes the softmax normalization of attention weights with activation + # function simulating quantization via thresholding + def softmax(self, attention): + # Input and output scale factors for float <-> int conversion + iscale = self.DequantSoftmax + # Scale the inputs, normalize using softmax and activate via thresholds + return multithreshold( + softmax(iscale * attention, axis=1), self.a_thresholds + ) + + # Computes the attention-value matmul with activation function simulating + # quantization via thresholding + def av_matmul(self, attention, value): + return multithreshold(attention @ value, self.av_thresholds) + + # Computes scaled dot-product attention + def __call__(self, query, key, value): + return self.av_matmul(self.softmax(self.qk_matmul(query, key)), value) + + # Generates random sample inputs + def make_rand_input(self): + # Sample random query, key and value matrices with types and shapes + # configured as attributes + query = gen_finn_dt_tensor(self.QType, (self.QLen, self.QKDim)) + key = gen_finn_dt_tensor(self.KType, (self.KVLen, self.QKDim)) + value = gen_finn_dt_tensor(self.VType, (self.KVLen, self.VDim)) + # Return query, key, value tensors with integers represented as floats + return query, key, value + + # Creates a QONNX ModelWrapper matching the attention configuration + def make_modelwrapper(self): + # Named threshold inputs + # Note: Order matters... + thresholds = [ + "thresholds_qk_matmul", + "thresholds_a_softmax", + "thresholds_av_matmul", + ] + # Build up the node attribute dictionary + kwargs = { + # Refer to this operator type by its name + "op_type": "ScaledDotProductAttention", + # Execution will try to look up the implementation in the package + # referred to by the domain + "domain": "finn.custom_op.fpgadataflow", + # Execution backend: Required attribute inherited from HLSCustomOp + "backend": "fpgadataflow", + # Named inputs and activation thresholds + # TODO: Currently no masking support + "inputs": ["Q", "K", "V", *thresholds], + # Named model output + "outputs": ["O"], + # TODO: Currently no masking support + "mask_mode": "none" + } + + # Insert attributes into a new ONNX graph node + node = helper.make_node(**kwargs, **self.node_attrs) + + # Create random sample inputs for shape inference + q, k, v = self.make_rand_input() + # Infer the output shape from the input shapes + o_shape = (q.shape[0], v.shape[1]) + # Create onnx value info of all inputs and outputs assuming float + # datatypes + q_info = helper.make_tensor_value_info("Q", TensorProto.FLOAT, q.shape) + k_info = helper.make_tensor_value_info("K", TensorProto.FLOAT, k.shape) + v_info = helper.make_tensor_value_info("V", TensorProto.FLOAT, v.shape) + o_info = helper.make_tensor_value_info("O", TensorProto.FLOAT, o_shape) + # Collect input and output nodes in order + inputs, outputs = [q_info, k_info, v_info], [o_info] + + # Create a graph connecting the scaled dot-product attention node to the + # input and output nodes + graph = helper.make_graph( + [node], inputs=inputs, outputs=outputs, name='attention_graph' + ) + # Wrap the ONNX graph in QONNX model wrapper + model = ModelWrapper(qonnx_make_model( + graph, producer_name='attention-model' + )) + + # Add datatype annotations to all input tensors + for tensor_name in kwargs["inputs"]: + # Only annotate if a datatype is specified + if f"{tensor_name}Type" in kwargs: + # Update the datatype annotation + model.set_tensor_datatype( + tensor_name, DataType[kwargs[f"{tensor_name}Type"]] + ) + + # Add datatype annotations to all output tensors + for tensor_name in kwargs["outputs"]: + # Only annotate if a datatype is specified + if f"{tensor_name}Type" in kwargs: + # Update the datatype annotation + model.set_tensor_datatype( + tensor_name, DataType[kwargs[f"{tensor_name}Type"]] + ) + + # Set the threshold tensors as model initializer attributes of the + # appropriate type + # TODO: Uses the actual input type to the multithreshold function as + # datatype. Somehow the mvau tests always use INT32, why? + model.set_tensor_datatype("thresholds_qk_matmul", self.AccQKMatMul) + model.set_initializer("thresholds_qk_matmul", self.qk_thresholds) + + model.set_tensor_datatype("thresholds_a_softmax", DataType["FLOAT32"]) + model.set_initializer("thresholds_a_softmax", self.a_thresholds) + + model.set_tensor_datatype("thresholds_av_matmul", self.AccAVMatMul) + model.set_initializer("thresholds_av_matmul", self.av_thresholds) + + # Return the constructed qonnx model wrapper + return model + + +# Size of query and key embedding dimension +@pytest.mark.parametrize("QKDim", [4, 8, 16]) # noqa: Duplicated code fragment +# Size of value embedding dimension +@pytest.mark.parametrize("VDim", [4, 8, 16]) +# Length of key and value sequences +@pytest.mark.parametrize("KVLen", [16, 24]) +# Length of query sequence +@pytest.mark.parametrize("QLen", [16, 24]) +# Folding along the embedding dimensions +@pytest.mark.parametrize("EmbFold", [2]) +# Folding along the sequence dimensions +@pytest.mark.parametrize("SeqFold", [8]) +# Datatypes of queries, keys and values, mask and output +@pytest.mark.parametrize("QType", [DataType["UINT8"]]) +@pytest.mark.parametrize("KType", [DataType["UINT8"]]) +@pytest.mark.parametrize("VType", [DataType["UINT8"]]) +@pytest.mark.parametrize("MType", [DataType["UINT8"]]) +@pytest.mark.parametrize("AType", [DataType["UINT8"]]) +@pytest.mark.parametrize("OType", [DataType["UINT8"]]) +# Different modes to provide a mask +@pytest.mark.parametrize("mask", ["none"]) +# This is a slow running fpgadataflow type of test which requires vivado +@pytest.mark.fpgadataflow +@pytest.mark.slow +@pytest.mark.vivado +# Tests cpp simulation of single scaled dot-product attention head +def test_attention_cppsim( + # Shape configuration + QKDim, # noqa: "Argument should be lowercase" + VDim, # noqa + KVLen, # noqa + QLen, # noqa + # Folding configuration + EmbFold, # noqa + SeqFold, # noqa + # Type configuration + QType, # noqa + KType, # noqa + VType, # noqa + MType, # noqa + AType, # noqa + OType, # noqa + # Type of mask to use: either 'none', 'input', or 'causal' + mask +): + # Attention instance simulating in python and generating a matching + # QONNX configuration + attention = MockScaledDotProductAttention( # noqa: Duplicated code fragment + # Shape configuration + QKDim=QKDim, + QLen=QLen, + VDim=VDim, + KVLen=KVLen, + # Folding configuration + EmbFold=EmbFold, + SeqFold=SeqFold, + # Type configuration + QType=QType, + KType=KType, + VType=VType, + MType=MType, + AType=AType, + OType=OType, + # Accumulator type configuration + AccQKMatMul=DataType["UINT32"], + OutQKMatMul=DataType["UINT8"], + AccAVMatMul=DataType["UINT32"], + OutAVMatMul=OType, + # Dequantizer scale, factor to convert the whole UINT8 range to floats + # in range 0.0 to 1.0 + DequantSoftmax=1.0 / (DataType["UINT8"].get_num_possible_values() - 1) + ) + + # Create a QONNX model wrapper for testing + model = attention.make_modelwrapper() + # Sample some random inputs + q, k, v = attention.make_rand_input() + # Prepare execution context + context = { + "Q": q, "K": k, "V": v, "mask": mask + } + + # Mark all nodes to be specialized as HLS backend implementations + for node in model.graph.node: + # Get the CustomOp instance of the node to get access to the node + # attributes + inst = getCustomOp(node) + # Note: only HLS-based layers execute C++ Simulation + inst.set_nodeattr("preferred_impl_style", "hls") + # Turn all HWCustomOp layers into HLS specializations + model = model.transform(SpecializeLayers("xczu7ev-ffvc1156-2-e")) + + # Set model execution mode to C++ simulation + model = model.transform(SetExecMode("cppsim")) + # Generates the C++ source and compiles the C++ simulation + model = model.transform(GiveUniqueNodeNames()) + model = model.transform(PrepareCppSim()) + model = model.transform(CompileCppSim()) + + # Compute ground-truth output in software + o_expected = attention(q, k, v) # noqa: Duplicated code fragment + # Execute the onnx model to collect the result + o_produced = execute_onnx(model, context)["O"] + + # Log outputs for debugging + print(f"{o_expected}\n", file=open('o_expected_cppsim.txt', 'w')) + print(f"{o_produced}\n", file=open('o_produced_cppsim.txt', 'w')) + # Save the ONNX model graph for debugging + model.save("attention-cppsim.onnx") + + # Test whether the expectation and the onnx model output match + assert np.allclose(o_produced, o_expected), "cppsim exec failed" + + +# Size of query and key embedding dimension +@pytest.mark.parametrize("QKDim", [4]) # noqa: Duplicated code fragment +# Size of value embedding dimension +@pytest.mark.parametrize("VDim", [4]) +# Length of key and value sequences +@pytest.mark.parametrize("KVLen", [16]) +# Length of query sequence +@pytest.mark.parametrize("QLen", [16]) +# Folding along the embedding dimensions +@pytest.mark.parametrize("EmbFold", [2]) +# Folding along the sequence dimensions +@pytest.mark.parametrize("SeqFold", [8]) +# Datatypes of queries, keys and values, mask and output +@pytest.mark.parametrize("QType", [DataType["UINT8"]]) +@pytest.mark.parametrize("KType", [DataType["UINT8"]]) +@pytest.mark.parametrize("VType", [DataType["UINT8"]]) +@pytest.mark.parametrize("MType", [DataType["UINT8"]]) +@pytest.mark.parametrize("AType", [DataType["UINT8"]]) +@pytest.mark.parametrize("OType", [DataType["UINT8"]]) +# Different modes to provide a mask +@pytest.mark.parametrize("mask", ["none"]) +# This is a slow running fpgadataflow type of test which requires vivado +@pytest.mark.fpgadataflow +@pytest.mark.slow +@pytest.mark.vivado +# Tests rtl simulation of single scaled dot-product attention head +def test_attention_rtlsim( + # Shape configuration + QKDim, # noqa: "Argument should be lowercase" + VDim, # noqa + KVLen, # noqa + QLen, # noqa + # Folding configuration + EmbFold, # noqa + SeqFold, # noqa + # Type configuration + QType, # noqa + KType, # noqa + VType, # noqa + MType, # noqa + AType, # noqa + OType, # noqa + # Type of mask to use: either 'none', 'input', or 'causal' + mask +): + # Attention instance simulating in python and generating a matching + # QONNX configuration + attention = MockScaledDotProductAttention( # noqa: Duplicated code fragment + # Shape configuration + QKDim=QKDim, + QLen=QLen, + VDim=VDim, + KVLen=KVLen, + # Folding configuration + EmbFold=EmbFold, + SeqFold=SeqFold, + # Type configuration + QType=QType, + KType=KType, + VType=VType, + MType=MType, + AType=AType, + OType=OType, + # Accumulator type configuration + AccQKMatMul=DataType["UINT32"], + OutQKMatMul=DataType["UINT8"], + AccAVMatMul=DataType["UINT32"], + OutAVMatMul=OType, + # Dequantizer scale, factor to convert the whole UINT8 range to floats + # in range 0.0 to 1.0 + DequantSoftmax=1.0 / (DataType["UINT8"].get_num_possible_values() - 1) + ) + + # Create a QONNX model wrapper for testing + model = attention.make_modelwrapper() + # Sample some random inputs + q, k, v = attention.make_rand_input() + # Prepare execution context + context = { + "Q": q, "K": k, "V": v, "mask": mask + } + + # Mark all nodes to be specialized as HLS backend implementations + for node in model.graph.node: + # Get the CustomOp instance of the node to get access to the node + # attributes + inst = getCustomOp(node) + # Note: only HLS-based layers execute C++ Simulation + inst.set_nodeattr("preferred_impl_style", "hls") + # Turn all HWCustomOp layers into HLS specializations + model = model.transform(SpecializeLayers("xczu7ev-ffvc1156-2-e")) + + # Set model execution mode to RTL simulation + model = model.transform(SetExecMode("rtlsim")) + # Generates the C++ source and compiles the RTL simulation + model = model.transform(GiveUniqueNodeNames()) + model = model.transform(PrepareIP("xczu7ev-ffvc1156-2-e", 10)) + model = model.transform(HLSSynthIP()) + model = model.transform(PrepareRTLSim()) + + # Compute ground-truth output in software + o_expected = attention(q, k, v) # noqa: Duplicated code fragment + # Execute the onnx model to collect the result + o_produced = execute_onnx(model, context)["O"] + + # Log outputs for debugging + print(f"{o_expected}\n", file=open('o_expected_rtlsim.txt', 'w')) + print(f"{o_produced}\n", file=open('o_produced_rtlsim.txt', 'w')) + # Save the ONNX model graph for debugging + model.save("attention-rtlsim.onnx") + + # Test whether the expectation and the onnx model output match + assert np.allclose(o_produced, o_expected), "rtlsim exec failed" diff --git a/tests/fpgadataflow/test_fpgadataflow_attention_heads.py b/tests/fpgadataflow/test_fpgadataflow_attention_heads.py new file mode 100644 index 000000000..097d4a63f --- /dev/null +++ b/tests/fpgadataflow/test_fpgadataflow_attention_heads.py @@ -0,0 +1,412 @@ +# Testing framework +import pytest + +# Use numpy for python execution / computing the ground truth expected values +import numpy as np + +# Protobuf onnx graph node type +from onnx import TensorProto +# Helper for creating ONNX nodes +from onnx import helper as oh + +# QONNX/FINN datatypes +from qonnx.core.datatype import DataType +# QONNX wrapper to ONNX model graphs +from qonnx.core.modelwrapper import ModelWrapper +# Execute onnx model graphs +from qonnx.core.onnx_exec import execute_onnx +# Registry of all QONNX CustomOps +from qonnx.custom_op.registry import getCustomOp +# Utility for wrapping onnx graphs and generating tensor of FINN datatypes +from qonnx.util.basic import qonnx_make_model, gen_finn_dt_tensor + +# Graph transformation giving unique names to each node in a QONNX model graph +from qonnx.transformation.general import GiveUniqueNodeNames + +# FINN graph transformations for preparing simulation (cppsim or rtlsim) +from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode +from finn.transformation.fpgadataflow.prepare_cppsim import PrepareCppSim +from finn.transformation.fpgadataflow.compile_cppsim import CompileCppSim +from finn.transformation.fpgadataflow.prepare_ip import PrepareIP +from finn.transformation.fpgadataflow.hlssynth_ip import HLSSynthIP +from finn.transformation.fpgadataflow.prepare_rtlsim import PrepareRTLSim +from finn.transformation.fpgadataflow.specialize_layers import SpecializeLayers + + +# Specializes all nodes to be implemented as HLS backend +def specialize_hls(model: ModelWrapper): + # Mark all nodes to be specialized as HLS backend implementations + for node in model.graph.node: # noqa: Duplicate test setup code + # Get the CustomOp instance of the node to get access to the node + # attributes + inst = getCustomOp(node) + # Note: only HLS-based layers execute C++ Simulation + inst.set_nodeattr("preferred_impl_style", "hls") + # Turn all HWCustomOp layers into HLS specializations + return model.transform(SpecializeLayers("xczu7ev-ffvc1156-2-e")) + + +# Creates a model executing mult-head splitting +def mock_split_multi_heads(seq, dim, heads, dtype): + # Create a node representing the attention heads splitting operation + node = oh.make_node( + # Operator type from the name of the fpgadataflow hlscustomop + op_type="SplitMultiHeads", + # Specify the domain, i.e., the package to look for the custom operator + # implementation + domain="finn.custom_op.fpgadataflow", + # Execution backend: Required attribute inherited from HLSCustomOp + backend="fpgadataflow", + # Just one input + inputs=["inp"], + # Enumerate the outputs + outputs=[f"out{i}" for i in range(heads)], + # Number of attention heads to split the input into + heads=heads, + # Packed output is not supported for now + packed=False, + # Datatype of inputs and outputs + dtype=dtype, + # Number of input elements, i.e., embedding dimension + num_elems=dim, + # Number of embeddings in the whole input sequence / feature map + num_inputs=[seq] + ) + # Construct the input tensor value info + inp = oh.make_tensor_value_info("inp", TensorProto.FLOAT, [seq, dim]) + # Construct output tensor value infos + out = [oh.make_tensor_value_info( + f"out{i}", TensorProto.FLOAT, [seq, dim // heads]) for i in range(heads) + ] + # Create a graph connecting the node to the inputs and outputs + graph = oh.make_graph([node], inputs=[inp], outputs=out, name="split") + # Wrap the ONNX graph in QONNX model wrapper + model = ModelWrapper(qonnx_make_model(graph, producer_name='split')) + + # Add datatype annotation to the value info of input tensor + model.set_tensor_datatype("inp", DataType[dtype]) + # Add datatype annotation to the value infor of each output tensor + for out in (f"out{i}" for i in range(heads)): + model.set_tensor_datatype(out, DataType[dtype]) + + # Return the wrapped onnx model + return model + + +# Creates a model executing mult-head merging +def mock_merge_multi_heads(seq, dim, heads, dtype): + # Create a node representing the attention heads merging operation + node = oh.make_node( + # Operator type from the name of the fpgadataflow hlscustomop + op_type="MergeMultiHeads", + # Specify the domain, i.e., the package to look for the custom operator + # implementation + domain="finn.custom_op.fpgadataflow", + # Execution backend: Required attribute inherited from HLSCustomOp + backend="fpgadataflow", + # Enumerate the inputs + inputs=[f"inp{i}" for i in range(heads)], + # Just one output + outputs=["out"], + # Number of attention heads to split the input into + heads=heads, + # Packed output is not supported for now + packed=False, + # Datatype of inputs and outputs + dtype=dtype, + # Number of input elements, i.e., embedding dimension + num_elems=dim // heads, + # Number of embeddings in the whole input sequence / feature map + num_inputs=[seq], + # Assume squeezed output by default + squeezed=True + ) + # Construct input tensor value infos + inp = [oh.make_tensor_value_info( + f"inp{i}", TensorProto.FLOAT, [seq, dim // heads]) for i in range(heads) + ] + # Construct the output tensor value info + out = oh.make_tensor_value_info("out", TensorProto.FLOAT, [seq, dim]) + # Create a graph connecting the node to the inputs and outputs + graph = oh.make_graph([node], inputs=inp, outputs=[out], name="merge") + # Wrap the ONNX graph in QONNX model wrapper + model = ModelWrapper(qonnx_make_model(graph, producer_name='merge')) + + # Add datatype annotation to the value infor of each input tensor + for inp in (f"inp{i}" for i in range(heads)): + model.set_tensor_datatype(inp, DataType[dtype]) + # Add datatype annotation to the value info of output tensor + model.set_tensor_datatype("out", DataType[dtype]) + + # Return the wrapped onnx model + return model + + +# Sequence length to simulate, i.e., number of individual inputs to be split +@pytest.mark.parametrize("seq", [64]) +# Number of input elements to be split, i.e., size of embedding dimension +@pytest.mark.parametrize("dim", [32]) +# Number of heads to split the input into +@pytest.mark.parametrize("heads", [1, 2, 4, 8]) +# Datatypes to simulate +@pytest.mark.parametrize("dtype", ["UINT8"]) +# This is a slow running fpgadataflow type of test which requires vivado +@pytest.mark.fpgadataflow +# Tests splitting of tensors to multiple attention heads using python mode +# execution +# Note: No actual attention operation is performed +def test_attention_heads_split_python(seq, dim, heads, dtype): + # Make dummy model for testing + model = mock_split_multi_heads(seq, dim, heads, dtype) + + # Prepare the execution context + context = {"inp": gen_finn_dt_tensor(DataType[dtype], (seq, dim))} + + # Set model execution mode to python simulation + model = model.transform(SetExecMode("python")) + model = model.transform(GiveUniqueNodeNames()) + + # Compute ground-truth output in software + o_expected = np.split(context["inp"], heads, axis=-1) # noqa: Duplicate + # Execute the onnx model to collect the result + o_produced = execute_onnx(model, context) + + # Validate each output separately + for i, out in enumerate((f"out{i}" for i in range(heads))): + # Compare expected (retrieved by index) to produced (retrieve by key) + assert (o_produced[out] == o_expected[i]).all() # noqa: "all" warning + + +# Sequence length to simulate, i.e., number of individual inputs to be split +@pytest.mark.parametrize("seq", [64]) +# Number of input elements to be split, i.e., size of embedding dimension +@pytest.mark.parametrize("dim", [32]) +# Number of heads to split the input into +@pytest.mark.parametrize("heads", [1, 2, 4, 8]) +# Datatypes to simulate +@pytest.mark.parametrize("dtype", ["UINT8"]) +# This is a slow running fpgadataflow type of test which requires vivado +@pytest.mark.fpgadataflow +@pytest.mark.slow +@pytest.mark.vivado +# Tests splitting of tensors to multiple attention heads using python mode +# execution +# Note: No actual attention operation is performed +def test_attention_heads_split_cppsim(seq, dim, heads, dtype): + # Make dummy model for testing + model = mock_split_multi_heads(seq, dim, heads, dtype) + + # Prepare the execution context + context = {"inp": gen_finn_dt_tensor(DataType[dtype], (seq, dim))} + + # Specializes all nodes to be implemented as HLS backend + model = specialize_hls(model) + # Set model execution mode to Python simulation + model = model.transform(SetExecMode("cppsim")) + # Generates the C++ source and compiles the C++ simulation + model = model.transform(GiveUniqueNodeNames()) + model = model.transform(PrepareCppSim()) + model = model.transform(CompileCppSim()) + + # Compute ground-truth output in software + o_expected = np.split(context["inp"], heads, axis=-1) # noqa: Duplicate + # Execute the onnx model to collect the result + o_produced = execute_onnx(model, context) + + # Validate each output separately + for i, out in enumerate((f"out{i}" for i in range(heads))): + # Compare expected (retrieved by index) to produced (retrieve by key) + assert (o_produced[out] == o_expected[i]).all() # noqa: "all" warning + + +# Sequence length to simulate, i.e., number of individual inputs to be split +@pytest.mark.parametrize("seq", [64]) +# Number of input elements to be split, i.e., size of embedding dimension +@pytest.mark.parametrize("dim", [32]) +# Number of heads to split the input into +@pytest.mark.parametrize("heads", [1, 2, 4, 8]) +# Datatypes to simulate +@pytest.mark.parametrize("dtype", ["UINT8"]) +# This is a slow running fpgadataflow type of test which requires vivado +@pytest.mark.fpgadataflow +@pytest.mark.slow +@pytest.mark.vivado +# Tests splitting of tensors to multiple attention heads using python mode +# execution +# Note: No actual attention operation is performed +def test_attention_heads_split_rtlsim(seq, dim, heads, dtype): + # Make dummy model for testing + model = mock_split_multi_heads(seq, dim, heads, dtype) + + # Prepare the execution context + context = {"inp": gen_finn_dt_tensor(DataType[dtype], (seq, dim))} + + # Specializes all nodes to be implemented as HLS backend + model = specialize_hls(model) + # Set model execution mode to Python simulation + model = model.transform(SetExecMode("rtlsim")) + # Generates the C++ source and compiles the RTL simulation + model = model.transform(GiveUniqueNodeNames()) + model = model.transform(PrepareIP("xczu7ev-ffvc1156-2-e", 10)) # noqa + model = model.transform(HLSSynthIP()) + model = model.transform(PrepareRTLSim()) + + # Compute ground-truth output in software + o_expected = np.split(context["inp"], heads, axis=-1) # noqa: Duplicate + # Execute the onnx model to collect the result + o_produced = execute_onnx(model, context) + + # Validate each output separately + for i, out in enumerate((f"out{i}" for i in range(heads))): + # Compare expected (retrieved by index) to produced (retrieve by key) + assert (o_produced[out] == o_expected[i]).all() # noqa: "all" warning + + +# Sequence length to simulate, i.e., number of individual inputs to be split +@pytest.mark.parametrize("seq", [64]) # noqa: Duplicate, test setup +# Number of input elements to be split, i.e., size of embedding dimension +@pytest.mark.parametrize("dim", [32]) +# Number of heads to split the input into +@pytest.mark.parametrize("heads", [1, 2, 4, 8]) +# Datatypes to simulate +@pytest.mark.parametrize("dtype", ["UINT8"]) +# This is a slow running fpgadataflow type of test which requires vivado +@pytest.mark.fpgadataflow +@pytest.mark.slow +@pytest.mark.vivado +# This is a slow running fpgadataflow type of test which requires vivado +@pytest.mark.fpgadataflow +# Tests merging of tensors from multiple attention heads using python mode +# execution +# Note: No actual attention operation is performed +def test_attention_heads_merge_python(seq, dim, heads, dtype): + # Make dummy model for testing + model = mock_merge_multi_heads(seq, dim, heads, dtype) + + # Create a random input tensor of shape and datatype + def make_inp_tensor(): + return gen_finn_dt_tensor(DataType[dtype], (seq, dim // heads)) + + # Prepare the execution context + context = { + f"inp{i}": make_inp_tensor() for i in range(heads) + } + + # Set model execution mode to Python simulation + model = model.transform(SetExecMode("python")) + model = model.transform(GiveUniqueNodeNames()) + + # Compute ground-truth output in software + o_expected = np.concatenate( + [context[f"inp{i}"] for i in range(heads)], axis=-1 + ) + # Execute the onnx model to collect the result + o_produced = execute_onnx(model, context)["out"] + + # Compare expected to produced output + assert (o_produced == o_expected).all() # noqa: Unresolved "all" warning + + +# Sequence length to simulate, i.e., number of individual inputs to be split +@pytest.mark.parametrize("seq", [64]) # noqa: Duplicate, test setup +# Number of input elements to be split, i.e., size of embedding dimension +@pytest.mark.parametrize("dim", [32]) +# Number of heads to split the input into +@pytest.mark.parametrize("heads", [1, 2, 4, 8]) +# Datatypes to simulate +@pytest.mark.parametrize("dtype", ["UINT8"]) +# This is a slow running fpgadataflow type of test which requires vivado +@pytest.mark.fpgadataflow +@pytest.mark.slow +@pytest.mark.vivado +# This is a slow running fpgadataflow type of test which requires vivado +@pytest.mark.fpgadataflow +@pytest.mark.slow +@pytest.mark.vivado +# Tests merging of tensors from multiple attention heads using python mode +# execution +# Note: No actual attention operation is performed +def test_attention_heads_merge_cppsim(seq, dim, heads, dtype): + # Make dummy model for testing + model = mock_merge_multi_heads(seq, dim, heads, dtype) + + # Create a random input tensor of shape and datatype + def make_inp_tensor(): + return gen_finn_dt_tensor(DataType[dtype], (seq, dim // heads)) + + # Prepare the execution context + context = { + f"inp{i}": make_inp_tensor() for i in range(heads) + } + + # Specializes all nodes to be implemented as HLS backend + model = specialize_hls(model) + # Set model execution mode to C++ simulation + model = model.transform(SetExecMode("cppsim")) + # Generates the C++ source and compiles the C++ simulation + model = model.transform(GiveUniqueNodeNames()) + model = model.transform(PrepareCppSim()) + model = model.transform(CompileCppSim()) + + # Compute ground-truth output in software + o_expected = np.concatenate( + [context[f"inp{i}"] for i in range(heads)], axis=-1 + ) + # Execute the onnx model to collect the result + o_produced = execute_onnx(model, context)["out"] + + # Compare expected to produced output + assert (o_produced == o_expected).all() # noqa: Unresolved "all" warning + + +# Sequence length to simulate, i.e., number of individual inputs to be split +@pytest.mark.parametrize("seq", [64]) # noqa: Duplicate, test setup +# Number of input elements to be split, i.e., size of embedding dimension +@pytest.mark.parametrize("dim", [32]) +# Number of heads to split the input into +@pytest.mark.parametrize("heads", [1, 2, 4, 8]) +# Datatypes to simulate +@pytest.mark.parametrize("dtype", ["UINT8"]) +# This is a slow running fpgadataflow type of test which requires vivado +@pytest.mark.fpgadataflow +@pytest.mark.slow +@pytest.mark.vivado +# This is a slow running fpgadataflow type of test which requires vivado +@pytest.mark.fpgadataflow +@pytest.mark.slow +@pytest.mark.vivado +# Tests merging of tensors from multiple attention heads using python mode +# execution +# Note: No actual attention operation is performed +def test_attention_heads_merge_rtlsim(seq, dim, heads, dtype): + # Make dummy model for testing + model = mock_merge_multi_heads(seq, dim, heads, dtype) + + # Create a random input tensor of shape and datatype + def make_inp_tensor(): + return gen_finn_dt_tensor(DataType[dtype], (seq, dim // heads)) + + # Prepare the execution context + context = { + f"inp{i}": make_inp_tensor() for i in range(heads) + } + + # Specializes all nodes to be implemented as HLS backend + model = specialize_hls(model) + # Set model execution mode to RTL simulation + model = model.transform(SetExecMode("rtlsim")) + # Generates the C++ source and compiles the RTL simulation + model = model.transform(GiveUniqueNodeNames()) + model = model.transform(PrepareIP("xczu7ev-ffvc1156-2-e", 10)) # noqa + model = model.transform(HLSSynthIP()) + model = model.transform(PrepareRTLSim()) + + # Compute ground-truth output in software + o_expected = np.concatenate( + [context[f"inp{i}"] for i in range(heads)], axis=-1 + ) + # Execute the onnx model to collect the result + o_produced = execute_onnx(model, context)["out"] + + # Compare expected to produced output + assert (o_produced == o_expected).all() # noqa: Unresolved "all" warning diff --git a/tests/fpgadataflow/test_fpgadataflow_replicate_stream.py b/tests/fpgadataflow/test_fpgadataflow_replicate_stream.py new file mode 100644 index 000000000..643a46461 --- /dev/null +++ b/tests/fpgadataflow/test_fpgadataflow_replicate_stream.py @@ -0,0 +1,227 @@ +# Testing framework +import pytest + +# Protobuf onnx graph node type +from onnx import TensorProto +# Helper for creating ONNX nodes +from onnx import helper as oh + +# QONNX/FINN datatypes +from qonnx.core.datatype import DataType +# QONNX wrapper to ONNX model graphs +from qonnx.core.modelwrapper import ModelWrapper +# Execute onnx model graphs +from qonnx.core.onnx_exec import execute_onnx +# Registry of all QONNX CustomOps +from qonnx.custom_op.registry import getCustomOp +# Utility for wrapping onnx graphs and generating tensor of FINN datatypes +from qonnx.util.basic import qonnx_make_model, gen_finn_dt_tensor + +# Graph transformation giving unique names to each node in a QONNX model graph +from qonnx.transformation.general import GiveUniqueNodeNames + +# FINN graph transformations for preparing simulation (cppsim or rtlsim) +from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode +from finn.transformation.fpgadataflow.prepare_cppsim import PrepareCppSim +from finn.transformation.fpgadataflow.compile_cppsim import CompileCppSim +from finn.transformation.fpgadataflow.prepare_ip import PrepareIP +from finn.transformation.fpgadataflow.hlssynth_ip import HLSSynthIP +from finn.transformation.fpgadataflow.prepare_rtlsim import PrepareRTLSim +from finn.transformation.fpgadataflow.specialize_layers import SpecializeLayers + + +# Specializes all nodes to be implemented as HLS backend +def specialize_hls(model: ModelWrapper): + # Mark all nodes to be specialized as HLS backend implementations + for node in model.graph.node: # noqa: Duplicate test setup code + # Get the CustomOp instance of the node to get access to the node + # attributes + inst = getCustomOp(node) + # Note: only HLS-based layers execute C++ Simulation + inst.set_nodeattr("preferred_impl_style", "hls") + # Turn all HWCustomOp layers into HLS specializations + return model.transform(SpecializeLayers("xczu7ev-ffvc1156-2-e")) + + +# Creates a model executing stream replication +def mock_replicate_streams(num_inputs, num_elems, pe, num, dtype): + # Create a node representing the stream replication operation + node = oh.make_node( + # Operator type from the name of the fpgadataflow hlscustomop + op_type="ReplicateStream", + # Specify the domain, i.e., the package to look for the custom operator + # implementation + domain="finn.custom_op.fpgadataflow", + # Execution backend: Required attribute inherited from HLSCustomOp + backend="fpgadataflow", + # Just one input + inputs=["inp"], + # Enumerate the outputs + outputs=[f"out{i}" for i in range(num)], + # Number of replicas to produce + num=num, + # Datatype of inputs and outputs + dtype=dtype, + # Number of input elements in the last dimension + num_elems=num_elems, + # Number of elements to process in parallel + PE=pe, + # Number of inputs to be processed sequentially + num_inputs=num_inputs + ) + # Shape of the input and each output + shape = [*num_inputs, num_elems] + # Construct the input tensor value info + inp = oh.make_tensor_value_info("inp", TensorProto.FLOAT, shape) + # Construct output tensor value infos + out = [oh.make_tensor_value_info( + f"out{i}", TensorProto.FLOAT, shape) for i in range(num) + ] + # Create a graph connecting the node to the inputs and outputs + graph = oh.make_graph([node], inputs=[inp], outputs=out, name="replicate") + # Wrap the ONNX graph in QONNX model wrapper + model = ModelWrapper(qonnx_make_model(graph, producer_name='replicate')) + + # Add datatype annotation to the value info of input tensor + model.set_tensor_datatype("inp", DataType[dtype]) + # Add datatype annotation to the value infor of each output tensor + for out in (f"out{i}" for i in range(num)): + model.set_tensor_datatype(out, DataType[dtype]) + + # Return the wrapped onnx model + return model + + +# Number of inputs to be processed sequentially +@pytest.mark.parametrize( # noqa Duplicate + "num_inputs", [[8], [1, 8], [2, 8], [2, 2, 8]] +) +# Number of input elements in the last dimension +@pytest.mark.parametrize("num_elems", [32]) +# Number of elements to process in parallel +@pytest.mark.parametrize("pe", [1, 2, 4, 8]) +# Number of replicas to produce +@pytest.mark.parametrize("num", [1, 2, 4, 8]) +# Datatypes to simulate +@pytest.mark.parametrize("dtype", ["FLOAT32", "UINT8", "INT4"]) +# This is a slow running fpgadataflow type of test which requires vivado +@pytest.mark.fpgadataflow +# Tests replicating of tensors/streams to multiple outputs using python mode +# execution +def test_replicate_stream_python(num_inputs, num_elems, pe, num, dtype): + # Make dummy model for testing + model = mock_replicate_streams(num_inputs, num_elems, pe, num, dtype) + + # Prepare the execution context + context = { + "inp": gen_finn_dt_tensor(DataType[dtype], (*num_inputs, num_elems)) + } + + # Set model execution mode to python simulation + model = model.transform(SetExecMode("python")) + model = model.transform(GiveUniqueNodeNames()) + + # Compute ground-truth output in software + o_expected = [context["inp"] for _ in range(num)] # noqa: Duplicate + # Execute the onnx model to collect the result + o_produced = execute_onnx(model, context) + + # Validate each output separately + for i, out in enumerate((f"out{i}" for i in range(num))): + # Compare expected (retrieved by index) to produced (retrieve by key) + assert (o_produced[out] == o_expected[i]).all() # noqa: "all" warning + + +# Number of inputs to be processed sequentially +@pytest.mark.parametrize( # noqa Duplicate + "num_inputs", [[8], [1, 8], [2, 8], [2, 2, 8]] +) +# Number of input elements in the last dimension +@pytest.mark.parametrize("num_elems", [32]) +# Number of elements to process in parallel +@pytest.mark.parametrize("pe", [1, 2, 4, 8]) +# Number of replicas to produce +@pytest.mark.parametrize("num", [1, 2, 4, 8]) +# Datatypes to simulate +@pytest.mark.parametrize("dtype", ["FLOAT32", "UINT8", "INT4"]) +# This is a slow running fpgadataflow type of test which requires vivado +@pytest.mark.fpgadataflow +@pytest.mark.slow +@pytest.mark.vivado +# Tests replicating of tensors/streams to multiple outputs using C++ mode +# execution +def test_replicate_stream_cppsim(num_inputs, num_elems, pe, num, dtype): + # Make dummy model for testing + model = mock_replicate_streams(num_inputs, num_elems, pe, num, dtype) + + # Prepare the execution context + context = { + "inp": gen_finn_dt_tensor(DataType[dtype], (*num_inputs, num_elems)) + } + + # Specializes all nodes to be implemented as HLS backend + model = specialize_hls(model) + # Set model execution mode to C++ simulation + model = model.transform(SetExecMode("cppsim")) + # Generates the C++ source and compiles the C++ simulation + model = model.transform(GiveUniqueNodeNames()) + model = model.transform(PrepareCppSim()) + model = model.transform(CompileCppSim()) + + # Compute ground-truth output in software + o_expected = [context["inp"] for _ in range(num)] # noqa: Duplicate + # Execute the onnx model to collect the result + o_produced = execute_onnx(model, context) + + # Validate each output separately + for i, out in enumerate((f"out{i}" for i in range(num))): + # Compare expected (retrieved by index) to produced (retrieve by key) + assert (o_produced[out] == o_expected[i]).all() # noqa: "all" warning + + +# Number of inputs to be processed sequentially +@pytest.mark.parametrize( # noqa Duplicate + "num_inputs", [[8], [1, 8], [2, 8], [2, 2, 8]] +) +# Number of input elements in the last dimension +@pytest.mark.parametrize("num_elems", [32]) +# Number of elements to process in parallel +@pytest.mark.parametrize("pe", [1, 2, 4, 8]) +# Number of replicas to produce +@pytest.mark.parametrize("num", [1, 2, 4, 8]) +# Datatypes to simulate +@pytest.mark.parametrize("dtype", ["FLOAT32", "UINT8", "INT4"]) +# This is a slow running fpgadataflow type of test which requires vivado +@pytest.mark.fpgadataflow +@pytest.mark.slow +@pytest.mark.vivado +# Tests replicating of tensors/streams to multiple outputs using RTL mode +# execution +def test_replicate_stream_rtlsim(num_inputs, num_elems, pe, num, dtype): + # Make dummy model for testing + model = mock_replicate_streams(num_inputs, num_elems, pe, num, dtype) + + # Prepare the execution context + context = { + "inp": gen_finn_dt_tensor(DataType[dtype], (*num_inputs, num_elems)) + } + + # Specializes all nodes to be implemented as HLS backend + model = specialize_hls(model) + # Set model execution mode to RTL simulation + model = model.transform(SetExecMode("rtlsim")) + # Generates the C++ source and compiles the RTL simulation + model = model.transform(GiveUniqueNodeNames()) + model = model.transform(PrepareIP("xczu7ev-ffvc1156-2-e", 10)) # noqa + model = model.transform(HLSSynthIP()) + model = model.transform(PrepareRTLSim()) + + # Compute ground-truth output in software + o_expected = [context["inp"] for _ in range(num)] # noqa: Duplicate + # Execute the onnx model to collect the result + o_produced = execute_onnx(model, context) + + # Validate each output separately + for i, out in enumerate((f"out{i}" for i in range(num))): + # Compare expected (retrieved by index) to produced (retrieve by key) + assert (o_produced[out] == o_expected[i]).all() # noqa: "all" warning