From f8d7d03f1c242b578c28f32db428256c61a0a808 Mon Sep 17 00:00:00 2001 From: fanliwen Date: Fri, 24 May 2019 15:54:35 +0800 Subject: [PATCH 1/4] fix cuBERT_LOGITS multi label num_labels!=1 bug --- src/cuBERT/Bert.cpp | 5 +++-- src/cuBERT/Bert.h | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/cuBERT/Bert.cpp b/src/cuBERT/Bert.cpp index bdaa843..fd9b6ec 100644 --- a/src/cuBERT/Bert.cpp +++ b/src/cuBERT/Bert.cpp @@ -18,6 +18,7 @@ namespace cuBERT { this->max_batch_size = max_batch_size; this->seq_length = seq_length; this->hidden_size = hidden_size; + this->num_labels = num_labels; this->stream = cuBERT::cuda_stream_create(); this->cublas = cuBERT::blas_create(); @@ -48,7 +49,7 @@ namespace cuBERT { this->_embedding_output = static_cast(cuBERT::malloc(sizeof(T) * max_batch_size * seq_length * hidden_size)); this->_pooled_output = static_cast(cuBERT::malloc(sizeof(T) * max_batch_size * hidden_size)); - this->_logits = static_cast(cuBERT::malloc(sizeof(T) * max_batch_size)); + this->_logits = static_cast(cuBERT::malloc(sizeof(T) * max_batch_size * num_labels)); this->input_ids_buf = static_cast(cuBERT::malloc(sizeof(int) * max_batch_size * seq_length)); this->input_mask_buf = static_cast(cuBERT::malloc(sizeof(int8_t) * max_batch_size * seq_length)); @@ -131,7 +132,7 @@ namespace cuBERT { } void *streamId = cuBERT::blas_get_stream(cublas); - cuBERT::memcpyAsync(logits, _logits, sizeof(T) * batch_size, 2, streamId); + cuBERT::memcpyAsync(logits, _logits, sizeof(T) * batch_size * num_labels, 2, streamId); cuBERT::cuda_stream_synchronize(streamId); if (!buffer_filled) { diff --git a/src/cuBERT/Bert.h b/src/cuBERT/Bert.h index b92280b..e14b0b5 100644 --- a/src/cuBERT/Bert.h +++ b/src/cuBERT/Bert.h @@ -41,6 +41,7 @@ namespace cuBERT { size_t max_batch_size; size_t seq_length; size_t hidden_size; + size_t num_labels; BertEmbeddings *bert_embeddings; Transformer *transformer; From 8db30a40df61a808e3b03cf19751a0e6070a6987 Mon Sep 17 00:00:00 2001 From: fanliwen Date: Mon, 3 Jun 2019 18:08:43 +0800 Subject: [PATCH 2/4] output probs for classification --- README.md | 4 +++- .../java/com/zhihu/cubert/OutputType.java | 1 + python/_cuBERT.pxd | 2 +- python/cuBERT.pyx | 1 + src/cuBERT.h | 1 + src/cuBERT/Bert.cpp | 13 +++++++--- src/cuBERT/Bert.h | 3 ++- src/cuBERT/BertM.cpp | 5 +++- src/cuBERT/op/Softmax.cpp | 18 +++++++++----- src/cuBERT/op/Softmax.cu | 24 +++++++++++-------- src/cuBERT/op/Softmax.h | 7 ++++-- src/cuBERT/op_out/AdditionalOutputLayer.cpp | 19 +++++++++------ src/cuBERT/op_out/AdditionalOutputLayer.h | 9 ++++--- test/cuBERT/BertTest.cpp | 2 +- .../op_out/AdditionalOutputLayerTest.cpp | 2 +- 15 files changed, 74 insertions(+), 37 deletions(-) diff --git a/README.md b/README.md index 4d9e7a3..f93705d 100644 --- a/README.md +++ b/README.md @@ -84,7 +84,8 @@ Following outputs are supported: |cuBERT_OutputType |python code | |--- |--- | -|cuBERT_LOGITS |[`model.get_pooled_output()` * output_weights + output_bias](https://github.com/google-research/bert/blob/d66a146741588fb208450bde15aa7db143baaa69/run_classifier.py#L607)| +|cuBERT_LOGITS |[`model.get_pooled_output() * output_weights + output_bias`](https://github.com/google-research/bert/blob/d66a146741588fb208450bde15aa7db143baaa69/run_classifier.py#L607)| +|cuBERT_PROBS |`probs = tf.nn.softmax(logits, axis=-1)`| |cuBERT_POOLED_OUTPUT |`model.get_pooled_output()` | |cuBERT_SEQUENCE_OUTPUT |`model.get_sequence_output()` | |cuBERT_EMBEDDING_OUTPUT|`model.get_embedding_output()`| @@ -248,3 +249,4 @@ to achieve the best trade-off. Good luck! * fanliwen * wangruixin * fangkuan +* sunxian diff --git a/java/src/main/java/com/zhihu/cubert/OutputType.java b/java/src/main/java/com/zhihu/cubert/OutputType.java index dd6605e..7bf7d93 100644 --- a/java/src/main/java/com/zhihu/cubert/OutputType.java +++ b/java/src/main/java/com/zhihu/cubert/OutputType.java @@ -5,4 +5,5 @@ public enum OutputType { POOLED_OUTPUT, SEQUENCE_OUTPUT, EMBEDDING_OUTPUT, + PROBS, } diff --git a/python/_cuBERT.pxd b/python/_cuBERT.pxd index 50b6424..26d3002 100644 --- a/python/_cuBERT.pxd +++ b/python/_cuBERT.pxd @@ -3,7 +3,7 @@ cdef extern from "../src/cuBERT.h": cuBERT_COMPUTE_FLOAT, cuBERT_COMPUTE_HALF cdef enum cuBERT_OutputType: - cuBERT_LOGITS, cuBERT_POOLED_OUTPUT, cuBERT_SEQUENCE_OUTPUT, cuBERT_EMBEDDING_OUTPUT + cuBERT_LOGITS, cuBERT_POOLED_OUTPUT, cuBERT_SEQUENCE_OUTPUT, cuBERT_EMBEDDING_OUTPUT, cuBERT_PROBS void cuBERT_initialize() except +; void cuBERT_finalize() except +; diff --git a/python/cuBERT.pyx b/python/cuBERT.pyx index 1f0d939..bc7daca 100644 --- a/python/cuBERT.pyx +++ b/python/cuBERT.pyx @@ -33,6 +33,7 @@ class OutputType: POOLED_OUTPUT = _cuBERT.cuBERT_OutputType.cuBERT_POOLED_OUTPUT SEQUENCE_OUTPUT = _cuBERT.cuBERT_OutputType.cuBERT_SEQUENCE_OUTPUT EMBEDDING_OUTPUT = _cuBERT.cuBERT_OutputType.cuBERT_EMBEDDING_OUTPUT + PROBS = _cuBERT.cuBERT_OutputType.cuBERT_PROBS cdef class Model: cdef void* _c_model diff --git a/src/cuBERT.h b/src/cuBERT.h index cf46871..0739393 100644 --- a/src/cuBERT.h +++ b/src/cuBERT.h @@ -19,6 +19,7 @@ enum cuBERT_OutputType { cuBERT_POOLED_OUTPUT = 1, cuBERT_SEQUENCE_OUTPUT = 2, cuBERT_EMBEDDING_OUTPUT = 3, + cuBERT_PROBS = 4, }; void cuBERT_initialize(); diff --git a/src/cuBERT/Bert.cpp b/src/cuBERT/Bert.cpp index fd9b6ec..f36c5f5 100644 --- a/src/cuBERT/Bert.cpp +++ b/src/cuBERT/Bert.cpp @@ -50,6 +50,7 @@ namespace cuBERT { this->_embedding_output = static_cast(cuBERT::malloc(sizeof(T) * max_batch_size * seq_length * hidden_size)); this->_pooled_output = static_cast(cuBERT::malloc(sizeof(T) * max_batch_size * hidden_size)); this->_logits = static_cast(cuBERT::malloc(sizeof(T) * max_batch_size * num_labels)); + this->_probs = static_cast(cuBERT::malloc(sizeof(T) * max_batch_size * num_labels)); this->input_ids_buf = static_cast(cuBERT::malloc(sizeof(int) * max_batch_size * seq_length)); this->input_mask_buf = static_cast(cuBERT::malloc(sizeof(int8_t) * max_batch_size * seq_length)); @@ -66,6 +67,7 @@ namespace cuBERT { cuBERT::free(input_mask_buf); cuBERT::free(input_ids_buf); + cuBERT::free(_probs); cuBERT::free(_logits); cuBERT::free(_pooled_output); cuBERT::free(_embedding_output); @@ -118,7 +120,7 @@ namespace cuBERT { bert_pooler->compute(batch_size, _sequence_output, _pooled_output); if (additional_output_layer != nullptr) { - additional_output_layer->_in_compute(batch_size, _pooled_output, _logits); + additional_output_layer->_in_compute(batch_size, _pooled_output, _logits, _probs); } // buffers should be re-computed in the next request @@ -126,13 +128,18 @@ namespace cuBERT { } template - void Bert::logits(size_t batch_size, T *logits) { + void Bert::logits(size_t batch_size, T *logits, T *probs) { if (additional_output_layer == nullptr) { std::cerr << "model does not have additional_output_layer, the output logits is wrong." << std::endl; } void *streamId = cuBERT::blas_get_stream(cublas); - cuBERT::memcpyAsync(logits, _logits, sizeof(T) * batch_size * num_labels, 2, streamId); + if (logits != nullptr) { + cuBERT::memcpyAsync(logits, _logits, sizeof(T) * batch_size * num_labels, 2, streamId); + } + if (probs != nullptr) { + cuBERT::memcpyAsync(probs, _probs, sizeof(T) * batch_size * num_labels, 2, streamId); + } cuBERT::cuda_stream_synchronize(streamId); if (!buffer_filled) { diff --git a/src/cuBERT/Bert.h b/src/cuBERT/Bert.h index e14b0b5..765ab16 100644 --- a/src/cuBERT/Bert.h +++ b/src/cuBERT/Bert.h @@ -29,7 +29,7 @@ namespace cuBERT { void compute(size_t batch_size, int *input_ids, int8_t *input_mask, int8_t *segment_ids); // ouput methods, cpu/gpu outputs - void logits(size_t batch_size, T *logits); + void logits(size_t batch_size, T *logits, T *probs); void pooled_output(size_t batch_size, T *pooled_output); void sequence_output(size_t batch_size, T *sequence_output); void embedding_output(size_t batch_size, T *embedding_output); @@ -58,6 +58,7 @@ namespace cuBERT { T *_sequence_output; T *_pooled_output; T *_logits; + T *_probs; // for pre-compute // FIXME: _sequence_output will be flushed diff --git a/src/cuBERT/BertM.cpp b/src/cuBERT/BertM.cpp index a6b935b..67978ff 100644 --- a/src/cuBERT/BertM.cpp +++ b/src/cuBERT/BertM.cpp @@ -77,7 +77,10 @@ namespace cuBERT { bert_instance->compute(batch_size, input_ids, input_mask, segment_ids); switch (output_type) { case cuBERT_LOGITS: - bert_instance->logits(batch_size, output); + bert_instance->logits(batch_size, output, nullptr); + break; + case cuBERT_PROBS: + bert_instance->logits(batch_size, nullptr, output); break; case cuBERT_POOLED_OUTPUT: bert_instance->pooled_output(batch_size, output); diff --git a/src/cuBERT/op/Softmax.cpp b/src/cuBERT/op/Softmax.cpp index 3c6a408..7805cf8 100644 --- a/src/cuBERT/op/Softmax.cpp +++ b/src/cuBERT/op/Softmax.cpp @@ -8,7 +8,8 @@ namespace cuBERT { #ifdef HAVE_MKL template<> - void softmax_(float *inout, + void softmax_(float *in, + float *out, const int batch_size, const int channel, float *sum_gpu, @@ -17,12 +18,12 @@ namespace cuBERT { for (int batch_idx = 0; batch_idx < batch_size; ++batch_idx) { float sum = 0; for (int i = batch_idx * channel; i < (batch_idx + 1) * channel; ++i) { - inout[i] = expf(inout[i]); - sum += inout[i]; + out[i] = expf(in[i]); + sum += out[i]; } for (int i = batch_idx * channel; i < (batch_idx + 1) * channel; ++i) { - inout[i] = inout[i] / sum; + out[i] = out[i] / sum; } } } @@ -40,8 +41,13 @@ namespace cuBERT { } template - void Softmax::compute_(size_t batch_size, T *inout_gpu, void* stream) { - softmax_(inout_gpu, batch_size, channel, sum_gpu, stream); + void Softmax::compute_(size_t batch_size, T *inout, void* stream) { + softmax_(inout, inout, batch_size, channel, sum_gpu, stream); + } + + template + void Softmax::compute_(size_t batch_size, T *in, T *out, void* stream) { + softmax_(in, out, batch_size, channel, sum_gpu, stream); } template class Softmax; diff --git a/src/cuBERT/op/Softmax.cu b/src/cuBERT/op/Softmax.cu index 713aebd..48afdf9 100644 --- a/src/cuBERT/op/Softmax.cu +++ b/src/cuBERT/op/Softmax.cu @@ -40,14 +40,18 @@ namespace cuBERT { } template - __global__ void kernel_substract_(T *inout, const int batch_size, const int channel, T *max_in) { + __global__ void kernel_substract(const T *__restrict__ in, + T *out, + const int batch_size, + const int channel, + const T *__restrict__ max_in) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx >= batch_size * channel) { return; } int batch_idx = idx / channel; - inout[idx] = (float) __ldg(inout + idx) - (float) __ldg(max_in + batch_idx); + out[idx] = (float) __ldg(in + idx) - (float) __ldg(max_in + batch_idx); } template @@ -81,22 +85,22 @@ namespace cuBERT { } template - __host__ void softmax_(T *inout, const int batch_size, const int channel, T *sum_gpu, void* stream) { + __host__ void softmax_(T *in, T *out, const int batch_size, const int channel, T *sum_gpu, void* stream) { const int all_blocks = (batch_size * channel + 127) / 128; - kernel_max_cub <<>> (inout, batch_size, channel, sum_gpu); - kernel_substract_ <<>> (inout, batch_size, channel, sum_gpu); + kernel_max_cub <<>> (in, batch_size, channel, sum_gpu); + kernel_substract <<>> (in, out, batch_size, channel, sum_gpu); - thrust::device_ptr dev_ptr(inout); + thrust::device_ptr dev_ptr(out); thrust::transform(thrust::cuda::par.on((cudaStream_t) stream), dev_ptr, dev_ptr + batch_size * channel, dev_ptr, exp_functor()); - kernel_sum_cub <<>> (inout, batch_size, channel, sum_gpu); - kernel_scale_ <<>> (inout, batch_size, channel, sum_gpu); + kernel_sum_cub <<>> (out, batch_size, channel, sum_gpu); + kernel_scale_ <<>> (out, batch_size, channel, sum_gpu); } template - __host__ void softmax_(float *inout, const int batch_size, const int channel, float *sum_gpu, void *stream); + __host__ void softmax_(float *in, float *out, const int batch_size, const int channel, float *sum_gpu, void *stream); template - __host__ void softmax_(half *inout, const int batch_size, const int channel, half *sum_gpu, void *stream); + __host__ void softmax_(half *in, half *out, const int batch_size, const int channel, half *sum_gpu, void *stream); } diff --git a/src/cuBERT/op/Softmax.h b/src/cuBERT/op/Softmax.h index 8d9cfcb..fd44c40 100644 --- a/src/cuBERT/op/Softmax.h +++ b/src/cuBERT/op/Softmax.h @@ -6,7 +6,8 @@ namespace cuBERT { template - void softmax_(T *inout, + void softmax_(T *in, + T *out, const int batch_size, const int channel, T *sum_gpu, @@ -19,7 +20,9 @@ namespace cuBERT { virtual ~Softmax(); - void compute_(size_t batch_size, T *inout_gpu, void* stream); + void compute_(size_t batch_size, T *inout, void* stream); + + void compute_(size_t batch_size, T *in, T *out, void* stream); private: size_t channel; diff --git a/src/cuBERT/op_out/AdditionalOutputLayer.cpp b/src/cuBERT/op_out/AdditionalOutputLayer.cpp index 3d4705d..a6731b7 100644 --- a/src/cuBERT/op_out/AdditionalOutputLayer.cpp +++ b/src/cuBERT/op_out/AdditionalOutputLayer.cpp @@ -13,8 +13,6 @@ namespace cuBERT { this->handle = handle; this->hidden_size = hidden_size; this->num_labels = num_labels; - - std::cerr << "hidden_size:" << this->hidden_size << " num_labels:" << this->num_labels << std::endl; this->output_weights = static_cast(cuBERT::malloc(sizeof(T) * hidden_size * this->num_labels)); cuBERT::memcpy(this->output_weights, output_weights, sizeof(T) * hidden_size * this->num_labels, 1); @@ -28,10 +26,13 @@ namespace cuBERT { this->output_bias = nullptr; } + this->softmax = new Softmax(max_batch_size, num_labels); } template ClassifierOutputLayer::~ClassifierOutputLayer() { + delete softmax; + if (output_bias != nullptr) { cuBERT::free(output_bias); } @@ -47,7 +48,7 @@ namespace cuBERT { } template - void ClassifierOutputLayer::_in_compute(size_t batch_size, T *input, T *output) { + void ClassifierOutputLayer::_in_compute(size_t batch_size, T *input, T *output_logits, T *output_probs) { float beta = output_bias == nullptr ? 0.f : 1.f; cuBERT::blas_gemm(handle, true, false, num_labels, batch_size, hidden_size, @@ -55,13 +56,17 @@ namespace cuBERT { output_weights, hidden_size, input, hidden_size, beta, - output, num_labels); + output_logits, num_labels); + if (output_probs != nullptr) { + void* streamId = blas_get_stream(handle); + softmax->compute_(batch_size, output_logits, output_probs, streamId); + } } template - void ClassifierOutputLayer::compute(size_t batch_size, T *input, T *output) { - _pre_compute(batch_size, output); - _in_compute(batch_size, input, output); + void ClassifierOutputLayer::compute(size_t batch_size, T *input, T *output_logits, T *output_probs) { + _pre_compute(batch_size, output_logits); + _in_compute(batch_size, input, output_logits, output_probs); } template class ClassifierOutputLayer; diff --git a/src/cuBERT/op_out/AdditionalOutputLayer.h b/src/cuBERT/op_out/AdditionalOutputLayer.h index b77cc73..7f271e7 100644 --- a/src/cuBERT/op_out/AdditionalOutputLayer.h +++ b/src/cuBERT/op_out/AdditionalOutputLayer.h @@ -4,7 +4,6 @@ #include - #include "cuBERT/op/Softmax.h" namespace cuBERT { @@ -21,6 +20,8 @@ namespace cuBERT { * * logits = tf.matmul(output_layer, output_weights, transpose_b=True) * logits = tf.nn.bias_add(logits, output_bias) + * + * probabilities = tf.nn.softmax(logits, axis=-1) */ template class ClassifierOutputLayer { @@ -36,9 +37,9 @@ namespace cuBERT { void _pre_compute(size_t batch_size, T *output); - void _in_compute(size_t batch_size, T *input, T *output); + void _in_compute(size_t batch_size, T *input, T *output_logits, T *output_probs); - void compute(size_t batch_size, T *in_gpu, T *out_gpu); + void compute(size_t batch_size, T *in_gpu, T *output_logits, T *output_probs); private: void* handle; @@ -49,6 +50,8 @@ namespace cuBERT { // cpu/gpu buffer T *output_weights; T *output_bias; + + Softmax *softmax; }; } diff --git a/test/cuBERT/BertTest.cpp b/test/cuBERT/BertTest.cpp index 2c6bb12..dc4509e 100644 --- a/test/cuBERT/BertTest.cpp +++ b/test/cuBERT/BertTest.cpp @@ -49,7 +49,7 @@ TEST_F(BertTest, compute) { EXPECT_FLOAT_EQ(embedding_output[49151], 0.33240327); float logits[2]; - bert.logits(2, logits); + bert.logits(2, logits, nullptr); EXPECT_NEAR(logits[0], -2.9427543, 1e-5); EXPECT_NEAR(logits[1], -1.4876306, 1e-5); } diff --git a/test/cuBERT/op_out/AdditionalOutputLayerTest.cpp b/test/cuBERT/op_out/AdditionalOutputLayerTest.cpp index 151a18e..f5ff92b 100644 --- a/test/cuBERT/op_out/AdditionalOutputLayerTest.cpp +++ b/test/cuBERT/op_out/AdditionalOutputLayerTest.cpp @@ -22,7 +22,7 @@ TEST_F(CommonTest, additional_output) { cuBERT::memcpy(in_gpu, in, sizeof(float) * 6, 1); - aol.compute(2, in_gpu, out_gpu); + aol.compute(2, in_gpu, out_gpu, nullptr); cuBERT::memcpy(out, out_gpu, sizeof(float) * 2, 2); cuBERT::free(in_gpu); From a87e5b0c65d7000cc7653237f9752d86dc0dd82b Mon Sep 17 00:00:00 2001 From: fanliwen Date: Mon, 3 Jun 2019 18:29:05 +0800 Subject: [PATCH 3/4] fix test on GPU --- README.md | 20 ++++++++++---------- test/cuBERT/BertTest.cpp | 2 +- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index f93705d..5e103f6 100644 --- a/README.md +++ b/README.md @@ -4,9 +4,9 @@ Fast implementation of BERT inference directly on NVIDIA (CUDA, CUBLAS) and Inte [![Build Status](https://travis-ci.org/zhihu/cuBERT.svg?branch=master)](https://travis-ci.org/zhihu/cuBERT) Highly customized and optimized BERT inference directly on NVIDIA (CUDA, -CUBLAS) or Intel MKL, without tensorflow and its framework overhead. +CUBLAS) or Intel MKL, *without* tensorflow and its framework overhead. -ONLY BERT (Transformer) is supported. +**ONLY** BERT (Transformer) is supported. # Benchmark @@ -23,17 +23,17 @@ ONLY BERT (Transformer) is supported. ### GPU (cuBERT) -|batch size|128 (ms)|32 (ms)| -|--- |--- |--- | -|tensorflow|255.2 |70.0 | -|cuBERT |184.6 |54.5 | +|batch size|128 (ms) |32 (ms) | +|--- |--- |--- | +|tensorflow|255.2 |70.0 | +|cuBERT |**184.6**|**54.5**| ### CPU (mklBERT) -|batch size|128 (ms)|1 (ms)| -|--- |--- |--- | -|tensorflow|1504.0 |69.9 | -|mklBERT |984.9 |24.0 | +|batch size|128 (ms) |1 (ms) | +|--- |--- |--- | +|tensorflow|1504.0 |69.9 | +|mklBERT |**984.9**|**24.0**| Note: MKL should be run under `OMP_NUM_THREADS=?` to control its thread number. Other environment variables and their possible values includes: diff --git a/test/cuBERT/BertTest.cpp b/test/cuBERT/BertTest.cpp index dc4509e..8f4805d 100644 --- a/test/cuBERT/BertTest.cpp +++ b/test/cuBERT/BertTest.cpp @@ -104,7 +104,7 @@ TEST_F(BertHalfTest, compute) { float logits[2]; half logits_half[2]; - bert.logits(2, logits_half); + bert.logits(2, logits_half, nullptr); half2float(logits_half, logits, 2); EXPECT_NEAR(logits[0], -2.9427543, 0.01); From d9470efe60816af4bce75d53c77987fc1763f7b0 Mon Sep 17 00:00:00 2001 From: fanliwen Date: Tue, 4 Jun 2019 16:45:17 +0800 Subject: [PATCH 4/4] fix bug caused by additional_output_layer _pre_compute --- src/cuBERT/Bert.cpp | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/cuBERT/Bert.cpp b/src/cuBERT/Bert.cpp index f36c5f5..c60070d 100644 --- a/src/cuBERT/Bert.cpp +++ b/src/cuBERT/Bert.cpp @@ -58,6 +58,9 @@ namespace cuBERT { // pre-compute buffers transformer->_pre_compute(max_batch_size); + if (additional_output_layer != nullptr) { + additional_output_layer->_pre_compute(max_batch_size, _logits); + } this->buffer_filled = true; } @@ -144,6 +147,9 @@ namespace cuBERT { if (!buffer_filled) { transformer->_pre_compute(batch_size); + if (additional_output_layer != nullptr) { + additional_output_layer->_pre_compute(batch_size, _logits); + } buffer_filled = true; } } @@ -156,6 +162,9 @@ namespace cuBERT { if (!buffer_filled) { transformer->_pre_compute(batch_size); + if (additional_output_layer != nullptr) { + additional_output_layer->_pre_compute(batch_size, _logits); + } buffer_filled = true; } } @@ -169,6 +178,9 @@ namespace cuBERT { if (!buffer_filled) { transformer->_pre_compute(batch_size); + if (additional_output_layer != nullptr) { + additional_output_layer->_pre_compute(batch_size, _logits); + } buffer_filled = true; } } @@ -182,6 +194,9 @@ namespace cuBERT { if (!buffer_filled) { transformer->_pre_compute(batch_size); + if (additional_output_layer != nullptr) { + additional_output_layer->_pre_compute(batch_size, _logits); + } buffer_filled = true; } }