Skip to content

Commit

Permalink
[DATA] Implement zero-copied string dtype and accelerate shuffle.
Browse files Browse the repository at this point in the history
1. Implement a zero-copied approach to read string data from Arrow to TF.
2. Accelerate the shuffle operation of string type in ParquetDataset.

preliminary benchmarking results
- col=300, `batch_size`=1000
- `Intel(R) Xeon(R) Platinum 8369B CPU @ 2.90GHz` with 128 logical cores.

| Dataset            | list type | shuffling | throughput (samples/s) | speedup over TFRecord |
| ---                | ---       | ---       | ---                    | ---                   |
| TFRecord           | N         | N         | 1404.23                | 1.0                   |
| HbParquet          | N         | N         | 41137.53               | 29.3                  |
| HbParquet-ZeroCopy | N         | N         | 51335.40               | 36.56                 |
| TFRecord           | N         | Y         | 1343.10                | 1.0                   |
| HbParquet          | N         | Y         | 6629.60                | 4.9                   |
| HbParquet-ZeroCopy | N         | Y         | 10941.25               | 8.1                   |
| TFRecord           | Y         | N         | 1352.05                | 1.0                   |
| HbParquet          | Y         | N         | 2307.33                | 1.71                  |
| HbParquet-ZeroCopy | Y         | N         | 2869.98                | 2.12                  |
| TFRecord           | Y         | Y         | 1367.96                | 1.0                   |
| HbParquet          | Y         | Y         | 1080.03                | 0.79                  |
| HbParquet-ZeroCopy | Y         | Y         | 1454.02                | 1.06                  |

Signed-off-by: langshi.cls <[email protected]>
  • Loading branch information
francktcheng committed May 29, 2023
1 parent 0545159 commit c477fae
Show file tree
Hide file tree
Showing 10 changed files with 639 additions and 167 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ ENV HYBRIDBACKEND_WITH_CUDA=ON \
HYBRIDBACKEND_WITH_NCCL=OFF \
HYBRIDBACKEND_WITH_ARROW_ZEROCOPY=ON \
HYBRIDBACKEND_WITH_TENSORFLOW_HALF=OFF \
HYBRIDBACKEND_WITH_TENSORFLOW_DISTRO=1015 \
HYBRIDBACKEND_WITH_TENSORFLOW_DISTRO=77661015 \
HYBRIDBACKEND_USE_CXX11_ABI=0 \
HYBRIDBACKEND_WHEEL_ALIAS=-tf115-cu100 \
HYBRIDBACKEND_WHEEL_REQUIRES="tensorflow_gpu>=1.15,<2.0"
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,9 @@ COPY --from=devel_tools /opt/tools /usr/local
ENV HYBRIDBACKEND_WITH_CUDA=ON \
HYBRIDBACKEND_WITH_NCCL=ON \
HYBRIDBACKEND_WITH_ARROW_ZEROCOPY=ON \
HYBRIDBACKEND_WITH_TENSORFLOW_DISTRO=1015 \
HYBRIDBACKEND_WITH_TENSORFLOW_DISTRO=77661015 \
HYBRIDBACKEND_USE_CXX11_ABI=0 \
HYBRIDBACKEND_USE_RUFF=1 \
HYBRIDBACKEND_WHEEL_ALIAS=-tf115-cu121 \
TENSORFLOW_INCLUDE=/opt/tensorflow/tensorflow-source \
LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib:/usr/local/cuda/lib64
LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib:/usr/local/cuda/lib64
45 changes: 38 additions & 7 deletions hybridbackend/tensorflow/benchmarks/data_benchmark_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,31 @@ def benchmark(params):
tf.logging.info('Started generating mock file ...')
workspace = tempfile.mkdtemp()
params.filenames = [os.path.join(workspace, 'benchmark.parquet')]
df = pd.DataFrame(
np.random.randint(
0, 100,
size=(params.batch_size * 100, len(params.fields)),
dtype=np.int64),
columns=params.fields)
if params.use_string_data:
df = pd.DataFrame(
np.array([
[
*[
np.array(list(map(str, np.random.randint(
0, 9,
size=(np.random.randint(10, 30),),
dtype=np.int64))))
for _ in xrange(len(params.fields))]]
for _ in xrange(params.batch_size * 100)], dtype=object),
columns=params.fields)
elif params.use_fixed_len_string_data:
df = pd.DataFrame(
np.array([
['abcdefghijklmnoprstu' for _ in xrange(len(params.fields))]
for _ in xrange(params.batch_size * 100)], dtype=np.str),
columns=params.fields)
else:
df = pd.DataFrame(
np.random.randint(
0, 100,
size=(params.batch_size * 100, len(params.fields)),
dtype=np.int64),
columns=params.fields)
df.to_parquet(params.filenames[0])
tf.logging.info(f'Mock file {params.filenames[0]} generated.')
with tf.Graph().as_default():
Expand All @@ -66,7 +85,14 @@ def benchmark(params):
ds = ds.batch(params.batch_size, drop_remainder=True)
batch = tf.data.make_one_shot_iterator(ds).get_next()
train_op = tf.group(list(batch.values()) + [step.assign_add(1)])
with tf.train.MonitoredTrainingSession('') as sess:
chief_only_hooks = []
if params.profile_every_n_iter is not None:
chief_only_hooks.append(
tf.train.ProfilerHook(
save_steps=params.profile_every_n_iter,
output_dir=params.output_dir))
with tf.train.MonitoredTrainingSession(
'', chief_only_hooks=chief_only_hooks) as sess:
count = 0
prev_ts = time.time()
try:
Expand Down Expand Up @@ -100,8 +126,13 @@ def benchmark(params):
parser = argparse.ArgumentParser()
parser.add_argument('--baseline', default=False, action='store_true')
parser.add_argument('--shuffle', default=False, action='store_true')
parser.add_argument('--use-string-data', default=False, action='store_true')
parser.add_argument(
'--use-fixed-len-string-data', default=False, action='store_true')
parser.add_argument('--batch-size', type=int, default=64000)
parser.add_argument('--num-steps', type=int, default=None)
parser.add_argument('--output-dir', default='./outputs')
parser.add_argument('--profile-every-n-iter', type=int, default=None)
parser.add_argument(
'--fields', nargs='+', default=[f'f{c}' for c in xrange(200)])
parser.add_argument('filenames', nargs='*')
Expand Down
97 changes: 77 additions & 20 deletions hybridbackend/tensorflow/benchmarks/data_benchmark_tfrecord.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,39 +38,89 @@ def benchmark(params):
tf.logging.info('Started generating mock file ...')
workspace = tempfile.mkdtemp()
params.filenames = [os.path.join(workspace, 'benchmark.tfrecord')]
df = pd.DataFrame(
np.random.randint(
0, 100,
size=(params.batch_size * 100, len(params.fields)),
dtype=np.int64),
columns=params.fields)
if params.use_string_data:
df = pd.DataFrame(
np.array([
[
*[
np.array(list(map(str, np.random.randint(
0, 9,
size=(np.random.randint(10, 30),),
dtype=np.int64))))
for _ in xrange(len(params.fields))]]
for _ in xrange(params.batch_size * 100)], dtype=object),
columns=params.fields)
elif params.use_fixed_len_string_data:
df = pd.DataFrame(
np.array([
['abcdefghijklmnoprstu' for _ in xrange(len(params.fields))]
for _ in xrange(params.batch_size * 100)], dtype=np.str),
columns=params.fields)
else:
df = pd.DataFrame(
np.random.randint(
0, 100,
size=(params.batch_size * 100, len(params.fields)),
dtype=np.int64),
columns=params.fields)
writer = tf.python_io.TFRecordWriter(params.filenames[0])
for row in tq(range(params.samples)):
feats = tf.train.Features(
feature={
f: tf.train.Feature(
int64_list=tf.train.Int64List(value=[df[f][row]]))
for f in params.fields})
for row in tq(range(params.batch_size * 100)):
if params.use_string_data or params.use_fixed_len_string_data:
feats = tf.train.Features(
feature={
f: tf.train.Feature(
bytes_list=tf.train.BytesList(
value=[bytes(val, 'utf-8') for val in df[f][row]]))
for f in params.fields})
else:
feats = tf.train.Features(
feature={
f: tf.train.Feature(
int64_list=tf.train.Int64List(value=[df[f][row]]))
for f in params.fields})
example = tf.train.Example(features=feats)
writer.write(example.SerializeToString())
writer.close()
tf.logging.info(f'Mock file {params.filenames[0]} generated.')
with tf.Graph().as_default():
step = tf.train.get_or_create_global_step()
ds = tf.data.TFRecordDataset(params.filenames)
if params.shuffle:
ds = ds.shuffle(params.batch_size * 10)
ds = ds.batch(params.batch_size, drop_remainder=True)
ds = ds.map(
lambda line: tf.parse_example(
line, {f: tf.FixedLenFeature([1], tf.int64) for f in params.fields}))
if params.use_string_data or params.use_fixed_len_string_data:
ds = ds.map(
lambda line: tf.parse_example(
line, {f: tf.VarLenFeature(tf.string) for f in params.fields}))
else:
ds = ds.map(
lambda line: tf.parse_example(
line, {f: tf.FixedLenFeature([1], tf.int64) for f in params.fields}))
batch = tf.data.make_one_shot_iterator(ds).get_next()
train_op = tf.group(batch + [step.assign_add(1)])
with tf.train.MonitoredTrainingSession('') as sess:
train_op = tf.group(list(batch.values()) + [step.assign_add(1)])
chief_only_hooks = []
if params.profile_every_n_iter is not None:
chief_only_hooks.append(
tf.train.ProfilerHook(
save_steps=params.profile_every_n_iter,
output_dir=params.output_dir))
with tf.train.MonitoredTrainingSession(
'', chief_only_hooks=chief_only_hooks) as sess:
count = 0
prev_ts = time.time()
try:
while not sess.should_stop():
sess.run(train_op)
count += 1
with tq() as pbar:
should_stop = False
while not sess.should_stop() and not should_stop:
prev_sess_run = time.time()
sess.run(train_op)
sess_run_duration = time.time() - prev_sess_run
pbar.set_description(
f'{params.batch_size / sess_run_duration:6.2f} samples/sec')
pbar.update(1)
count += 1
if params.num_steps is not None:
should_stop = count >= params.num_steps
except tf.errors.OutOfRangeError:
pass
duration = time.time() - prev_ts
Expand All @@ -87,7 +137,14 @@ def benchmark(params):
os.environ['CUDA_VISIBLE_DEVICES'] = ''
tf.logging.set_verbosity(tf.logging.INFO)
parser = argparse.ArgumentParser()
parser.add_argument('--shuffle', default=False, action='store_true')
parser.add_argument('--use-string-data', default=False, action='store_true')
parser.add_argument(
'--use-fixed-len-string-data', default=False, action='store_true')
parser.add_argument('--batch-size', type=int, default=64000)
parser.add_argument('--num-steps', type=int, default=None)
parser.add_argument('--output-dir', default='./outputs')
parser.add_argument('--profile-every-n-iter', type=int, default=None)
parser.add_argument(
'--fields', nargs='+', default=[f'f{c}' for c in xrange(200)])
parser.add_argument('filenames', nargs='*')
Expand Down
102 changes: 93 additions & 9 deletions hybridbackend/tensorflow/common/arrow.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,80 @@ limitations under the License.
#include <arrow/util/thread_pool.h>
#include <tensorflow/core/framework/allocation_description.pb.h>

#include "hybridbackend/common/env.h"
#include "hybridbackend/tensorflow/common/arrow.h"
#include "hybridbackend/tensorflow/common/eigen.h"
#endif

namespace tensorflow {
namespace hybridbackend {

namespace {
inline bool ZeroCopyStringForRebatchDisabled() {
static const bool kZeroCopyStringForRebatchDisabled =
::hybridbackend::EnvVarGetBool("HB_ZERO_COPY_STRING_FOR_REBATCH_DISABLED",
false);
return kZeroCopyStringForRebatchDisabled;
}
} // namespace

#if HYBRIDBACKEND_ARROW

#if HYBRIDBACKEND_ARROW_ZEROCOPY
#if (TF_MAJOR_VERSION * 1000L + TF_MINOR_VERSION) < 1014L
ArrowStringTensorBuffer::ArrowStringTensorBuffer(
const std::shared_ptr<arrow::Buffer>& value_data_buf,
const std::shared_ptr<arrow::Buffer>& value_offsets_buf,
const uint8_t* raw_data, const int32_t* raw_value_offsets)
: value_data_buf_(value_data_buf),
value_offsets_buf_(value_offsets_buf),
raw_data_(raw_data),
raw_value_offsets_(raw_value_offsets) {}

void ArrowStringTensorBuffer::data() const { return this; }

#else
ArrowStringTensorBuffer::ArrowStringTensorBuffer(
const std::shared_ptr<arrow::Buffer>& value_data_buf,
const std::shared_ptr<arrow::Buffer>& value_offsets_buf,
const uint8_t* raw_data, const int32_t* raw_value_offsets)
: TensorBuffer(this),
value_data_buf_(value_data_buf),
value_offsets_buf_(value_offsets_buf),
raw_data_(raw_data),
raw_value_offsets_(raw_value_offsets) {}
#endif

size_t ArrowStringTensorBuffer::size() const {
LOG(ERROR) << "When using zero copy string for rebatch, please and a "
"hb.data.rebatch(batch_size) following hb.data.ParquetDataset ";
return 0;
}

TensorBuffer* ArrowStringTensorBuffer::root_buffer() { return this; }

void ArrowStringTensorBuffer::FillAllocationDescription(
AllocationDescription* proto) const {
proto->set_requested_bytes(sizeof(tstring));
proto->set_allocator_name("ZerocopyArrowStringTensorBuffer");
#if HYBRIDBACKEND_TENSORFLOW_DISTRO == 1015
// NOTE: vanilla tensorflow from community has no `data()` method of
// class `Tensor`, thus we have to leverage the FillAllocationDescription
// API to obtain the underlying ArrowStringTensorBuffer.
proto->set_ptr(reinterpret_cast<uint64>(this));
#endif
}

bool ArrowStringTensorBuffer::OwnsMemory() const { return false; }

const uint8_t* ArrowStringTensorBuffer::GetValue(int64_t i,
int32_t* out_length) {
const int32_t pos = raw_value_offsets_[i];
*out_length = raw_value_offsets_[i + 1] - pos;
return raw_data_ + pos;
}
#endif

namespace {
#if HYBRIDBACKEND_ARROW_ZEROCOPY
class ArrowPrimitiveTensorBuffer : public TensorBuffer {
Expand Down Expand Up @@ -143,15 +208,34 @@ ::arrow::Status MakeStringTensorFromArrowArray(
&actual_shape))) {
return ::arrow::Status::Invalid("Field shape is not fully defined");
}

*tensor = Tensor(DT_STRING, actual_shape);
auto tensor_vec = tensor->vec<std::string>();

for (auto i = 0; i < total_num_elems; ++i) {
int string_size;
auto string_data = array.GetValue(i, &string_size);
tensor_vec(i).assign(reinterpret_cast<const char*>(string_data),
string_size);
if (ZeroCopyStringForRebatchDisabled()) {
*tensor = Tensor(DT_STRING, actual_shape);
auto tensor_vec = tensor->vec<std::string>();

for (auto i = 0; i < total_num_elems; ++i) {
int string_size;
auto string_data = array.GetValue(i, &string_size);
tensor_vec(i).assign(reinterpret_cast<const char*>(string_data),
string_size);
}
} else {
#if HYBRIDBACKEND_ARROW_ZEROCOPY
ArrowStringTensorBuffer* tensor_buffer = new ArrowStringTensorBuffer(
array.value_data(), array.value_offsets(), array.raw_data(),
array.raw_value_offsets());
core::ScopedUnref unref(tensor_buffer);
*tensor = Tensor(DT_STRING, actual_shape, tensor_buffer);
#else
*tensor = Tensor(DT_STRING, actual_shape);
auto tensor_vec = tensor->vec<std::string>();

for (auto i = 0; i < total_num_elems; ++i) {
int string_size;
auto string_data = array.GetValue(i, &string_size);
tensor_vec(i).assign(reinterpret_cast<const char*>(string_data),
string_size);
}
#endif
}
return ::arrow::Status::OK();
}
Expand Down
27 changes: 27 additions & 0 deletions hybridbackend/tensorflow/common/arrow.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include <deque>

#if HYBRIDBACKEND_ARROW
#include <arrow/array.h>
#include <arrow/dataset/api.h>
#include <arrow/filesystem/path_util.h>
#include <arrow/record_batch.h>
Expand All @@ -34,6 +35,7 @@ limitations under the License.

#include <tensorflow/core/framework/tensor.h>
#include <tensorflow/core/lib/core/errors.h>
#include <tensorflow/core/public/version.h>

#define TF_RETURN_IF_ARROW_ERROR(...) \
do { \
Expand Down Expand Up @@ -89,6 +91,31 @@ MATCH_TYPE_AND_ARROW_ENUM(float, ::arrow::Type::FLOAT);
MATCH_TYPE_AND_ARROW_ENUM(double, ::arrow::Type::DOUBLE);
MATCH_TYPE_AND_ARROW_ENUM(string, ::arrow::Type::STRING);

#if HYBRIDBACKEND_ARROW_ZEROCOPY
class ArrowStringTensorBuffer : public TensorBuffer {
public:
ArrowStringTensorBuffer() = delete;
explicit ArrowStringTensorBuffer(
const std::shared_ptr<arrow::Buffer>& value_data_buf,
const std::shared_ptr<arrow::Buffer>& value_offsets_buf,
const uint8_t* raw_data, const int32_t* raw_value_offsets);
#if (TF_MAJOR_VERSION * 1000L + TF_MINOR_VERSION) < 1014L
void* data() const override;
#endif
const uint8_t* GetValue(int64_t i, int32_t* out_length);
size_t size() const override;
TensorBuffer* root_buffer() override;
void FillAllocationDescription(AllocationDescription* proto) const override;
bool OwnsMemory() const override;

private:
std::shared_ptr<::arrow::Buffer> value_data_buf_;
std::shared_ptr<::arrow::Buffer> value_offsets_buf_;
const uint8_t* raw_data_;
const int32_t* raw_value_offsets_;
};
#endif

Status MakeDataTypeAndRaggedRankFromArrowDataType(
const std::shared_ptr<::arrow::DataType>& arrow_dtype, DataType* dtype,
int32* ragged_rank);
Expand Down
Loading

0 comments on commit c477fae

Please sign in to comment.