Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dataset iterator can't be warpped in the hybridBackend scope #84

Open
fuhailin opened this issue Nov 3, 2022 · 0 comments
Open

Dataset iterator can't be warpped in the hybridBackend scope #84

fuhailin opened this issue Nov 3, 2022 · 0 comments

Comments

@fuhailin
Copy link
Contributor

fuhailin commented Nov 3, 2022

Current behavior

I am using hybridBackend to do data parallelism, I create a dataset and make it an iterator, when I use hybridBackend scope to wrap the whole pipeline, an exception occurred after the iterator step, here is the error log:

Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/tensor_util.py", line 324, in _AssertCompatible
    fn(values)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/tensor_util.py", line 276, in _check_not_tensor
    _ = [_check_failed(v) for v in nest.flatten(values)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/tensor_util.py", line 277, in <listcomp>
    if isinstance(v, ops.Tensor)]
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/tensor_util.py", line 248, in _check_failed
    raise ValueError(v)
ValueError: Tensor("Iterator_1/Identity:0", shape=(?,), dtype=int64, device=/job:chief/task:0/device:GPU:0)

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "demo.py", line 332, in <module>
    app.run(runner)
  File "/usr/local/lib/python3.6/dist-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.6/dist-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "demo.py", line 213, in runner
    features, labels = datasource.iter.get_next()
  File "/usr/local/lib/python3.6/dist-packages/hybridbackend/tensorflow/data/iterators.py", line 120, in get_next
    DataSyncRewriting.accept(should_stop)
  File "/usr/local/lib/python3.6/dist-packages/hybridbackend/tensorflow/data/iterators.py", line 169, in accept
    should_stop = math_ops.cast(should_stop, dtypes.int32)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/util/dispatch.py", line 180, in wrapper
    return target(*args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/ops/math_ops.py", line 702, in cast
    x = ops.convert_to_tensor(x, name="x")
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/ops.py", line 1184, in convert_to_tensor
    return convert_to_tensor_v2(value, dtype, preferred_dtype, name)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/ops.py", line 1242, in convert_to_tensor_v2
    as_ref=False)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/ops.py", line 1297, in internal_convert_to_tensor
    ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/constant_op.py", line 286, in _constant_tensor_conversion_function
    return constant(v, dtype=dtype, name=name)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/constant_op.py", line 227, in constant
    allow_broadcast=True)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/constant_op.py", line 265, in _constant_impl
    allow_broadcast=allow_broadcast))
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/tensor_util.py", line 449, in make_tensor_proto
    _AssertCompatible(values, dtype)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/tensor_util.py", line 328, in _AssertCompatible
    raise TypeError("List of Tensors when single Tensor expected")
TypeError: List of Tensors when single Tensor expected

Expected behavior

System information

  • GPU model and memory: Tesla P100
  • OS Platform: Ubuntu 18.04
  • Docker version: Docker Engine - Community Version: 20.10.14
  • GCC/CUDA/cuDNN version:
  • Python/conda version: Python 3.6.9
  • TensorFlow/PyTorch version: TensorFlow:DeepRec2208

Code to reproduce

import numpy as np
import pandas as pd

new_dtypes = {"uid": np.int64, "packagename": np.int64, "label_play": np.float64}

train_df = pd.DataFrame(np.random.randint(0, 100, (5, 3)), columns=['uid', 'packagename', 'label_play'])
train_df = train_df.astype(new_dtypes)
train_df.to_parquet('train.parquet')

import tensorflow as tf
import hybridbackend.tensorflow as hb
from hybridbackend.tensorflow.data import ParquetDataset
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.experimental.ops.dataframe import to_sparse



def parquet_map(record):
    for key in record:
        record[key] = tf.reshape(record[key], [-1])
    label = record.pop("label_play")
    return record, label


# Create model
def neural_net(features):
    with tf.device("/CPU:0"):
        var = tf.get_embedding_variable(
            "var_0",
            embedding_dim=3,
            initializer=tf.ones_initializer(tf.float32),
            partitioner=tf.fixed_size_partitioner(num_shards=4),
        )

    emb = tf.nn.embedding_lookup(var, features["uid"])
    fun = tf.multiply(emb, 2.0, name="multiply")
    loss = tf.reduce_sum(fun, name="reduce_sum")
    opt = tf.train.AdagradOptimizer(0.1)

    g_v = opt.compute_gradients(loss)
    train_op = opt.apply_gradients(g_v)
    return train_op, loss


with hb.scope():
    with tf.device("/cpu:0"):
        dataset = tf.data.Dataset.list_files(["train.parquet"])
        dataset = dataset.apply(
            tf.data.experimental.parallel_interleave(
                lambda tmp_file: ParquetDataset(
                    tmp_file,
                    drop_remainder=True,
                    batch_size=2,
                    num_parallel_reads=1,
                    fields=[
                        hb.data.DataFrame.Field("uid", tf.int64, ragged_rank=0),
                        hb.data.DataFrame.Field("packagename", tf.int64, ragged_rank=0),
                        hb.data.DataFrame.Field("label_play", tf.float64, ragged_rank=0),
                    ],
                ).apply(
                    to_sparse()
                ),
                cycle_length=1,
                block_length=1,
            )
        )
        dataset = dataset.batch(2, drop_remainder=True,).map(
            map_func=parquet_map,
            num_parallel_calls=dataset_ops.AUTOTUNE,
        )
    
    iterator = dataset.make_one_shot_iterator()
    # iterator = tf.data.make_one_shot_iterator(dataset)
    features, labels = iterator.get_next()

    train_op, loss = neural_net(features)

    scaffold = tf.train.Scaffold(
        init_op=tf.group(
            tf.global_variables_initializer(),
        ),
    )

    with tf.train.MonitoredTrainingSession(
        master="", scaffold=scaffold) as mon_sess:
        while not mon_sess.should_stop():
            _, ev = mon_sess.run([train_op, loss])
            print(ev)

Willing to contribute

Yes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant