diff --git a/docs/Fused-Embedding.md b/docs/Fused-Embedding.md index f1a2a9eacd5..7bbb6f95fab 100644 --- a/docs/Fused-Embedding.md +++ b/docs/Fused-Embedding.md @@ -2,32 +2,39 @@ ## 介绍 -DeepRec 及 TensorFlow 原生的 embedding lookup 相关 API,如 safe_embedding_lookup_sparse,会创建比较多的 op,因此在 GPU 上执行时容易出现 kernel launch bound 的问题。因此,Embedding子图Fusion功能提供了一组接口,并提供了一组fusion ops,通过Fusion的Op,减少需要 launch 的 kernel 数量,并提供高性能的实现,达到在 GPU 上加速执行的目的。 +DeepRec 及 TensorFlow 原生的 embedding lookup 相关 API,如 safe_embedding_lookup_sparse,会创建比较多的 op,因此在 GPU 上执行时容易出现 kernel launch bound 的问题,且部分 op 只有 CPU 实现,速度相对较慢。因此,Embedding子图Fusion功能提供了一组接口,并提供了一组fusion ops,通过Fusion的Op,减少需要 launch 的 kernel 数量,并提供高性能的实现,达到加速执行的目的。 + + ## FeatureColumn接口 + 用户 FeatureColumn作为接口。embedding_column 会返回一个 EmbeddingColumn 的类实例,常用的 EmbeddingColumn 有: -1. `tensorflow/python/feature_column/feature_column_v2.py` 的 `EmbeddingColumn` -1. `tensorflow/contrib/layers/python/layers/feature_column.py` 的 `_EmbeddingColumn` +1. `tensorflow/python/feature_column/feature_column_v2.py` 的 `EmbeddingColumn` +2. `tensorflow/contrib/layers/python/layers/feature_column.py` 的 `_EmbeddingColumn` 然后一般会通过 `tf.feature_column.input_layer` 或 `tf.feature_column_ops.input_from_feature_columns` 等高级接口,将此实例传入,建立 lookup 相关计算图。 -因此,Embedding子图Fusion功能给上述的 `EmbeddingColumn` 类都添加了 `do_fusion` 属性,默认为 `False`,用户在使用时,可以显示的设置为 `True`,让 embedding lookup 过程使用 fused ops。 +因此,Embedding子图Fusion功能给上述的 `EmbeddingColumn` 类都添加了 `do_fusion` 属性,默认为 None,用户在使用时,可以显示的设置为 `'v1', 'v2'` 这样的 fusion 版本,让 embedding lookup 过程使用 fused ops。 如下: + +a. tf.feature_column.embedding_column + ```python import tensorflow as tf from tensorflow.python.framework import ops -columns = tf.feature_column.categorical_column_with_embedding("col_emb", dtype=tf.dtypes.int64) -W = tf.feature_column.embedding_column(categorical_column=columns, +column = tf.feature_column.categorical_column_with_embedding("col_emb", dtype=tf.dtypes.int64) +W = tf.feature_column.embedding_column( + categorical_column=column, dimension=3, initializer=tf.ones_initializer(tf.dtypes.float32), - do_fusion=True) + do_fusion='v2') ids={} ids["col_emb"] = tf.SparseTensor(indices=[[0,0],[1,1],[2,2],[3,3],[4,4]], values=tf.cast([1,2,3,4,5], tf.dtypes.int64), dense_shape=[5, 4]) -# 传入设置了 use_fused_lookup 的 EmbeddingColumn 实例 +# 传入设置了 do_fusion 的 EmbeddingColumn 实例 emb = tf.feature_column.input_layer(ids, [W]) fun = tf.multiply(emb, 2.0, name='multiply') loss = tf.reduce_sum(fun, name='reduce_sum') @@ -43,6 +50,9 @@ with tf.Session() as sess: print(sess.run([emb, train_op,loss])) print(sess.run([emb, train_op,loss])) ``` + +b. tf.contrib.layers.python.layers.feature_column.embedding_column + ```python import tensorflow as tf from tensorflow.python.framework import ops @@ -54,7 +64,7 @@ columns = feature_column.sparse_column_with_embedding(column_name="col_emb", dty W = feature_column.embedding_column(sparse_id_column=columns, dimension=3, initializer=tf.ones_initializer(tf.dtypes.float32), - do_fusion=True) + do_fusion='v2') ids={} @@ -87,12 +97,20 @@ def fused_safe_embedding_lookup_sparse(embedding_weights, name=None, partition_strategy="div", max_norm=None, - prune=True): + prune=True, + blocknums=None, + fusion_version='v2'): ``` 此接口与 DeepRec 的 `safe_embedding_lookup_sparse` 接口功能是一致的。因此参数不再赘述,可查看相关文档 + + ## fused_embedding_lookup_sparse接口 + +### 使用 v1 版本 + 通过 `nn.fused_embedding_lookup_sparse` ```python +@tf_export(v1=["nn.fused_embedding_lookup_sparse"]) def fused_embedding_lookup_sparse(params, sp_ids, sparse_weights=None, @@ -102,73 +120,79 @@ def fused_embedding_lookup_sparse(params, max_norm=None, default_id=None, prune_invalid_ids=False, + fill_empty_row=True, blocknums=None): ``` +### 使用 v2 版本 + +通过 `nn.fused_embedding_lookup_sparse_v2` +```python +@tf_export(v1=["nn.fused_embedding_lookup_sparse_v2"]) +def fused_embedding_lookup_sparse_v2(params, + sp_ids, + sparse_weights=None, + partition_strategy=None, + name=None, + combiner=None, + max_norm=None, + default_id=None, + prune=False, + fill_empty_row=True, + blocknums=None): +``` + +### 参数说明 + - `params`: List,可以含有单个的 embedding tensor 或是被 partition 过的 embedding tensors。embedding tensors 的 rank 须都为 2。 - `sp_ids`: SparseTenor,其 values 为需要查找的 id。indices 的 rank 须为 2。dense_shape 的 rank 须为 1,元素个数为 2。 -- `sparse_weights`: sparse_ids 的 values 的权重。 +- `sparse_weights`: sparse_ids 的 values 的权重。目前还暂不支持。 - `partition_strategy`: embedding tensor 的 partition 策略。 - `name`: 此 operation 的名称。 - `combiner`: entry 维度进行 combine 的策略。 - `max_norm`: 如果不为 None, 则对每个 embedding vector 都计算 l2,然后对于超过 max_norm 值的进行 normalization。 -- `default_id`: 对于 empty 的 row,填充 default_id。如果 default_id 为 None, 则默认填充 0。 -- `prune_invalid_ids`: 是否对 sparse_ids 去除非法值(id < 0)。 +- `default_id`: 若 `fill_empty_row=True`, 则对于 empty 的 row,填充 default_id。如果 default_id 为 None, 则默认填充 0。 +- `fill_empty_row`: 是否对 sparse_ids 进行空行填充,结合 `default_id` 使用。 +- `prune_invalid_ids` or `prune`: 是否去除非法值。 - `blocknums`: DynamicEmbeddingVariable 使用的参数。 + + ## 注意事项 +1. `v2` 目前仅有 GPU 实现。 +2. `v2` 目前支持 `sparse_weights` 功能,`v1` 还不支持。 +3. 目前不支持动态弹性维度、Multi-Hash Variable、AdaptiveEmbedding功能,后续会逐步支持。 +4. 使用 GPU fusion 时,可以考虑 `export TF_GPU_THREAD_MODE="gpu_private"` 以及 `export TF_GPU_THREAD_COUNT=1`。测试发现在 feature 数目较多的情况下,GPU 使用单线程去 lanuch kernels 时 overhead 较小,有助于进一步提速。 + -1. 目前 Embedding子图Fusion当前支持 Nvidia GPU 上执行。相应的 `tf.Variable` 和 `EmbeddingVariable` 及其他算子可以在 CPU 上。其中CPU版本的Embedding Fusion子图功能正在代码开发中。 -1. 目前不支持设置权重 `sparse_weights`。 -1. partition_strategy 目前只支持 div ,且在 axis = 0 上对 embedding tensor 做切分。且如果 embedding tensor 是 EmbeddingVariable 的话,目前只能是单个完整的 ev,还不支持对 ev 进行 partition 的查找模式。 -1. 目前不支持动态弹性维度、Multi-Hash Variable、AdaptiveEmbedding功能,后续会逐步支持。 -## Op 介绍及计算图 -新增了 Fused Embedding 相关算子: +## Op 介绍 + +### Fused Embedding V1 相关算子: 1. FusedEmbeddingSparsePreLookUp 2. FusedEmbeddingSparsePostLookUp 3. FusedEmbeddingSparsePostLookUpGrad +FusedEmbeddingSparsePreLookUp 主要负责 fill empty row, prune invalid id, 以及根据 partition_strategy 对 sp_ids 的 values 和 indices 进行划分。 +tf.Gather 与 EmbeddingVariable 或 tf.Variable 在同一个 device 上,在 partition 的情况下可能有多份,在不同的 device 上(分布式)。它负责接受 PreEmbedding 划分过的 values 和 indices,进行实际的 embedding vector 查找。 +FusedEmbeddingSparsePostLookUp 则负责将 embedding vector 从各个 parition 上收集回来,然后进行 combiner 及 max_norm 等相关操作。 +FusedEmbeddingSparsePostLookUpGrad 负责 FusedEmbeddingSparsePostLookUp 的反向梯度计算。 +### Fused Embedding V2 相关算子: -以底层级接口 `fused_embedding_lookup_sparse` 为例,调用之后会创建如下的计算图: -![img_1.png](Fused-Embedding/img_1.png) - -1. **FusedEmbeddingSparsePreLookUp** 主要负责 fill empty row, prune invalid id, 以及根据 partition_strategy 对 sp_ids 的 values 和 indices 进行划分。 +1. PruneInvalidAndFillEmptyRows +2. UniqueWithCountsV3 +3. PartitionWithPermutation +4. FusedEmbeddingSparsePostLookUpV2 +5. FusedEmbeddingSparsePostLookUpV3Grad -2. **tf.Gather** 与 **EmbeddingVariable** 或 **tf.Variable** 在同一个 device 上,在 partition 的情况下可能有多份,在不同的 device 上(分布式)。它负责接受 PreEmbedding 划分过的 values 和 indices,进行实际的 embedding vector 查找。 +调用 `fused_embedding_lookup_sparse_v2` 之后会依照下列顺序创建计算图: -3. **FusedEmbeddingSparsePostLookUp** 则负责将 embedding vector 从各个 parition 上收集回来,然后进行 combiner 及 max_norm 等相关操作。 - -4. **FusedEmbeddingSparsePostLookUpGrad** 负责 FusedEmbeddingSparsePostLookUp 的反向梯度计算。 +1. PruneInvalidAndFillEmptyRows 负责去除非法值及填充空行 +2. UniqueWithCountsV2 负责对 sparse_ids 进行 unique 操作,在多机多卡的情况下可以减少通信量 +3. PartitionWithPermutation 在需要对 sparse_ids 进行 partition 时候,会创建此算子,按照不同的策略进行 partition +4. **tf.Gather** 与 **EmbeddingVariable** 或 **tf.Variable** 在同一个 device 上,在 partition 的情况下可能有多份,在不同的 device 上(分布式)。它进行实际的 embedding vector 查找。 +5. **FusedEmbeddingSparsePostLookUp** 则负责将 embedding vector 从各个 parition 上收集回来,然后进行 combiner 及 max_norm 等相关操作。 +6. **FusedEmbeddingSparsePostLookUpGrad** 负责 FusedEmbeddingSparsePostLookUp 的反向梯度计算。 ## 性能对比 -在 modelzoo 中,对比了一些 model 在 unfused 以及 fused embedding 情况下性能提升(5000个 iteration 平均结果) - -Machine: -8 cores AMD EPYC 7232P CPU @ 3.20GHz. - -A100-80GB-PCIE GPU - -DLRM Model: - -| | Avg Time per Iteration | -| ------- | ---------------------- | -| Unfused | 20.78 ms | -| Fused | 17.41 ms | -| SpeedUp | 1.19x | - -DeepFM Model: - -| | Avg Time per Iteration | -| ------- | ---------------------- | -| Unfused | 37.24 ms | -| Fused | 30.98 ms | -| SpeedUp | 1.20x | - -WDL Model: - -| | Avg Time per Iteration | -| ------- | ---------------------- | -| Unfused | 36.38 ms | -| Fused | 34.52 ms | -| SpeedUp | 1.05x | +v2 算子 GPU 相关,见 `modelzoo/features/GPUFusedEmbedding` 下的测试数据 \ No newline at end of file diff --git a/docs/Fused-Embedding/img_1.png b/docs/Fused-Embedding/img_1.png deleted file mode 100644 index c2e39187dd2..00000000000 Binary files a/docs/Fused-Embedding/img_1.png and /dev/null differ diff --git a/modelzoo/features/gpu_fused_embedding/.gitignore b/modelzoo/features/gpu_fused_embedding/.gitignore index 69ebd9d97be..8efe9c0774e 100644 --- a/modelzoo/features/gpu_fused_embedding/.gitignore +++ b/modelzoo/features/gpu_fused_embedding/.gitignore @@ -2,4 +2,5 @@ */result/model_* record.py *.sh -*.nsys-rep \ No newline at end of file +*.nsys-rep +*.gz \ No newline at end of file diff --git a/modelzoo/features/gpu_fused_embedding/deepfm/README.md b/modelzoo/features/gpu_fused_embedding/deepfm/README.md index edb3dc54d6a..f43e3271fc0 100644 --- a/modelzoo/features/gpu_fused_embedding/deepfm/README.md +++ b/modelzoo/features/gpu_fused_embedding/deepfm/README.md @@ -7,15 +7,21 @@ The only difference is that this model use GPU Fused Embedding to acclerate the ```python categorical_embedding_column = tf.feature_column.embedding_column( categorical_column, dimension=16, combiner='mean', - do_fusion=True) + do_fusion='v2') ``` ## Benchmark On A100-80GB-PCIE GPU, with 8 cores AMD EPYC 7232P CPU @ 3.20GHz. Average of 5000 iterations. The perf boost: -| | Avg Time per Iteration | -| ------- | ---------------------- | -| Unfused | 37.24 ms | -| Fused | 30.98 ms | -| SpeedUp | 1.20x | +Let tensorflow use private single thread for GPU kernels: + +```bash +export TF_GPU_THREAD_MODE="gpu_private" +export TF_GPU_THREAD_COUNT=1 +``` + +| | Unfused | Fused | Speedup | +| ---------------------------- | ------- | +| Step Time, Batch Size = 512 | 31.2ms | 24.1ms | 1.29x | +| Step Time, Batch Size = 4096 | 57.1ms | 44.0ms | 1.29x | diff --git a/modelzoo/features/gpu_fused_embedding/deepfm/train.py b/modelzoo/features/gpu_fused_embedding/deepfm/train.py index 043021524c3..1d14f664366 100644 --- a/modelzoo/features/gpu_fused_embedding/deepfm/train.py +++ b/modelzoo/features/gpu_fused_embedding/deepfm/train.py @@ -95,7 +95,7 @@ def build_feature_cols(): categorical_embedding_column = tf.feature_column.embedding_column( categorical_column, dimension=16, combiner='mean', - do_fusion=True) + do_fusion='v2') wide_column.append(categorical_embedding_column) deep_column.append(categorical_embedding_column) diff --git a/modelzoo/features/gpu_fused_embedding/dlrm/README.md b/modelzoo/features/gpu_fused_embedding/dlrm/README.md index 39c79aab534..c65b0930e88 100644 --- a/modelzoo/features/gpu_fused_embedding/dlrm/README.md +++ b/modelzoo/features/gpu_fused_embedding/dlrm/README.md @@ -7,15 +7,22 @@ The only difference is that this model use GPU Fused Embedding to acclerate the ```python categorical_embedding_column = tf.feature_column.embedding_column( categorical_column, dimension=16, combiner='mean', - do_fusion=True) + do_fusion='v2') ``` ## Benchmark -On A100-80GB-PCIE GPU, with 8 cores AMD EPYC 7232P CPU @ 3.20GHz. Average of 5000 iterations. The perf boost: +On A100-80GB-PCIE GPU, with 8 cores AMD EPYC 7232P CPU @ 3.20GHz. Average of 5000 iterations. +Let tensorflow use private single thread for GPU kernels: -| | Avg Time per Iteration | -| ------- | ---------------------- | -| Unfused | 20.78 ms | -| Fused | 17.41 ms | -| SpeedUp | 1.19x | +```bash +export TF_GPU_THREAD_MODE="gpu_private" +export TF_GPU_THREAD_COUNT=1 +``` + +The perf boost: + +| | Unfused | Fused | Speedup | +| ---------------------------- | ------- | +| Step Time, Batch Size = 512 | 19.98ms | 14.81ms | 1.34x | +| Step Time, Batch Size = 4096 | 37.82ms | 28.82ms | 1.31x | \ No newline at end of file diff --git a/modelzoo/features/gpu_fused_embedding/dlrm/train.py b/modelzoo/features/gpu_fused_embedding/dlrm/train.py index b78b9a8c911..51a196d6763 100644 --- a/modelzoo/features/gpu_fused_embedding/dlrm/train.py +++ b/modelzoo/features/gpu_fused_embedding/dlrm/train.py @@ -96,7 +96,7 @@ def build_feature_cols(): tf.feature_column.embedding_column(categorical_column, dimension=16, combiner='mean', - do_fusion=True)) + do_fusion='v2')) else: column = tf.feature_column.numeric_column(column_name, shape=(1, )) dense_column.append(column) @@ -288,7 +288,7 @@ def optimizer(self): tf.summary.scalar('loss', loss) self.global_step = tf.train.get_or_create_global_step() - optimizer = tf.train.GradientDescentOptimizer( + optimizer = tf.train.AdamOptimizer( learning_rate=self.learning_rate) train_op = optimizer.minimize(loss, global_step=self.global_step) @@ -619,4 +619,4 @@ def main(tf_config=None, server=None): server=server) else: print("Task type or index error.") - sys.exit() \ No newline at end of file + sys.exit() diff --git a/modelzoo/features/gpu_fused_embedding/wide_and_deep/README.md b/modelzoo/features/gpu_fused_embedding/wide_and_deep/README.md index df91b6533ad..ee0b9af763a 100755 --- a/modelzoo/features/gpu_fused_embedding/wide_and_deep/README.md +++ b/modelzoo/features/gpu_fused_embedding/wide_and_deep/README.md @@ -8,15 +8,21 @@ The only difference is that this model use GPU Fused Embedding to acclerate the deep_columns.append(tf.feature_column.embedding_column( categorical_column, dimension=EMBEDDING_DIMENSIONS[column_name], - combiner='mean', do_fusion=True)) + combiner='mean', do_fusion='v2')) ``` ## Benchmark On A100-80GB-PCIE GPU, with 8 cores AMD EPYC 7232P CPU @ 3.20GHz. Average of 5000 iterations. The perf boost: -| | Avg Time per Iteration | -| ------- | ---------------------- | -| Unfused | 36.38 ms | -| Fused | 34.52 ms | -| SpeedUp | 1.05x | \ No newline at end of file +Let tensorflow use private single thread for GPU kernels: + +```bash +export TF_GPU_THREAD_MODE="gpu_private" +export TF_GPU_THREAD_COUNT=1 +``` + +| | Unfused | Fused | Speedup | +| ---------------------------- | ------- | +| Step Time, Batch Size = 512 | 41.3ms | 38.4ms | 1.07x | +| Step Time, Batch Size = 4096 | 75.1ms | 66.5ms | 1.12x | \ No newline at end of file diff --git a/modelzoo/features/gpu_fused_embedding/wide_and_deep/train.py b/modelzoo/features/gpu_fused_embedding/wide_and_deep/train.py index b84e48ac10d..9d9c9bfac64 100644 --- a/modelzoo/features/gpu_fused_embedding/wide_and_deep/train.py +++ b/modelzoo/features/gpu_fused_embedding/wide_and_deep/train.py @@ -163,7 +163,7 @@ def minmaxscaler(col): deep_columns.append(tf.feature_column.embedding_column( categorical_column, dimension=EMBEDDING_DIMENSIONS[column_name], - combiner='mean', do_fusion=True)) + combiner='mean', do_fusion='v2')) else: normalizer_fn = None i = CONTINUOUS_COLUMNS.index(column_name) diff --git a/tensorflow/contrib/layers/python/layers/embedding_ops.py b/tensorflow/contrib/layers/python/layers/embedding_ops.py index 34ea1401952..af14ab2f414 100644 --- a/tensorflow/contrib/layers/python/layers/embedding_ops.py +++ b/tensorflow/contrib/layers/python/layers/embedding_ops.py @@ -39,6 +39,7 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging from tensorflow.python.ops import fused_embedding_ops +from tensorflow.python.ops import fused_embedding_ops_v2 __all__ = [ "safe_embedding_lookup_sparse", "scattered_embedding_lookup", @@ -196,11 +197,12 @@ def fused_safe_embedding_lookup_sparse(embedding_weights, name=None, partition_strategy="div", max_norm=None, - blocknums=None): + blocknums=None, + fusion_version='v2'): """Functionally the same as safe_embedding_lookup_sparse but using fused embedding lookup ops in this method. """ - logging.info("Is using fused embedding lookup for this scope {}".format(name)) + logging.info("Is using fused embedding lookup {} for this scope {}".format(fusion_version, name)) if combiner is None: logging.warn("The default value of combiner will change from \"mean\" " @@ -239,18 +241,33 @@ def fused_safe_embedding_lookup_sparse(embedding_weights, array_ops.gather(original_shape, original_rank - 1) ]) - result = fused_embedding_ops.fused_embedding_lookup_sparse( - embedding_weights, - sparse_ids, - sparse_weights=sparse_weights, - partition_strategy=partition_strategy, - name=name, - combiner=combiner, - max_norm=max_norm, - default_id=default_id, - prune_invalid_ids=True, - blocknums=blocknums - ) + assert(fusion_version in ['v1', 'v2']) + if fusion_version == 'v1': + result = fused_embedding_ops.fused_embedding_lookup_sparse( + embedding_weights, + sparse_ids, + sparse_weights=sparse_weights, + partition_strategy=partition_strategy, + name=None if default_id is None else scope, + combiner=combiner, + max_norm=max_norm, + default_id=default_id, + prune_invalid_ids=True, + blocknums=blocknums + ) + else: + result = fused_embedding_ops_v2.fused_embedding_lookup_sparse_v2( + embedding_weights, + sparse_ids, + sparse_weights=sparse_weights, + partition_strategy=partition_strategy, + name=None if default_id is None else scope, + combiner=combiner, + max_norm=max_norm, + default_id=default_id, + prune=True, + blocknums=blocknums + ) # Reshape back from linear ids back into higher-dimensional dense result. final_result = array_ops.reshape( diff --git a/tensorflow/contrib/layers/python/layers/feature_column.py b/tensorflow/contrib/layers/python/layers/feature_column.py index 650fdcbeb2f..1417f9bf9f2 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column.py +++ b/tensorflow/contrib/layers/python/layers/feature_column.py @@ -221,7 +221,7 @@ def __new__(cls, *args, **kwargs): if 'ev_option' not in kwargs: kwargs['ev_option'] = variables.EmbeddingVariableOption() if 'do_fusion' not in kwargs: - kwargs['do_fusion'] = False + kwargs['do_fusion'] = None return super(_DeepEmbeddingLookupArguments, cls).__new__( cls, *args, **kwargs) @@ -1189,7 +1189,7 @@ def __new__(cls, shared_vocab_size=None, max_norm=None, trainable=True, - do_fusion=False): + do_fusion=None): if initializer is not None and not callable(initializer): raise ValueError("initializer must be callable if specified. " "Embedding of column_name: {}".format( @@ -1333,7 +1333,7 @@ def _embeddings_from_arguments(column, # This option is only enabled for scattered_embedding_column. if args.hash_key: if args.do_fusion: - raise ValueError("Both do_fusion and hash_key is set. Not support yet.") + raise ValueError("Both do_fusion and hash_key is set. Embedding fusion not support hash_key yet.") embeddings = contrib_variables.model_variable( name="weights", @@ -1439,8 +1439,9 @@ def _embeddings_from_arguments(column, input_tensor, sparse_weights=weight_tensor, combiner=args.combiner, - name=column.name + "weights", - max_norm=args.max_norm + name=column.name + "_weights", + max_norm=args.max_norm, + fusion_version=args.do_fusion ) else: return embedding_ops.safe_embedding_lookup_sparse( @@ -1483,7 +1484,7 @@ def embedding_column(sparse_id_column, tensor_name_in_ckpt=None, max_norm=None, trainable=True, - do_fusion=False): + do_fusion=None): """Creates an `_EmbeddingColumn` for feeding sparse data into a DNN. Args: diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops.py b/tensorflow/contrib/layers/python/layers/feature_column_ops.py index e87df25f52b..5314d9ae9f1 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column_ops.py +++ b/tensorflow/contrib/layers/python/layers/feature_column_ops.py @@ -132,9 +132,6 @@ def _input_from_feature_columns(columns_to_tensors, arguments = column._deep_embedding_lookup_arguments( transformed_tensor) if isinstance(column, fc._DynamicDimensionEmbeddingColumn): - if arguments.do_fusion: - raise ValueError("do_fusion is set but feature column is a _DynamicDimensionEmbeddingColumn." - "Not support yet.") output = fc._dynamic_dimension_embeddings_from_arguments( # pylint: disable=protected-access column, diff --git a/tensorflow/core/api_def/base_api/api_def_FusedEmbeddingLocalSparseLookUp.pbtxt b/tensorflow/core/api_def/base_api/api_def_FusedEmbeddingLocalSparseLookUp.pbtxt deleted file mode 100644 index 109f271e105..00000000000 --- a/tensorflow/core/api_def/base_api/api_def_FusedEmbeddingLocalSparseLookUp.pbtxt +++ /dev/null @@ -1,3 +0,0 @@ -op { - graph_op_name: "FusedEmbeddingLocalSparseLookUp" -} diff --git a/tensorflow/core/api_def/base_api/api_def_FusedEmbeddingLocalSparseLookUpGrad.pbtxt b/tensorflow/core/api_def/base_api/api_def_FusedEmbeddingLocalSparseLookUpGrad.pbtxt deleted file mode 100644 index 58e275b0a55..00000000000 --- a/tensorflow/core/api_def/base_api/api_def_FusedEmbeddingLocalSparseLookUpGrad.pbtxt +++ /dev/null @@ -1,3 +0,0 @@ -op { - graph_op_name: "FusedEmbeddingLocalSparseLookUpGrad" -} diff --git a/tensorflow/core/api_def/base_api/api_def_FusedEmbeddingSparsePostLookUpV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_FusedEmbeddingSparsePostLookUpV2.pbtxt new file mode 100644 index 00000000000..27125f32ec8 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_FusedEmbeddingSparsePostLookUpV2.pbtxt @@ -0,0 +1,3 @@ +op { + graph_op_name: "FusedEmbeddingSparsePostLookUpV2" +} diff --git a/tensorflow/core/api_def/base_api/api_def_FusedEmbeddingSparsePostLookUpV2Grad.pbtxt b/tensorflow/core/api_def/base_api/api_def_FusedEmbeddingSparsePostLookUpV2Grad.pbtxt new file mode 100644 index 00000000000..99caf561ce0 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_FusedEmbeddingSparsePostLookUpV2Grad.pbtxt @@ -0,0 +1,3 @@ +op { + graph_op_name: "FusedEmbeddingSparsePostLookUpV2Grad" +} diff --git a/tensorflow/core/api_def/base_api/api_def_PartitionWithPermutation.pbtxt b/tensorflow/core/api_def/base_api/api_def_PartitionWithPermutation.pbtxt new file mode 100644 index 00000000000..c28c7e4a1ce --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_PartitionWithPermutation.pbtxt @@ -0,0 +1,3 @@ +op { + graph_op_name: "PartitionWithPermutation" +} diff --git a/tensorflow/core/api_def/base_api/api_def_PruneInvalidAndFillEmptyRows.pbtxt b/tensorflow/core/api_def/base_api/api_def_PruneInvalidAndFillEmptyRows.pbtxt new file mode 100644 index 00000000000..b541397ffbc --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_PruneInvalidAndFillEmptyRows.pbtxt @@ -0,0 +1,3 @@ +op { + graph_op_name: "PruneInvalidAndFillEmptyRows" +} diff --git a/tensorflow/core/api_def/base_api/api_def_UniqueWithCountsV3.pbtxt b/tensorflow/core/api_def/base_api/api_def_UniqueWithCountsV3.pbtxt new file mode 100644 index 00000000000..1061545ad7a --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_UniqueWithCountsV3.pbtxt @@ -0,0 +1,3 @@ +op { + graph_op_name: "UniqueWithCountsV3" +} diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 1ce4d7db9a6..6c671c13c8c 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -1950,12 +1950,13 @@ tf_cuda_cc_test( ) tf_cuda_cc_test( - name = "fused_embedding_ops_test", + name = "fused_embedding_ops_gpu_test", size = "small", - srcs = ["fused_embedding/fused_embedding_local_ops_test.cc", - "fused_embedding/fused_embedding_pre_ops_test.cc", - "fused_embedding/fused_embedding_post_ops_test.cc", - "fused_embedding/fused_embedding_post_grad_ops_test.cc"], + srcs = ["fused_embedding/gpu/tests/prune_invalid_and_fill_empty_rows_ops_test.cc", + "fused_embedding/gpu/tests/unique_with_count_v3_ops_test.cc", + "fused_embedding/gpu/tests/partition_with_permutation_ops_test.cc", + "fused_embedding/gpu/tests/fused_embedding_post_v2_ops_test.cc", + "fused_embedding/gpu/tests/fused_embedding_post_v2_grad_ops_test.cc"], tags = tf_cuda_tests_tags(), deps = [ ":fused_embedding_ops", @@ -1964,10 +1965,8 @@ tf_cuda_cc_test( "//tensorflow/cc:cc_ops", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", - "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", - "//tensorflow/core:tensorflow", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", @@ -5325,8 +5324,11 @@ tf_kernel_library( ) tf_cuda_library( - name = "fused_embedding_common_cuh", - hdrs = ["fused_embedding/fused_embedding_common.cu.h"], + name = "fused_embedding_cuh", + hdrs = ["fused_embedding/gpu/common.cu.h", + "fused_embedding/gpu/functions/hash_functions.cu.h", + "fused_embedding/gpu/functions/kernels.cu.h", + "fused_embedding/gpu/functions/partition_select.cu.h"], ) tf_kernel_library( @@ -5338,12 +5340,18 @@ tf_kernel_library( ], hdrs = ["fused_embedding/embedding_lookup_sparse_op.h"], gpu_srcs = [ - "fused_embedding/fused_embedding_local_ops_gpu.cu.cc", - "fused_embedding/fused_embedding_pre_ops_gpus.cu.cc", - "fused_embedding/fused_embedding_post_ops_gpus.cu.cc" - ], - deps = ["//third_party/eigen3"] + DYNAMIC_DEPS + mkl_deps() + - if_cuda(["@cub_archive//:cub", ":fused_embedding_common_cuh"]), + "fused_embedding/gpu/fused_embedding_post_v2_ops_gpus.cu.cc", + "fused_embedding/gpu/partition_with_permutation_ops.cu.cc", + "fused_embedding/gpu/prune_invalid_and_fill_empty_rows_ops.cu.cc", + "fused_embedding/gpu/unique_with_count_v3_ops.cu.cc", + "fused_embedding/gpu/functions/kernels.cu.cc", + "fused_embedding/gpu/functions/partition_select.cu.cc", + ], + deps = ["//third_party/eigen3", + "//tensorflow/core/profiler:nvtx_utils"] + + DYNAMIC_DEPS + mkl_deps() + + if_cuda(["@cub_archive//:cub", + "fused_embedding_cuh"]), ) tf_cc_test( diff --git a/tensorflow/core/kernels/fused_embedding/fused_embedding_common.cu.h b/tensorflow/core/kernels/fused_embedding/fused_embedding_common.cu.h deleted file mode 100644 index eff9e7c0782..00000000000 --- a/tensorflow/core/kernels/fused_embedding/fused_embedding_common.cu.h +++ /dev/null @@ -1,80 +0,0 @@ -#ifndef TENSORFLOW_CORE_KERNELS_FUSED_EMBEDDING_FUSED_EMBEDDING_COMMON_CU_H_ -#define TENSORFLOW_CORE_KERNELS_FUSED_EMBEDDING_FUSED_EMBEDDING_COMMON_CU_H_ - -#if GOOGLE_CUDA - -#define CK_CUDA_THROW_(x) \ - do { \ - cudaError_t retval = (x); \ - if (retval != cudaSuccess) { \ - throw std::runtime_error(std::string("Runtime error: ") + \ - (cudaGetErrorString(retval)) + " " + __FILE__ + \ - ":" + std::to_string(__LINE__) + " \n"); \ - } \ - } while (0) - -namespace tensorflow { - -namespace { - -inline int CalcBlocksLinearMapping(const int problem_size, const int threads) { - return problem_size % threads == 0 ? (problem_size / threads) - : (problem_size / threads + 1); -} - -struct IndicePair { - int64_t row_in_batch; - int64_t entry_in_column; -}; - -enum Combiner { Mean, Sum, Sqrtn }; - -template -__forceinline__ __device__ float Combine(const float in, const int feature_num); - -template <> -__forceinline__ __device__ float Combine(const float in, - const int feature_num) { - return in / sqrtf(feature_num); -} - -template <> -__forceinline__ __device__ float Combine(const float in, - const int feature_num) { - return in / feature_num; -} - -template <> -__forceinline__ __device__ float Combine(const float in, - const int feature_num) { - return in; -} - -template -__forceinline__ __device__ float CombineGrad(const float grad, - const int feature_num); - -template <> -__forceinline__ __device__ float CombineGrad(const float grad, - const int feature_num) { - return grad / sqrtf(feature_num); -} - -template <> -__forceinline__ __device__ float CombineGrad(const float grad, - const int feature_num) { - return grad / feature_num; -} - -template <> -__forceinline__ __device__ float CombineGrad(const float grad, - const int feature_num) { - return grad; -} -} // namespace - -} // namespace tensorflow - -#endif // GOOGLE_CUDA - -#endif // TENSORFLOW_CORE_KERNELS_FUSED_EMBEDDING_FUSED_EMBEDDING_COMMON_CU_H_ \ No newline at end of file diff --git a/tensorflow/core/kernels/fused_embedding/fused_embedding_local_ops_gpu.cu.cc b/tensorflow/core/kernels/fused_embedding/fused_embedding_local_ops_gpu.cu.cc deleted file mode 100644 index 266b960a32c..00000000000 --- a/tensorflow/core/kernels/fused_embedding/fused_embedding_local_ops_gpu.cu.cc +++ /dev/null @@ -1,315 +0,0 @@ -#include -#include - -#include "tensorflow/core/framework/op_kernel.h" - -#if GOOGLE_CUDA - -#define EIGEN_USE_GPU - -#include "tensorflow/core/kernels/fused_embedding/fused_embedding_common.cu.h" -#include "tensorflow/core/util/gpu_kernel_helper.h" - -namespace tensorflow { -using GPUDevice = Eigen::GpuDevice; - -namespace { - -__global__ void SetToIntMaxSTG128(int* values_offset, const int batch_size) { - const int thread_offset = 4 * (blockIdx.x * blockDim.x + threadIdx.x); - const int int_max = 0x7fffffff; - if (thread_offset + 4 < batch_size) { - int4 four = make_int4(int_max, int_max, int_max, int_max); - *((int4*)(values_offset + thread_offset)) = four; - } else if (thread_offset < batch_size) { - for (int i = thread_offset; i < batch_size; i++) { - values_offset[i] = int_max; - } - } -} - -__global__ void CalcPerElementRowInBatchValuesOffset(const int64_t* indices, - int* values_offset, - const int64_t nnz) { - const int thread_offset = blockIdx.x * blockDim.x + threadIdx.x; - if (thread_offset < int(nnz)) { - const int64_t element_row = indices[2 * thread_offset]; - atomicMin(values_offset + int(element_row), thread_offset); - } -} - -template -__global__ void EmbeddingLookUp(const float* emb_variable, - const int64_t* values, const int* values_offset, - float* embedding_vector, const float max_norm, - const int emb_vec_size, - const int64_t batch_size, const int64_t nnz) { - __shared__ float l2_sum[1]; - - int value_offset = values_offset[blockIdx.x]; - int feature_num; - if (blockIdx.x == int(batch_size) - 1) { - feature_num = int(nnz) - value_offset; - } else { - feature_num = values_offset[blockIdx.x + 1] - value_offset; - } - float out = 0.0f; - for (int i = 0; i < feature_num; i++) { - float emb_element = - emb_variable[int(values[value_offset + i]) * emb_vec_size + - threadIdx.x]; - if (max_norm >= 0.0f) { - // calc l2 norm of this emb row(per block) and compare with max_norm. - // if greater than max_norm, then clip every element with factor - // max_norm / l2norm - if (threadIdx.x == 0) { - l2_sum[0] = 0.0f; - } - __syncthreads(); - atomicAdd(l2_sum, emb_element * emb_element); - __syncthreads(); - float l2_norm = sqrtf(l2_sum[0]); - if (l2_norm > max_norm) { - emb_element *= max_norm / l2_norm; - } - } - out += emb_element; - } - - // combine - out = Combine(out, feature_num); - - // store the embedding vector - embedding_vector[blockIdx.x * emb_vec_size + threadIdx.x] = out; -} - -template -__global__ void DoEmbeddingGrad(const float* top_grad, - const float* emb_variable, - const int64_t* values, const int* values_offset, - float* grad_values, const float max_norm, - const int emb_vec_size, - const int64_t batch_size, const int64_t nnz) { - __shared__ float l2_sum[1]; - const int value_offset = values_offset[blockIdx.x]; - int feature_num; - if (blockIdx.x == int(batch_size) - 1) { - feature_num = int(nnz) - value_offset; - } else { - feature_num = values_offset[blockIdx.x + 1] - value_offset; - } - float grad = top_grad[blockIdx.x * emb_vec_size + threadIdx.x]; - grad = CombineGrad(grad, feature_num); - for (int i = 0; i < feature_num; i++) { - float grad_i = grad; - if (max_norm > 0.0f) { - float emb_element = - emb_variable[int(values[value_offset + i]) * emb_vec_size + - threadIdx.x]; - if (threadIdx.x == 0) { - l2_sum[0] = 0.0f; - } - __syncthreads(); - atomicAdd(l2_sum, emb_element * emb_element); - __syncthreads(); - float l2_norm = sqrtf(l2_sum[0]); - if (l2_norm > max_norm) { - grad_i *= max_norm / l2_norm; - } - } - grad_values[(value_offset + i) * emb_vec_size + threadIdx.x] = grad_i; - } -} - -} // namespace - -class FusedEmbeddingLocalSparseLookUpGPU : public OpKernel { - public: - explicit FusedEmbeddingLocalSparseLookUpGPU(OpKernelConstruction* ctx) - : OpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("combiner", &combiner_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("max_norm", &max_norm_)); - } - - void Compute(OpKernelContext* ctx) override { - auto stream = ctx->eigen_device().stream(); - - Tensor const* values_tensor = nullptr; - OP_REQUIRES_OK(ctx, ctx->input("sp_values", &values_tensor)); - Tensor const* indices_tensor = nullptr; - OP_REQUIRES_OK(ctx, ctx->input("sp_indices", &indices_tensor)); - Tensor const* dense_shape_tensor = nullptr; - OP_REQUIRES_OK(ctx, ctx->input("sp_dense_shape", &dense_shape_tensor)); - Tensor const* emb_variable_tensor = nullptr; - OP_REQUIRES_OK(ctx, ctx->input("emb_variable", &emb_variable_tensor)); - - auto dense_shape = dense_shape_tensor->flat().data(); - const size_t batch_size = dense_shape[0]; - const int64 nnz = indices_tensor->shape().dim_size(0); - const int64 emb_vec_size = emb_variable_tensor->shape().dim_size(1); - - TensorShape emb_vectors_tensor_shape; - - emb_vectors_tensor_shape = TensorShape( - std::vector({static_cast(batch_size), emb_vec_size})); - Tensor* emb_vectors_tensor = nullptr; - // allocate output - OP_REQUIRES_OK(ctx, ctx->allocate_output(0, emb_vectors_tensor_shape, - &emb_vectors_tensor)); - - // allocate offset tensor - TensorShape values_offset_tensor_shape = - TensorShape(std::vector({static_cast(batch_size)})); - - Tensor* values_offset_tensor = nullptr; - OP_REQUIRES_OK(ctx, ctx->allocate_output(1, values_offset_tensor_shape, - &values_offset_tensor)); - - { - const int threads = 1024; - int blocks = batch_size / threads; - blocks = batch_size % threads == 0 ? blocks : blocks + 1; - SetToIntMaxSTG128<<>>( - values_offset_tensor->flat().data(), int(batch_size)); - } - { - const int threads = 1024; - int blocks = nnz % threads == 0 ? (nnz / threads) : (nnz / threads + 1); - - // calculate values offset - CalcPerElementRowInBatchValuesOffset<<>>( - reinterpret_cast( - indices_tensor->flat().data()), - values_offset_tensor->flat().data(), nnz); - } - { - const int blocks = int(batch_size); - const int threads = int(emb_vec_size); - if (combiner_ == "sqrtn") { - EmbeddingLookUp<<>>( - reinterpret_cast( - emb_variable_tensor->flat().data()), - reinterpret_cast( - values_tensor->flat().data()), - values_offset_tensor->flat().data(), - reinterpret_cast(emb_vectors_tensor->flat().data()), - max_norm_, int(emb_vec_size), batch_size, nnz); - } else if (combiner_ == "mean") { - EmbeddingLookUp<<>>( - reinterpret_cast( - emb_variable_tensor->flat().data()), - reinterpret_cast( - values_tensor->flat().data()), - values_offset_tensor->flat().data(), - reinterpret_cast(emb_vectors_tensor->flat().data()), - max_norm_, int(emb_vec_size), batch_size, nnz); - } else { - EmbeddingLookUp<<>>( - reinterpret_cast( - emb_variable_tensor->flat().data()), - reinterpret_cast( - values_tensor->flat().data()), - values_offset_tensor->flat().data(), - reinterpret_cast(emb_vectors_tensor->flat().data()), - max_norm_, int(emb_vec_size), batch_size, nnz); - } - } - } - - private: - std::string combiner_; - float max_norm_; -}; - -REGISTER_KERNEL_BUILDER(Name("FusedEmbeddingLocalSparseLookUp") - .Device(DEVICE_GPU) - .HostMemory("sp_dense_shape"), - FusedEmbeddingLocalSparseLookUpGPU); - -class FusedEmbeddingLocalSparseLookUpGradGPU : public OpKernel { - public: - explicit FusedEmbeddingLocalSparseLookUpGradGPU(OpKernelConstruction* ctx) - : OpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("combiner", &combiner_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("max_norm", &max_norm_)); - } - - void Compute(OpKernelContext* ctx) override { - auto stream = ctx->eigen_device().stream(); - - Tensor const* top_grad_tensor = nullptr; - OP_REQUIRES_OK(ctx, ctx->input("top_grad", &top_grad_tensor)); - - Tensor const* emb_variable_tensor = nullptr; - OP_REQUIRES_OK(ctx, ctx->input("emb_variable", &emb_variable_tensor)); - Tensor const* values_tensor = nullptr; - OP_REQUIRES_OK(ctx, ctx->input("sp_values", &values_tensor)); - Tensor const* values_offset_tensor = nullptr; - OP_REQUIRES_OK(ctx, ctx->input("sp_values_offset", &values_offset_tensor)); - - const int64 emb_vec_size = top_grad_tensor->shape().dim_size(1); - const int64 batch_size = top_grad_tensor->shape().dim_size(0); - const int64 nnz = values_tensor->shape().dim_size(0); - - Tensor* grad_emb_weight_sp_values_tensor; - TensorShape grad_emb_weight_sp_values_tensor_shape = - TensorShape(std::vector({nnz, emb_vec_size})); - OP_REQUIRES_OK( - ctx, ctx->allocate_output(0, grad_emb_weight_sp_values_tensor_shape, - &grad_emb_weight_sp_values_tensor)); - - { - const int blocks = int(batch_size); - const int threads = int(emb_vec_size); - - if (combiner_ == "sqrtn") { - DoEmbeddingGrad<<>>( - reinterpret_cast( - top_grad_tensor->flat().data()), - reinterpret_cast( - emb_variable_tensor->flat().data()), - reinterpret_cast( - values_tensor->flat().data()), - values_offset_tensor->flat().data(), - reinterpret_cast( - grad_emb_weight_sp_values_tensor->flat().data()), - max_norm_, emb_vec_size, batch_size, nnz); - } else if (combiner_ == "mean") { - DoEmbeddingGrad<<>>( - reinterpret_cast( - top_grad_tensor->flat().data()), - reinterpret_cast( - emb_variable_tensor->flat().data()), - reinterpret_cast( - values_tensor->flat().data()), - values_offset_tensor->flat().data(), - reinterpret_cast( - grad_emb_weight_sp_values_tensor->flat().data()), - max_norm_, emb_vec_size, batch_size, nnz); - } else { - DoEmbeddingGrad<<>>( - reinterpret_cast( - top_grad_tensor->flat().data()), - reinterpret_cast( - emb_variable_tensor->flat().data()), - reinterpret_cast( - values_tensor->flat().data()), - values_offset_tensor->flat().data(), - reinterpret_cast( - grad_emb_weight_sp_values_tensor->flat().data()), - max_norm_, emb_vec_size, batch_size, nnz); - } - } - } - - private: - float max_norm_; - std::string combiner_; -}; - -REGISTER_KERNEL_BUILDER( - Name("FusedEmbeddingLocalSparseLookUpGrad").Device(DEVICE_GPU), - FusedEmbeddingLocalSparseLookUpGradGPU); - -} // namespace tensorflow -#endif // GOOGLE_CUDA \ No newline at end of file diff --git a/tensorflow/core/kernels/fused_embedding/fused_embedding_local_ops_test.cc b/tensorflow/core/kernels/fused_embedding/fused_embedding_local_ops_test.cc deleted file mode 100644 index f22a7d5f8af..00000000000 --- a/tensorflow/core/kernels/fused_embedding/fused_embedding_local_ops_test.cc +++ /dev/null @@ -1,410 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/cc/ops/const_op.h" -#include "tensorflow/cc/ops/image_ops.h" -#include "tensorflow/cc/ops/nn_ops.h" -#include "tensorflow/cc/ops/standard_ops.h" -#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" -#include "tensorflow/core/framework/fake_input.h" -#include "tensorflow/core/framework/node_def_builder.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/conv_ops_gpu.h" -#include "tensorflow/core/kernels/ops_testutil.h" -#include "tensorflow/core/kernels/ops_util.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/platform/test_benchmark.h" -#include "tensorflow/core/public/session.h" - -namespace tensorflow { -namespace { - -enum class Device { CPU, GPU }; - -enum TestCase { Sqrtn, Mean, Sum, SqrtnAndMaxNorm200, MeanAndMaxNorm100 }; - -template -void get_node_attr_from_test_case(string& combiner_str, float& max_norm) { - if (test_case == Sqrtn) { - combiner_str = "sqrtn"; - max_norm = -1.0f; - } else if (test_case == Mean) { - combiner_str = "mean"; - max_norm = -1.0f; - } else if (test_case == Sum) { - combiner_str = "sum"; - max_norm = -1.0f; - } else if (test_case == SqrtnAndMaxNorm200) { - combiner_str = "sqrtn"; - max_norm = 200.0f; - } else if (test_case == MeanAndMaxNorm100) { - combiner_str = "mean"; - max_norm = 100.0f; - } -} - -template -void fill_emb_vector_expected(Tensor* expected); - -template <> -void fill_emb_vector_expected(Tensor* expected) { - test::FillValues( - expected, {22.627416610717773, 24.0416316986084, 25.45584487915039, - 26.870058059692383, 28.284271240234375, 29.698484420776367, - 31.112699508666992, 32.526912689208984, 73.90083312988281, - 75.63288879394531, 77.36493682861328, 79.09698486328125, - 80.82904052734375, 82.56108856201172, 84.29314422607422, - 86.02519226074219, 124.70765686035156, 126.43971252441406, - 128.17176818847656, 129.90380859375, 131.6358642578125, - 133.367919921875, 135.09996032714844, 136.83201599121094, - 107.48023223876953, 108.89444732666016, 110.30866241455078, - 111.72286987304688, 113.1370849609375, 114.55130004882812, - 115.96551513671875, 117.37973022460938}); -} - -template <> -void fill_emb_vector_expected(Tensor* expected) { - test::FillValues( - expected, {16.00000000000000, 17.00000000000000, 18.00000000000000, - 19.00000000000000, 20.00000000000000, 21.00000000000000, - 22.00000000000000, 23.00000000000000, 42.66666793823242, - 43.66666793823242, 44.66666793823242, 45.66666793823242, - 46.66666793823242, 47.66666793823242, 48.66666793823242, - 49.66666793823242, 72.00000000000000, 73.00000000000000, - 74.00000000000000, 75.00000000000000, 76.00000000000000, - 77.00000000000000, 78.00000000000000, 79.00000000000000, - 76.00000000000000, 77.00000000000000, 78.00000000000000, - 79.00000000000000, 80.00000000000000, 81.00000000000000, - 82.00000000000000, 83.00000000000000}); -} - -template <> -void fill_emb_vector_expected(Tensor* expected) { - test::FillValues( - expected, {32.0, 34.0, 36.0, 38.0, 40.0, 42.0, 44.0, 46.0, - 128.0, 131.0, 134.0, 137.0, 140.0, 143.0, 146.0, 149.0, - 216.0, 219.0, 222.0, 225.0, 228.0, 231.0, 234.0, 237.0, - 152.0, 154.0, 156.0, 158.0, 160.0, 162.0, 164.0, 166.0}); -} - -template <> -void fill_emb_vector_expected(Tensor* expected) { - test::FillValues( - expected, - {22.62741661, 24.04163170, 25.45584488, 26.87005806, 28.28427124, - 29.69848442, 31.11269951, 32.52691269, 73.90083313, 75.63288879, - 77.36493683, 79.09698486, 80.82904053, 82.56108856, 84.29314423, - 86.02519226, 92.61308289, 94.01081848, 95.40855408, 96.80628204, - 98.20401764, 99.60175323, 100.99948120, 102.39721680, 71.20205688, - 72.31395721, 73.42584991, 74.53774261, 75.64963531, 76.76153564, - 77.87342834, 78.98532867}); -} - -class FusedEmbeddingLocalSparseLookUpOpTest : public OpsTestBase { - protected: - template - void Run(Device device) { - if (device == Device::GPU) { - SetDevice(DEVICE_GPU, - std::unique_ptr(DeviceFactory::NewDevice( - "GPU", {}, "/job:a/replica:0/task:0"))); - } - DataType dtype = DataTypeToEnum::value; - std::string combiner_str; - float max_norm; - - get_node_attr_from_test_case(combiner_str, max_norm); - - TF_EXPECT_OK(NodeDefBuilder("fused_embedding_local_sparse_look_up", - "FusedEmbeddingLocalSparseLookUp") - .Input(FakeInput(DT_INT64)) - .Input(FakeInput(DT_INT64)) - .Input(FakeInput(DT_INT64)) - .Input(FakeInput(dtype)) - .Attr("T", dtype) - .Attr("combiner", combiner_str) - .Attr("max_norm", max_norm) - .Finalize(node_def())); - TF_EXPECT_OK(InitOp()); - - const int nnz = 10; - const int batch_size = 4; - const int emb_vector_dim = 8; - const int entries = 8; - const int bucket_size = 16; - - Tensor sp_values(DT_INT64, {nnz}); - Tensor sp_indices(DT_INT64, {nnz, 2}); - Tensor sp_dense_shape(DT_INT64, {2}); - Tensor emb_variable(dtype, {bucket_size, emb_vector_dim}); - - test::FillValues(&sp_values, {3, 1, 4, 5, 7, 3, 12, 12, 15, 4}); - test::FillValues(&sp_indices, {0, 1, 0, 5, 1, 2, 1, 1, 1, 7, - 2, 1, 2, 4, 2, 7, 3, 0, 3, 6}); - test::FillValues(&sp_dense_shape, {batch_size, entries}); - test::FillValues( - &emb_variable, - {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, - 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, - 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, - 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, - 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, 49.0, - 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, - 60.0, 61.0, 62.0, 63.0, 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, - 70.0, 71.0, 72.0, 73.0, 74.0, 75.0, 76.0, 77.0, 78.0, 79.0, - 80.0, 81.0, 82.0, 83.0, 84.0, 85.0, 86.0, 87.0, 88.0, 89.0, - 90.0, 91.0, 92.0, 93.0, 94.0, 95.0, 96.0, 97.0, 98.0, 99.0, - 100.0, 101.0, 102.0, 103.0, 104.0, 105.0, 106.0, 107.0, 108.0, 109.0, - 110.0, 111.0, 112.0, 113.0, 114.0, 115.0, 116.0, 117.0, 118.0, 119.0, - 120.0, 121.0, 122.0, 123.0, 124.0, 125.0, 126.0, 127.0}); - - AddInputFromArray(sp_values.shape(), sp_values.flat()); - AddInputFromArray(sp_indices.shape(), sp_indices.flat()); - AddInputFromArray(sp_dense_shape.shape(), - sp_dense_shape.flat()); - AddInputFromArray(emb_variable.shape(), emb_variable.flat()); - - TF_ASSERT_OK(RunOpKernel()); - - Tensor emb_vector_expected(dtype, {batch_size, emb_vector_dim}); - Tensor sp_values_offset_expected(DT_INT32, {batch_size}); - fill_emb_vector_expected(&emb_vector_expected); - test::FillValues(&sp_values_offset_expected, {0, 2, 5, 8}); - - const Tensor& emb_vector = *GetOutput(0); - const Tensor& values_offset = *GetOutput(1); - TF_EXPECT_OK(device_->Sync()); - - test::ExpectTensorNear(emb_vector_expected, emb_vector, 1e-4); - test::ExpectTensorEqual(sp_values_offset_expected, values_offset); - } -}; - -template -void fill_grad_expected(Tensor* expected); - -template <> -void fill_grad_expected(Tensor* expected) { - test::FillValues( - expected, {0.000000000000000, 0.7071067690849304, 1.4142135381698608, - 2.1213204860687256, 2.8284270763397217, 3.535533905029297, - 4.242640972137451, 4.949747562408447, 0.000000000000000, - 0.7071067690849304, 1.4142135381698608, 2.1213204860687256, - 2.8284270763397217, 3.535533905029297, 4.242640972137451, - 4.949747562408447, 4.618802070617676, 5.196152687072754, - 5.773502826690674, 6.350852966308594, 6.928203582763672, - 7.505553722381592, 8.082903861999512, 8.66025447845459, - 4.618802070617676, 5.196152687072754, 5.773502826690674, - 6.350852966308594, 6.928203582763672, 7.505553722381592, - 8.082903861999512, 8.66025447845459, 4.618802070617676, - 5.196152687072754, 5.773502826690674, 6.350852966308594, - 6.928203582763672, 7.505553722381592, 8.082903861999512, - 8.66025447845459, 9.237604141235352, 9.81495475769043, - 10.392305374145508, 10.96965503692627, 11.547005653381348, - 12.124356269836426, 12.701705932617188, 13.279056549072266, - 9.237604141235352, 9.81495475769043, 10.392305374145508, - 10.96965503692627, 11.547005653381348, 12.124356269836426, - 12.701705932617188, 13.279056549072266, 9.237604141235352, - 9.81495475769043, 10.392305374145508, 10.96965503692627, - 11.547005653381348, 12.124356269836426, 12.701705932617188, - 13.279056549072266, 16.970563888549805, 17.677669525146484, - 18.384777069091797, 19.091882705688477, 19.79899024963379, - 20.5060977935791, 21.21320343017578, 21.920310974121094, - 16.970563888549805, 17.677669525146484, 18.384777069091797, - 19.091882705688477, 19.79899024963379, 20.5060977935791, - 21.21320343017578, 21.920310974121094}); -} - -template <> -void fill_grad_expected(Tensor* expected) { - test::FillValues( - expected, {0.000000000000000, 0.500000000000000, 1.000000000000000, - 1.500000000000000, 2.000000000000000, 2.500000000000000, - 3.000000000000000, 3.500000000000000, 0.000000000000000, - 0.500000000000000, 1.000000000000000, 1.500000000000000, - 2.000000000000000, 2.500000000000000, 3.000000000000000, - 3.500000000000000, 2.6666667461395264, 3.000000000000000, - 3.3333332538604736, 3.6666667461395264, 4.000000000000000, - 4.333333492279053, 4.666666507720947, 5.000000000000000, - 2.6666667461395264, 3.000000000000000, 3.3333332538604736, - 3.6666667461395264, 4.000000000000000, 4.333333492279053, - 4.666666507720947, 5.000000000000000, 2.6666667461395264, - 3.000000000000000, 3.3333332538604736, 3.6666667461395264, - 4.000000000000000, 4.333333492279053, 4.666666507720947, - 5.000000000000000, 5.333333492279053, 5.666666507720947, - 6.000000000000000, 6.333333492279053, 6.666666507720947, - 7.000000000000000, 7.333333492279053, 7.666666507720947, - 5.333333492279053, 5.666666507720947, 6.000000000000000, - 6.333333492279053, 6.666666507720947, 7.000000000000000, - 7.333333492279053, 7.666666507720947, 5.333333492279053, - 5.666666507720947, 6.000000000000000, 6.333333492279053, - 6.666666507720947, 7.000000000000000, 7.333333492279053, - 7.666666507720947, 12.000000000000000, 12.500000000000000, - 13.000000000000000, 13.500000000000000, 14.000000000000000, - 14.500000000000000, 15.000000000000000, 15.500000000000000, - 12.000000000000000, 12.500000000000000, 13.000000000000000, - 13.500000000000000, 14.000000000000000, 14.500000000000000, - 15.000000000000000, 15.500000000000000}); -} - -template <> -void fill_grad_expected(Tensor* expected) { - test::FillValues( - expected, - {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 0.0, 1.0, 2.0, 3.0, - 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, - 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 8.0, 9.0, 10.0, 11.0, - 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, - 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 16.0, 17.0, 18.0, 19.0, - 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, - 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0}); -} - -template <> -void fill_grad_expected(Tensor* expected) { - test::FillValues( - expected, - {0.00000000, 0.50000000, 1.00000000, 1.50000000, 2.00000000, 2.50000000, 3.00000000, 3.50000000, - 0.00000000, 0.50000000, 1.00000000, 1.50000000, 2.00000000, 2.50000000, 3.00000000, 3.50000000, - 2.65028572, 2.98157120, 3.31285667, 3.64414287, 3.97542834, 4.30671406, 4.63799953, 4.96928549, - 2.16437674, 2.43492365, 2.70547056, 2.97601795, 3.24656487, 3.51711202, 3.78765893, 4.05820608, - 1.58337951, 1.78130186, 1.97922409, 2.17714667, 2.37506914, 2.57299161, 2.77091384, 2.96883631, - 5.33333349, 5.66666651, 6.00000000, 6.33333349, 6.66666651, 7.00000000, 7.33333349, 7.66666651, - 1.89459133, 2.01300311, 2.13141513, 2.24982715, 2.36823893, 2.48665094, 2.60506320, 2.72347474, - 1.89459133, 2.01300311, 2.13141513, 2.24982715, 2.36823893, 2.48665094, 2.60506320, 2.72347474, - 3.43474555, 3.57786012, 3.72097445, 3.86408877, 4.00720310, 4.15031767, 4.29343224, 4.43654633, - 11.92628479, 12.42321396, 12.92014217, 13.41707039, 13.91399956, 14.41092777, 14.90785599, 15.40478516}); -} - -class FusedEmbeddingLocalSparseLookUpGradOpTest : public OpsTestBase { - protected: - template - void Run(Device device) { - if (device == Device::GPU) { - SetDevice(DEVICE_GPU, - std::unique_ptr(DeviceFactory::NewDevice( - "GPU", {}, "/job:a/replica:0/task:0"))); - } - DataType dtype = DataTypeToEnum::value; - std::string combiner_str; - float max_norm; - get_node_attr_from_test_case(combiner_str, max_norm); - - TF_EXPECT_OK(NodeDefBuilder("fused_embedding_local_sparse_look_up_grad", - "FusedEmbeddingLocalSparseLookUpGrad") - .Input(FakeInput(dtype)) - .Input(FakeInput(dtype)) - .Input(FakeInput(DT_INT64)) - .Input(FakeInput(DT_INT32)) - .Attr("T", dtype) - .Attr("combiner", combiner_str) - .Attr("max_norm", max_norm) - .Finalize(node_def())); - TF_EXPECT_OK(InitOp()); - - const int nnz = 10; - const int batch_size = 4; - const int emb_vector_dim = 8; - const int bucket_size = 16; - - Tensor top_grad(dtype, {batch_size, emb_vector_dim}); - Tensor emb_variable(dtype, {bucket_size, emb_vector_dim}); - Tensor sp_values(DT_INT64, {nnz}); - Tensor sp_values_offset(DT_INT32, {batch_size}); - - test::FillValues( - &top_grad, - {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, - 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, - 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0}); - test::FillValues( - &emb_variable, - {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, - 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, - 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, - 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, - 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, 49.0, - 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, - 60.0, 61.0, 62.0, 63.0, 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, - 70.0, 71.0, 72.0, 73.0, 74.0, 75.0, 76.0, 77.0, 78.0, 79.0, - 80.0, 81.0, 82.0, 83.0, 84.0, 85.0, 86.0, 87.0, 88.0, 89.0, - 90.0, 91.0, 92.0, 93.0, 94.0, 95.0, 96.0, 97.0, 98.0, 99.0, - 100.0, 101.0, 102.0, 103.0, 104.0, 105.0, 106.0, 107.0, 108.0, 109.0, - 110.0, 111.0, 112.0, 113.0, 114.0, 115.0, 116.0, 117.0, 118.0, 119.0, - 120.0, 121.0, 122.0, 123.0, 124.0, 125.0, 126.0, 127.0}); - test::FillValues(&sp_values, {3, 1, 4, 5, 7, 3, 12, 12, 15, 4}); - test::FillValues(&sp_values_offset, {0, 2, 5, 8}); - - AddInputFromArray(top_grad.shape(), top_grad.flat()); - AddInputFromArray(emb_variable.shape(), emb_variable.flat()); - AddInputFromArray(sp_values.shape(), sp_values.flat()); - AddInputFromArray(sp_values_offset.shape(), - sp_values_offset.flat()); - - TF_ASSERT_OK(RunOpKernel()); - - Tensor grad_expected(dtype, {nnz, emb_vector_dim}); - fill_grad_expected(&grad_expected); - - const Tensor& grad = *GetOutput(0); - TF_EXPECT_OK(device_->Sync()); - - test::ExpectTensorNear(grad_expected, grad, 1e-4); - } -}; - -#ifdef GOOGLE_CUDA -TEST_F(FusedEmbeddingLocalSparseLookUpOpTest, EmbeddingLocalSparseLookUpFloatSqrtnGpu) { - Run(Device::GPU); -} - -TEST_F(FusedEmbeddingLocalSparseLookUpOpTest, EmbeddingLocalSparseLookUpFloatMeanGpu) { - Run(Device::GPU); -} - -TEST_F(FusedEmbeddingLocalSparseLookUpOpTest, EmbeddingLocalSparseLookUpFloatSumGpu) { - Run(Device::GPU); -} - -TEST_F(FusedEmbeddingLocalSparseLookUpOpTest, - EmbeddingLocalSparseLookUpFloatSqrtnAndMaxNorm200Gpu) { - Run(Device::GPU); -} - -TEST_F(FusedEmbeddingLocalSparseLookUpGradOpTest, - EmbeddingLocalSparseLookUpGradFloatGpu) { - Run(Device::GPU); -} - -TEST_F(FusedEmbeddingLocalSparseLookUpGradOpTest, - EmbeddingLocalSparseLookUpGradFloatMeanGpu) { - Run(Device::GPU); -} - -TEST_F(FusedEmbeddingLocalSparseLookUpGradOpTest, - EmbeddingLocalSparseLookUpGradFloatSumGpu) { - Run(Device::GPU); -} - -TEST_F(FusedEmbeddingLocalSparseLookUpGradOpTest, - EmbeddingLocalSparseLookUpGradFloatMeanAndMaxNorm100Gpu) { - Run(Device::GPU); -} - -#endif - -} // namespace -} // namespace tensorflow \ No newline at end of file diff --git a/tensorflow/core/kernels/fused_embedding/fused_embedding_post_grad_ops_test.cc b/tensorflow/core/kernels/fused_embedding/fused_embedding_post_grad_ops_test.cc deleted file mode 100644 index acef29612fc..00000000000 --- a/tensorflow/core/kernels/fused_embedding/fused_embedding_post_grad_ops_test.cc +++ /dev/null @@ -1,243 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" -#include "tensorflow/core/framework/fake_input.h" -#include "tensorflow/core/framework/node_def_builder.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/conv_ops_gpu.h" -#include "tensorflow/core/kernels/ops_testutil.h" -#include "tensorflow/core/kernels/ops_util.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/platform/test_benchmark.h" -#include "tensorflow/core/public/session.h" - -namespace tensorflow { -namespace { - -enum class Device { CPU, GPU }; - -class FusedEmbeddingSparsePostLookUpGradOpTest : public OpsTestBase { - protected: - void MakeOpAndSetDevice(Device device, int num_partitions, DataType dtype, - const std::string& combiner, const float max_norm, - const int default_id) { - if (device == Device::GPU) { - SetDevice(DEVICE_GPU, - std::unique_ptr(DeviceFactory::NewDevice( - "GPU", {}, "/job:a/replica:0/task:0"))); - } - - TF_EXPECT_OK(NodeDefBuilder("fused_embedding__sparse_post_look_up_grad", - "FusedEmbeddingSparsePostLookUpGrad") - .Attr("T", dtype) - .Attr("num_partitions", num_partitions) - .Attr("partition_axis", 0) - .Attr("combiner", combiner) - .Attr("max_norm", max_norm) - .Attr("default_id", default_id) - .Input(FakeInput(dtype)) - .Input(FakeInput(dtype)) - .Input(FakeInput(DT_INT64)) - .Input(FakeInput(DT_INT32)) - .Input(FakeInput(DT_INT32)) - .Finalize(node_def())); - TF_EXPECT_OK(InitOp()); - } -}; - -TEST_F(FusedEmbeddingSparsePostLookUpGradOpTest, - Partition2_Mean_MaxNorm100_Float) { - const int nnz = 10; - const int batch_size = 4; - const int emb_vector_dim = 8; - const int entries = 8; - - MakeOpAndSetDevice(Device::GPU, 2, DT_FLOAT, "mean", 100.0, -1); - - // top_grad - AddInputFromArray( - TensorShape({batch_size, emb_vector_dim}), - {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, - 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, - 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0}); - - // emb_shards - AddInputFromArray( - TensorShape({6, emb_vector_dim}), - {8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 24.0, 25.0, 26.0, 27.0, - 28.0, 29.0, 30.0, 31.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, - 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 32.0, 33.0, 34.0, 35.0, - 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0}); - AddInputFromArray( - TensorShape({4, emb_vector_dim}), - {56.0, 57.0, 58.0, 59.0, 60.0, 61.0, 62.0, 63.0, - 96.0, 97.0, 98.0, 99.0, 100.0, 101.0, 102.0, 103.0, - 96.0, 97.0, 98.0, 99.0, 100.0, 101.0, 102.0, 103.0, - 120.0, 121.0, 122.0, 123.0, 124.0, 125.0, 126.0, 127.0}); - - // sp_values: 3, 1, 4, 5, 7, 3, 12, 12, 15, 4 - // partitioned_values: 1, 3, 3, 4, 4, 5 and 7, 12, 12, 15 - // partitioned_indices - AddInputFromArray(TensorShape({6, 2}), - {0, 5, 0, 1, 2, 1, 1, 2, 3, 6, 1, 1}); - AddInputFromArray(TensorShape({4, 2}), {1, 7, 2, 4, 2, 7, 3, 0}); - - // feature_nums - AddInputFromArray(TensorShape({batch_size}), {2, 3, 3, 2}); - - // row_empty_and_invalid_flags - AddInputFromArray(TensorShape({batch_size + nnz}), - {0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); - - TF_ASSERT_OK(RunOpKernel()); - TF_EXPECT_OK(device_->Sync()); - - { - Tensor grad_shards_1(allocator(), DT_FLOAT, - TensorShape({6, emb_vector_dim})); - test::FillValues( - &grad_shards_1, - {0.00000000, 0.50000000, 1.00000000, 1.50000000, 2.00000000, - 2.50000000, 3.00000000, 3.50000000, 0.00000000, 0.50000000, - 1.00000000, 1.50000000, 2.00000000, 2.50000000, 3.00000000, - 3.50000000, 5.33333349, 5.66666651, 6.00000000, 6.33333349, - 6.66666651, 7.00000000, 7.33333349, 7.66666651, 2.65028572, - 2.98157120, 3.31285667, 3.64414287, 3.97542834, 4.30671406, - 4.63799953, 4.96928549, 11.92628479, 12.42321396, 12.92014217, - 13.41707039, 13.91399956, 14.41092777, 14.90785599, 15.40478516, - 2.16437674, 2.43492365, 2.70547056, 2.97601795, 3.24656487, - 3.51711202, 3.78765893, 4.05820608}); - test::ExpectTensorNear(grad_shards_1, *GetOutput(0), 1e-4); - } - - { - Tensor grad_shards_2(allocator(), DT_FLOAT, - TensorShape({4, emb_vector_dim})); - test::FillValues( - &grad_shards_2, - {1.58337951, 1.78130186, 1.97922409, 2.17714667, 2.37506914, 2.57299161, - 2.77091384, 2.96883631, 1.89459133, 2.01300311, 2.13141513, 2.24982715, - 2.36823893, 2.48665094, 2.60506320, 2.72347474, 1.89459133, 2.01300311, - 2.13141513, 2.24982715, 2.36823893, 2.48665094, 2.60506320, 2.72347474, - 3.43474555, 3.57786012, 3.72097445, 3.86408877, 4.00720310, 4.15031767, - 4.29343224, 4.43654633}); - test::ExpectTensorNear(grad_shards_2, *GetOutput(1), 1e-4); - } -} - -TEST_F(FusedEmbeddingSparsePostLookUpGradOpTest, - Partition2_SUM_Float_No_Default) { - const int nnz = 3; - const int batch_size = 3; - const int emb_vector_dim = 4; - const int entries = 8; - - MakeOpAndSetDevice(Device::GPU, 2, DT_FLOAT, "sum", -1.0, -1); - - // top_grad - AddInputFromArray( - TensorShape({batch_size, emb_vector_dim}), - {1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0, 3.0}); - - // emb_shards - AddInputFromArray(TensorShape({2, emb_vector_dim}), - {8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0}); - AddInputFromArray(TensorShape({2, emb_vector_dim}), - {56.0, 57.0, 58.0, 59.0, 60.0, 61.0, 62.0, 63.0}); - - // partitioned_indices - AddInputFromArray(TensorShape({2, 2}), {0, 0, 0, 5}); - AddInputFromArray(TensorShape({2, 2}), {1, 4, 2, 0}); - - // feature_nums - AddInputFromArray(TensorShape({batch_size}), {2, 1, 1}); - - // row_empty_and_invalid_flags - AddInputFromArray(TensorShape({batch_size + nnz}), {0, 0, 1, 1, 1, 1}); - - TF_ASSERT_OK(RunOpKernel()); - TF_EXPECT_OK(device_->Sync()); - - { - Tensor grad_shards_1(allocator(), DT_FLOAT, - TensorShape({2, emb_vector_dim})); - test::FillValues(&grad_shards_1, - {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}); - test::ExpectTensorNear(grad_shards_1, *GetOutput(0), 1e-4); - } - - { - Tensor grad_shards_2(allocator(), DT_FLOAT, - TensorShape({2, emb_vector_dim})); - test::FillValues(&grad_shards_2, - {2.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0, 3.0}); - test::ExpectTensorNear(grad_shards_2, *GetOutput(1), 1e-4); - } -} - -TEST_F(FusedEmbeddingSparsePostLookUpGradOpTest, - Partition2_SUM_Float_Default_0) { - const int nnz = 3; - const int batch_size = 3; - const int emb_vector_dim = 4; - const int entries = 8; - - MakeOpAndSetDevice(Device::GPU, 2, DT_FLOAT, "sum", -1.0, 0); - - // top_grad - AddInputFromArray( - TensorShape({batch_size, emb_vector_dim}), - {1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0, 3.0}); - - // emb_shards - AddInputFromArray(TensorShape({2, emb_vector_dim}), - {8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0}); - AddInputFromArray(TensorShape({2, emb_vector_dim}), - {56.0, 57.0, 58.0, 59.0, 60.0, 61.0, 62.0, 63.0}); - - // partitioned_indices - AddInputFromArray(TensorShape({2, 2}), {0, 0, 0, 5}); - AddInputFromArray(TensorShape({2, 2}), {1, 4, 2, 0}); - - // feature_nums - AddInputFromArray(TensorShape({batch_size}), {2, 1, 1}); - - // row_empty_and_invalid_flags - AddInputFromArray(TensorShape({batch_size + nnz}), {0, 0, 1, 1, 1, 1}); - - TF_ASSERT_OK(RunOpKernel()); - TF_EXPECT_OK(device_->Sync()); - - { - Tensor grad_shards_1(allocator(), DT_FLOAT, - TensorShape({2, emb_vector_dim})); - test::FillValues(&grad_shards_1, - {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}); - test::ExpectTensorNear(grad_shards_1, *GetOutput(0), 1e-4); - } - - { - Tensor grad_shards_2(allocator(), DT_FLOAT, - TensorShape({2, emb_vector_dim})); - test::FillValues(&grad_shards_2, - {2.0, 2.0, 2.0, 2.0, 0.0, 0.0, 0.0, 0.0}); - test::ExpectTensorNear(grad_shards_2, *GetOutput(1), 1e-4); - } -} - -} // namespace -} // namespace tensorflow \ No newline at end of file diff --git a/tensorflow/core/kernels/fused_embedding/fused_embedding_post_ops_gpus.cu.cc b/tensorflow/core/kernels/fused_embedding/fused_embedding_post_ops_gpus.cu.cc deleted file mode 100644 index 3325bd69a6d..00000000000 --- a/tensorflow/core/kernels/fused_embedding/fused_embedding_post_ops_gpus.cu.cc +++ /dev/null @@ -1,328 +0,0 @@ -#include -#include -#include - -#include "tensorflow/core/framework/op_kernel.h" - -#if GOOGLE_CUDA - -#define EIGEN_USE_GPU - -#include "tensorflow/core/kernels/fused_embedding/fused_embedding_common.cu.h" -#include "tensorflow/core/util/gpu_kernel_helper.h" -#include "cub/thread/thread_operators.cuh" - -namespace tensorflow { -using GPUDevice = Eigen::GpuDevice; - -namespace { -__global__ void SumUpEmbeddingShard(const float* emb_shard, - const int64_t* partitioned_indice, - float* emb_vectors, int* feature_nums, - const float max_norm, - const int emb_vec_size) { - __shared__ float l2_sum[1]; - - const int64_t row_in_batch = partitioned_indice[2 * blockIdx.x]; - float emb_element = emb_shard[blockIdx.x * emb_vec_size + threadIdx.x]; - if (max_norm >= 0.0f) { - if (threadIdx.x == 0) { - l2_sum[0] = 0.0f; - } - __syncthreads(); - atomicAdd(l2_sum, emb_element * emb_element); - __syncthreads(); - float l2_norm = sqrtf(l2_sum[0]); - if (l2_norm > max_norm) { - emb_element *= max_norm / l2_norm; - } - } - - atomicAdd(emb_vectors + row_in_batch * emb_vec_size + threadIdx.x, - emb_element); - - if (threadIdx.x == 0) { - atomicAdd(feature_nums + row_in_batch, 1); - } -} - -template -__global__ void ApplyCombiner(float* emb_vectors, const int* row_emptiness_flag, - const bool set_empty_row_zero, - const int* feature_nums) { - const int offset = blockIdx.x * blockDim.x + threadIdx.x; - if (set_empty_row_zero) { - if (row_emptiness_flag[blockIdx.x]) { - emb_vectors[offset] = 0.0f; - return; - } - } - const int feature_num = feature_nums[blockIdx.x]; - const float emb_element = emb_vectors[offset]; - emb_vectors[offset] = Combine(emb_element, feature_num); -} - -template -__global__ void DistributeGradToShard( - const float* top_grad, const float* emb_shard, - const int64_t* partitioned_indice, const int* feature_nums, - const int* row_emptiness_flag, const bool set_empty_row_zero, - float* grad_shard, const int64_t sub_nnz, const int64_t emb_vec_size, - const float max_norm) { - __shared__ int64_t row_in_batch_shared[1]; - __shared__ int feature_num_shared[1]; - __shared__ float l2_sum[1]; - int64_t row_in_batch; - if (threadIdx.x == 0) { - row_in_batch = partitioned_indice[2 * blockIdx.x]; - row_in_batch_shared[0] = row_in_batch; - feature_num_shared[0] = feature_nums[row_in_batch]; - } - __syncthreads(); - row_in_batch = row_in_batch_shared[0]; - const int feature_num = feature_num_shared[0]; - if (set_empty_row_zero) { - if (row_emptiness_flag[row_in_batch]) { - grad_shard[blockIdx.x * emb_vec_size + threadIdx.x] = 0.0f; - return; - } - } - float grad = top_grad[row_in_batch * emb_vec_size + threadIdx.x]; - grad = CombineGrad(grad, feature_num); - if (max_norm >= 0.0f) { - const float emb_element = - emb_shard[blockIdx.x * emb_vec_size + threadIdx.x]; - if (threadIdx.x == 0) { - l2_sum[0] = 0.0f; - } - __syncthreads(); - atomicAdd(l2_sum, emb_element * emb_element); - __syncthreads(); - float l2_norm = sqrtf(l2_sum[0]); - if (l2_norm > max_norm) { - grad *= max_norm / l2_norm; - } - } - grad_shard[blockIdx.x * emb_vec_size + threadIdx.x] = grad; -} -} // namespace - -class FusedEmbeddingSparsePostLookUpGPU : public OpKernel { - public: - explicit FusedEmbeddingSparsePostLookUpGPU(OpKernelConstruction* ctx) - : OpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("num_partitions", &num_partitions_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("partition_axis", &partition_axis_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("combiner", &combiner_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("max_norm", &max_norm_)); - int temp_default_id; - OP_REQUIRES_OK(ctx, ctx->GetAttr("default_id", &temp_default_id)); - default_id_ = int64_t(temp_default_id); - } - - void Compute(OpKernelContext* ctx) override { - auto stream = ctx->eigen_device().stream(); - - OpInputList emb_shards; - OP_REQUIRES_OK(ctx, ctx->input_list("emb_shards", &emb_shards)); - - OpInputList partitioned_indices; - OP_REQUIRES_OK( - ctx, ctx->input_list("partitioned_indices", &partitioned_indices)); - - Tensor const* dense_shape_tensor = nullptr; - OP_REQUIRES_OK(ctx, ctx->input("sp_dense_shape", &dense_shape_tensor)); - - Tensor const* row_empty_and_invalid_flags = nullptr; - OP_REQUIRES_OK(ctx, ctx->input("row_empty_and_invalid_flags", - &row_empty_and_invalid_flags)); - - const int64_t emb_vec_size = emb_shards[0].shape().dim_size(1); - const int64_t batch_size = dense_shape_tensor->flat().data()[0]; - - // 1. sum up emb values from different entries and dump into output - Tensor* emb_vectors_tensor = nullptr; - OP_REQUIRES_OK( - ctx, ctx->allocate_output(0, TensorShape({batch_size, emb_vec_size}), - &emb_vectors_tensor)); - // stream_executor::DeviceMemoryBase emb_vectors_wrapper( - // emb_vectors_tensor.flat().data(), - // emb_vectors_tensor->NumElements() * sizeof(float)); - // stream->ThenMemZero(&emb_vectors_wrapper, - // emb_vectors_tensor->NumElements() * sizeof(float)); - - cudaMemsetAsync(emb_vectors_tensor->flat().data(), 0x0, - sizeof(float) * emb_vectors_tensor->NumElements(), stream); - - Tensor* feature_nums; - OP_REQUIRES_OK( - ctx, ctx->allocate_output(1, TensorShape({batch_size}), &feature_nums)); - // stream_executor::DeviceMemoryBase feature_nums_wrapper( - // feature_nums.flat().data(), - // feature_nums.NumElements() * sizeof(int)); - // stream->ThenMemZero(&feature_nums_wrapper, - // feature_nums.NumElements() * sizeof(int)); - cudaMemsetAsync(feature_nums->flat().data(), 0x0, - sizeof(int) * feature_nums->NumElements(), stream); - - for (int i = 0; i < num_partitions_; i++) { - const size_t sub_nnz = emb_shards[i].shape().dim_size(0); - OP_REQUIRES( - ctx, sub_nnz == partitioned_indices[i].shape().dim_size(0), - errors::InvalidArgument( - "emb_shard and partitioned_indice dosn't have the same length")); - - { - const int blocks = sub_nnz; - const int threads = emb_vec_size; - SumUpEmbeddingShard<<>>( - emb_shards[i].flat().data(), - reinterpret_cast( - partitioned_indices[i].flat().data()), - emb_vectors_tensor->flat().data(), - feature_nums->flat().data(), max_norm_, emb_vec_size); - CK_CUDA_THROW_(cudaGetLastError()); - } - } - - const bool set_empty_row_zero = default_id_ >= 0; - // 2. combiner - { - const int blocks = batch_size; - const int threads = emb_vec_size; - if (combiner_ == "sqrtn") { - ApplyCombiner<<>>( - emb_vectors_tensor->flat().data(), - row_empty_and_invalid_flags->flat().data(), set_empty_row_zero, - feature_nums->flat().data()); - } else if (combiner_ == "mean") { - ApplyCombiner<<>>( - emb_vectors_tensor->flat().data(), - row_empty_and_invalid_flags->flat().data(), set_empty_row_zero, - feature_nums->flat().data()); - } else { - ApplyCombiner<<>>( - emb_vectors_tensor->flat().data(), - row_empty_and_invalid_flags->flat().data(), set_empty_row_zero, - feature_nums->flat().data()); - } - CK_CUDA_THROW_(cudaGetLastError()); - } - } - - private: - int num_partitions_; - int partition_axis_; - std::string combiner_; - float max_norm_; - int64_t default_id_; -}; - -REGISTER_KERNEL_BUILDER(Name("FusedEmbeddingSparsePostLookUp") - .Device(DEVICE_GPU) - .HostMemory("sp_dense_shape"), - FusedEmbeddingSparsePostLookUpGPU); - -class FusedEmbeddingSparsePostLookUpGradGPU : public OpKernel { - public: - explicit FusedEmbeddingSparsePostLookUpGradGPU(OpKernelConstruction* ctx) - : OpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("num_partitions", &num_partitions_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("partition_axis", &partition_axis_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("combiner", &combiner_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("max_norm", &max_norm_)); - int temp_default_id; - OP_REQUIRES_OK(ctx, ctx->GetAttr("default_id", &temp_default_id)); - default_id_ = int64_t(temp_default_id); - } - - void Compute(OpKernelContext* ctx) override { - auto stream = ctx->eigen_device().stream(); - - Tensor const* top_grad_tensor = nullptr; - OP_REQUIRES_OK(ctx, ctx->input("top_grad", &top_grad_tensor)); - - OpInputList emb_shards; - OP_REQUIRES_OK(ctx, ctx->input_list("emb_shards", &emb_shards)); - - OpInputList partitioned_indices; - OP_REQUIRES_OK( - ctx, ctx->input_list("partitioned_indices", &partitioned_indices)); - - Tensor const* feature_nums = nullptr; - OP_REQUIRES_OK(ctx, ctx->input("feature_nums", &feature_nums)); - - Tensor const* row_empty_and_invalid_flags = nullptr; - OP_REQUIRES_OK(ctx, ctx->input("row_empty_and_invalid_flags", - &row_empty_and_invalid_flags)); - - OpOutputList grad_shards; - OP_REQUIRES_OK(ctx, ctx->output_list("grad_shards", &grad_shards)); - - const int64_t batch_size = top_grad_tensor->shape().dim_size(0); - const int64_t emb_vec_size = emb_shards[0].shape().dim_size(1); - - const bool set_empty_row_zero = default_id_ >= 0; - - for (int i = 0; i < num_partitions_; i++) { - const int64_t sub_nnz = partitioned_indices[i].shape().dim_size(0); - - Tensor* grad_shard; - OP_REQUIRES_OK( - ctx, grad_shards.allocate(i, TensorShape({sub_nnz, emb_vec_size}), - &grad_shard)); - - { - const int blocks = sub_nnz; - const int threads = emb_vec_size; - if (combiner_ == "sqrtn") { - DistributeGradToShard<<>>( - top_grad_tensor->flat().data(), - emb_shards[i].flat().data(), - reinterpret_cast( - partitioned_indices[i].flat().data()), - feature_nums->flat().data(), - row_empty_and_invalid_flags->flat().data(), - set_empty_row_zero, grad_shard->flat().data(), sub_nnz, - emb_vec_size, max_norm_); - } else if (combiner_ == "mean") { - DistributeGradToShard<<>>( - top_grad_tensor->flat().data(), - emb_shards[i].flat().data(), - reinterpret_cast( - partitioned_indices[i].flat().data()), - feature_nums->flat().data(), - row_empty_and_invalid_flags->flat().data(), - set_empty_row_zero, grad_shard->flat().data(), sub_nnz, - emb_vec_size, max_norm_); - } else { - DistributeGradToShard<<>>( - top_grad_tensor->flat().data(), - emb_shards[i].flat().data(), - reinterpret_cast( - partitioned_indices[i].flat().data()), - feature_nums->flat().data(), - row_empty_and_invalid_flags->flat().data(), - set_empty_row_zero, grad_shard->flat().data(), sub_nnz, - emb_vec_size, max_norm_); - } - CK_CUDA_THROW_(cudaGetLastError()); - } - } - } - - private: - int num_partitions_; - int partition_axis_; - std::string combiner_; - float max_norm_; - int64_t default_id_; -}; - -REGISTER_KERNEL_BUILDER( - Name("FusedEmbeddingSparsePostLookUpGrad").Device(DEVICE_GPU), - FusedEmbeddingSparsePostLookUpGradGPU); - -} // namespace tensorflow - -#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/fused_embedding/fused_embedding_post_ops_test.cc b/tensorflow/core/kernels/fused_embedding/fused_embedding_post_ops_test.cc deleted file mode 100644 index 3321f3ff677..00000000000 --- a/tensorflow/core/kernels/fused_embedding/fused_embedding_post_ops_test.cc +++ /dev/null @@ -1,213 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" -#include "tensorflow/core/framework/fake_input.h" -#include "tensorflow/core/framework/node_def_builder.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/conv_ops_gpu.h" -#include "tensorflow/core/kernels/ops_testutil.h" -#include "tensorflow/core/kernels/ops_util.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/platform/test_benchmark.h" -#include "tensorflow/core/public/session.h" - -namespace tensorflow { -namespace { - -enum class Device { CPU, GPU }; -class FusedEmbeddingSparsePostLookUpOpTest : public OpsTestBase { - protected: - void MakeOpAndSetDevice(Device device, int num_partitions, DataType dtype, - const std::string& combiner, const float max_norm, - const int default_id) { - if (device == Device::GPU) { - SetDevice(DEVICE_GPU, - std::unique_ptr(DeviceFactory::NewDevice( - "GPU", {}, "/job:a/replica:0/task:0"))); - } - - TF_EXPECT_OK(NodeDefBuilder("fused_embedding_sparse_post_look_up", - "FusedEmbeddingSparsePostLookUp") - .Attr("T", dtype) - .Attr("num_partitions", num_partitions) - .Attr("partition_axis", 0) - .Attr("combiner", combiner) - .Attr("max_norm", max_norm) - .Attr("default_id", default_id) - .Input(FakeInput(num_partitions, dtype)) - .Input(FakeInput(num_partitions, DT_INT64)) - .Input(FakeInput(DT_INT64)) - .Input(FakeInput(DT_INT32)) - .Input(FakeInput(DT_INT64)) - .Finalize(node_def())); - TF_EXPECT_OK(InitOp()); - } -}; - -TEST_F(FusedEmbeddingSparsePostLookUpOpTest, - Partition3_Sqrtn_MaxNorm200_Float) { - const int nnz = 10; - const int batch_size = 4; - const int emb_vector_dim = 8; - const int entries = 8; - - MakeOpAndSetDevice(Device::GPU, 3, DT_FLOAT, "sqrtn", 200.0, -1); - - // emb_shards - AddInputFromArray( - TensorShape({6, emb_vector_dim}), - { - 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 24.0, 25.0, - 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 24.0, 25.0, 26.0, 27.0, - 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, - 38.0, 39.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, - 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, - }); - AddInputFromArray(TensorShape({1, emb_vector_dim}), - {56.0, 57.0, 58.0, 59.0, 60.0, 61.0, 62.0, 63.0}); - AddInputFromArray( - TensorShape({3, emb_vector_dim}), - {96.0, 97.0, 98.0, 99.0, 100.0, 101.0, 102.0, 103.0, - 96.0, 97.0, 98.0, 99.0, 100.0, 101.0, 102.0, 103.0, - 120.0, 121.0, 122.0, 123.0, 124.0, 125.0, 126.0, 127.0}); - - // partitioned_indices - AddInputFromArray(TensorShape({6, 2}), - {0, 5, 0, 1, 2, 1, 1, 2, 3, 6, 1, 1}); - AddInputFromArray(TensorShape({1, 2}), {1, 7}); - AddInputFromArray(TensorShape({3, 2}), {2, 4, 2, 7, 3, 0}); - - // sp_dense_shape - AddInputFromArray(TensorShape({2}), {batch_size, entries}); - - // row_empty_and_invalid_flags - AddInputFromArray(TensorShape({batch_size + nnz}), - {0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); - - TF_ASSERT_OK(RunOpKernel()); - TF_EXPECT_OK(device_->Sync()); - - { - Tensor expected_emb_vectors(allocator(), DT_FLOAT, - TensorShape({batch_size, emb_vector_dim})); - test::FillValues( - &expected_emb_vectors, - {22.62741661, 24.04163170, 25.45584488, 26.87005806, 28.28427124, - 29.69848442, 31.11269951, 32.52691269, 73.90083313, 75.63288879, - 77.36493683, 79.09698486, 80.82904053, 82.56108856, 84.29314423, - 86.02519226, 92.61308289, 94.01081848, 95.40855408, 96.80628204, - 98.20401764, 99.60175323, 100.99948120, 102.39721680, 71.20205688, - 72.31395721, 73.42584991, 74.53774261, 75.64963531, 76.76153564, - 77.87342834, 78.98532867}); - test::ExpectTensorNear(expected_emb_vectors, *GetOutput(0), 1e-4); - } - { - Tensor feature_nums_expected(allocator(), DT_INT32, - TensorShape({batch_size})); - test::FillValues(&feature_nums_expected, {2, 3, 3, 2}); - test::ExpectTensorEqual(feature_nums_expected, *GetOutput(1)); - } -} - -TEST_F(FusedEmbeddingSparsePostLookUpOpTest, Partition2_Sum_No_Default) { - const int nnz = 3; - const int batch_size = 3; - const int emb_vector_dim = 4; - const int entries = 8; - - MakeOpAndSetDevice(Device::GPU, 2, DT_FLOAT, "sum", -1.0, -1); - - // emb_shards - AddInputFromArray(TensorShape({2, emb_vector_dim}), - {1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0}); - AddInputFromArray(TensorShape({2, emb_vector_dim}), - {10.0, 10.0, 10.0, 10.0, 13.0, 13.0, 13.0, 13.0}); - - // partitioned_indices - AddInputFromArray(TensorShape({2, 2}), {0, 0, 0, 5}); - AddInputFromArray(TensorShape({2, 2}), {1, 4, 2, 0}); - - // sp_dense_shape - AddInputFromArray(TensorShape({2}), {batch_size, entries}); - - // row_empty_and_invalid_flags - AddInputFromArray(TensorShape({batch_size + nnz}), {0, 0, 1, 1, 1, 1}); - - TF_ASSERT_OK(RunOpKernel()); - TF_EXPECT_OK(device_->Sync()); - - { - Tensor expected_emb_vectors(allocator(), DT_FLOAT, - TensorShape({batch_size, emb_vector_dim})); - test::FillValues( - &expected_emb_vectors, - {3.0, 3.0, 3.0, 3.0, 10.0, 10.0, 10.0, 10.0, 13.0, 13.0, 13.0, 13.0}); - test::ExpectTensorNear(expected_emb_vectors, *GetOutput(0), 1e-4); - } - { - Tensor feature_nums_expected(allocator(), DT_INT32, - TensorShape({batch_size})); - test::FillValues(&feature_nums_expected, {2, 1, 1}); - test::ExpectTensorEqual(feature_nums_expected, *GetOutput(1)); - } -} - -TEST_F(FusedEmbeddingSparsePostLookUpOpTest, Partition2_Sum_Default_0) { - const int nnz = 3; - const int batch_size = 3; - const int emb_vector_dim = 4; - const int entries = 8; - - MakeOpAndSetDevice(Device::GPU, 2, DT_FLOAT, "sum", -1.0, 0); - - // emb_shards - AddInputFromArray(TensorShape({2, emb_vector_dim}), - {1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0}); - AddInputFromArray(TensorShape({2, emb_vector_dim}), - {10.0, 10.0, 10.0, 10.0, 13.0, 13.0, 13.0, 13.0}); - - // partitioned_indices - AddInputFromArray(TensorShape({2, 2}), {0, 0, 0, 5}); - AddInputFromArray(TensorShape({2, 2}), {1, 4, 2, 0}); - - // sp_dense_shape - AddInputFromArray(TensorShape({2}), {batch_size, entries}); - - // row_empty_and_invalid_flags - AddInputFromArray(TensorShape({batch_size + nnz}), {0, 0, 1, 1, 1, 1}); - - TF_ASSERT_OK(RunOpKernel()); - TF_EXPECT_OK(device_->Sync()); - - { - Tensor expected_emb_vectors(allocator(), DT_FLOAT, - TensorShape({batch_size, emb_vector_dim})); - test::FillValues( - &expected_emb_vectors, - {3.0, 3.0, 3.0, 3.0, 10.0, 10.0, 10.0, 10.0, 0.0, 0.0, 0.0, 0.0}); - test::ExpectTensorNear(expected_emb_vectors, *GetOutput(0), 1e-4); - } - { - Tensor feature_nums_expected(allocator(), DT_INT32, - TensorShape({batch_size})); - test::FillValues(&feature_nums_expected, {2, 1, 1}); - test::ExpectTensorEqual(feature_nums_expected, *GetOutput(1)); - } -} - -} // namespace -} // namespace tensorflow \ No newline at end of file diff --git a/tensorflow/core/kernels/fused_embedding/fused_embedding_pre_ops_gpus.cu.cc b/tensorflow/core/kernels/fused_embedding/fused_embedding_pre_ops_gpus.cu.cc deleted file mode 100644 index e8521600322..00000000000 --- a/tensorflow/core/kernels/fused_embedding/fused_embedding_pre_ops_gpus.cu.cc +++ /dev/null @@ -1,521 +0,0 @@ -#include -#include -#include - -#include "tensorflow/core/framework/op_kernel.h" - -#if GOOGLE_CUDA - -#define EIGEN_USE_GPU - -#include "tensorflow/core/kernels/fused_embedding/fused_embedding_common.cu.h" -#include "tensorflow/core/util/gpu_kernel_helper.h" -#include "cub/device/device_radix_sort.cuh" -#include "cub/device/device_select.cuh" -#include "cub/iterator/constant_input_iterator.cuh" -#include "cub/thread/thread_operators.cuh" - -namespace tensorflow { -using GPUDevice = Eigen::GpuDevice; - -namespace { - -__global__ void InitFlagsToOneInt4(int length, int* flags) { - int offset = blockIdx.x * blockDim.x + threadIdx.x; - if (4 * offset + 3 < length) { - *((int4*)(flags + 4 * offset)) = make_int4(1, 1, 1, 1); - } else if (4 * offset < length) { - for (int i = 0; i < length - 4 * offset; i++) { - flags[4 * offset + i] = 1; - } - } -} - -__global__ void FusedMultiFunctionalKernel( - const IndicePair* indices, const int64_t* values, const int64_t nnz, - const int64_t batch_size, const bool prune_invalid_id, - const int64_t default_id, int* row_emptiness_flag, int* invalid_id_flag, - IndicePair* tmp_indices_buffer, int64_t* values_extended) { - // This kernel will do many things together - // 1. The first part of threads will do job 1(DetectRowEmptiness), others will - // do job2(InitBatchRowsBuffer) - // 2. Do job3 (set values extended to default id) - - const int offset = blockIdx.x * blockDim.x + threadIdx.x; - if (offset < nnz) { - // do DetectRowEmptiness - if (prune_invalid_id) { - const int64_t value = values[offset]; - if (value < 0) { - // invalid, set invalid_id_flag - atomicAnd(invalid_id_flag + offset, 0); - } else { - // valid, set row_emptiness_flag - const int64_t row_in_batch = indices[offset].row_in_batch; - atomicAnd(row_emptiness_flag + row_in_batch, 0); - } - } else { - // set row_emptiness_flag - const int64_t row_in_batch = indices[offset].row_in_batch; - atomicAnd(row_emptiness_flag + row_in_batch, 0); - } - } else { - // do InitBatchRowsBuffer - const int other_offset = offset - nnz; - if (other_offset < batch_size) { - tmp_indices_buffer[other_offset].row_in_batch = other_offset; - // always set entry id to 0; - tmp_indices_buffer[other_offset].entry_in_column = 0; - } - } - - // set values extended to default id - if (2 * offset + 1 < nnz + batch_size) { - longlong2 l2 = make_longlong2(default_id, default_id); - *((longlong2*)(values_extended + 2 * offset)) = l2; - } else if (2 * offset < nnz + batch_size) { - values_extended[2 * offset] = default_id; - } -} - -__global__ void DetectInvalid(const int64_t* values, const int64_t nnz, - int* invalid_id_flag) { - const int offset = blockIdx.x * blockDim.x + threadIdx.x; - if (offset < nnz) { - const int64_t value = values[offset]; - if (value < 0) { - atomicAnd(invalid_id_flag + offset, 0); - } - } -} - -__global__ void CalcElementsOffsetPerPartition( - const int64_t* values_sorted, int64_t* partition_sizes_accumulate, - int64_t* elements_offset_per_partition, int nnz) { - // dichotomy - const int64_t target = partition_sizes_accumulate[blockIdx.x]; - int roof = nnz; - int floor = 0; - - int pos = (roof + floor) / 2; - while (1) { - if (pos == 0) { - pos = -1; - break; - } else if (pos == nnz - 1) { - break; - } - int64_t value = values_sorted[pos]; - int64_t value_plus_1 = values_sorted[pos + 1]; - if (value < target && value_plus_1 >= target) { - break; - } - if (value < target) { - floor = pos; - } else { - roof = pos; - } - pos = (roof + floor) / 2; - } - elements_offset_per_partition[blockIdx.x] = int64_t(pos + 1); -} - -__global__ void GatherAndConvertToSubPartition( - const int64_t* sub_values_sorted, int64_t* sub_partitioned_values, - const int64_t partition_start_base, const int64_t partition_size) { - const int t_offset = blockIdx.x * blockDim.x + threadIdx.x; - if (t_offset < partition_size) { - int64_t value = sub_values_sorted[t_offset]; - // rebase value to it's corresponding sub partition - value = value - partition_start_base; - sub_partitioned_values[t_offset] = value; - } -} - -} // namespace - -class FusedEmbeddingSparsePreLookUpGPU : public OpKernel { - public: - explicit FusedEmbeddingSparsePreLookUpGPU(OpKernelConstruction* ctx) - : OpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("num_partitions", &num_partitions_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("partition_axis", &partition_axis_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("fill_empty_row", &fill_empty_row_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("prune_invalid_id", &prune_invalid_id_)); - int temp_default_id; - OP_REQUIRES_OK(ctx, ctx->GetAttr("default_id", &temp_default_id)); - default_id_ = int64_t(temp_default_id); - } - - void Compute(OpKernelContext* ctx) override { - auto stream = ctx->eigen_device().stream(); - - const int64_t default_id = default_id_ >= 0 ? default_id_ : 0; - const int linear_mapping_threads = 128; - - // 1. bind inputs - Tensor const* values_tensor = nullptr; - OP_REQUIRES_OK(ctx, ctx->input("sp_values", &values_tensor)); - const int64_t nnz = values_tensor->shape().dim_size(0); - - Tensor const* indices_tensor = nullptr; - OP_REQUIRES_OK(ctx, ctx->input("sp_indices", &indices_tensor)); - - Tensor const* dense_shape = nullptr; - OP_REQUIRES_OK(ctx, ctx->input("sp_dense_shape", &dense_shape)); - const int64_t batch_size = dense_shape->flat().data()[0]; - - OpInputList partition_shapes; - OP_REQUIRES_OK(ctx, ctx->input_list("partition_shapes", &partition_shapes)); - - partition_sizes_accumulate_.clear(); - for (const Tensor& shape : partition_shapes) { - OP_REQUIRES(ctx, shape.dims() <= 2, - errors::InvalidArgument( - "input partition_shapes must all less than rank 2")); - const int64_t accu = partition_sizes_accumulate_.empty() - ? shape.flat().data()[0] - : shape.flat().data()[0] + - partition_sizes_accumulate_.back(); - partition_sizes_accumulate_.push_back(accu); - } - - // 2. allocate cub tmp storage - Tensor cub_temp_storage; - size_t max_cub_bytes = 0; - size_t temp_storage_bytes = 0; - - if (num_partitions_ > 1) { - cub::DeviceRadixSort::SortPairs( - (void*)nullptr, temp_storage_bytes, (int64_t*)nullptr, - (int64_t*)nullptr, (IndicePair*)nullptr, (IndicePair*)nullptr, - int(nnz + batch_size), 0, sizeof(int64_t) * 8, stream); - max_cub_bytes = temp_storage_bytes > max_cub_bytes ? temp_storage_bytes - : max_cub_bytes; - } - - if (fill_empty_row_ || prune_invalid_id_) { - cub::DeviceSelect::Flagged(nullptr, temp_storage_bytes, (int64_t*)nullptr, - (int*)nullptr, (int64_t*)nullptr, - (int*)nullptr, nnz, stream); - - max_cub_bytes = temp_storage_bytes > max_cub_bytes ? temp_storage_bytes - : max_cub_bytes; - - cub::DeviceSelect::Flagged( - (void*)nullptr, temp_storage_bytes, (IndicePair*)nullptr, - (int*)nullptr, (IndicePair*)nullptr, (int*)nullptr, nnz, stream); - - max_cub_bytes = temp_storage_bytes > max_cub_bytes ? temp_storage_bytes - : max_cub_bytes; - - if (fill_empty_row_) { - cub::DeviceSelect::Flagged((void*)nullptr, temp_storage_bytes, - (IndicePair*)nullptr, (int*)nullptr, - (IndicePair*)nullptr, (int*)nullptr, - batch_size, stream); - max_cub_bytes = temp_storage_bytes > max_cub_bytes ? temp_storage_bytes - : max_cub_bytes; - } - } - - OP_REQUIRES_OK( - ctx, ctx->allocate_temp( - DT_INT8, TensorShape({static_cast(max_cub_bytes)}), - &cub_temp_storage)); - - // 3. fill_empty_row, prune, if avaliable. - Tensor values_extended; - Tensor indices_extended; - Tensor tmp_indices_buffer; - Tensor* all_flags; - Tensor selected_num_d; - int new_nnz = nnz; - - OP_REQUIRES_OK( - ctx, ctx->allocate_output(2 * num_partitions_, - TensorShape{batch_size + nnz}, &all_flags)); - - if (fill_empty_row_ || prune_invalid_id_) { - OP_REQUIRES_OK(ctx, - ctx->allocate_temp(DT_INT64, TensorShape{nnz + batch_size}, - &values_extended)); - OP_REQUIRES_OK( - ctx, ctx->allocate_temp(DT_INT64, TensorShape{2 * (nnz + batch_size)}, - &indices_extended)); - OP_REQUIRES_OK(ctx, - ctx->allocate_temp(DT_INT64, TensorShape{2 * batch_size}, - &tmp_indices_buffer)); - OP_REQUIRES_OK( - ctx, ctx->allocate_temp(DT_INT32, TensorShape{1}, &selected_num_d)); - - { - const int threads = linear_mapping_threads; - const int blocks = - CalcBlocksLinearMapping(batch_size + nnz, threads * 4); - InitFlagsToOneInt4<<>>( - batch_size + nnz, all_flags->flat().data()); - CK_CUDA_THROW_(cudaGetLastError()); - } - - // 3.1 set flags, init tmp_indices_buffer etc. - if (fill_empty_row_) { - { - const int threads = linear_mapping_threads; - const int blocks = CalcBlocksLinearMapping(nnz + batch_size, threads); - FusedMultiFunctionalKernel<<>>( - reinterpret_cast( - indices_tensor->flat().data()), - reinterpret_cast( - values_tensor->flat().data()), - nnz, batch_size, prune_invalid_id_, default_id, - all_flags->flat().data(), - all_flags->flat().data() + batch_size, - reinterpret_cast( - tmp_indices_buffer.flat().data()), - reinterpret_cast(values_extended.flat().data())); - CK_CUDA_THROW_(cudaGetLastError()); - } - } else if (prune_invalid_id_) { - { - const int threads = linear_mapping_threads; - const int blocks = CalcBlocksLinearMapping(nnz, threads); - DetectInvalid<<>>( - reinterpret_cast( - values_tensor->flat().data()), - nnz, all_flags->flat().data() + batch_size); - CK_CUDA_THROW_(cudaGetLastError()); - } - } - // 3.2 select copy valid id, select copy empty row indices - - cudaError_t cuda_ret = cudaSuccess; - cuda_ret = cub::DeviceSelect::Flagged( - cub_temp_storage.flat().data(), max_cub_bytes, - reinterpret_cast(values_tensor->flat().data()), - (const int*)(all_flags->flat().data() + batch_size), - reinterpret_cast(values_extended.flat().data()), - selected_num_d.flat().data(), int(nnz), stream); - CK_CUDA_THROW_(cudaGetLastError()); - - cub::DeviceSelect::Flagged( - cub_temp_storage.flat().data(), max_cub_bytes, - reinterpret_cast( - indices_tensor->flat().data()), - all_flags->flat().data() + batch_size, - reinterpret_cast(indices_extended.flat().data()), - selected_num_d.flat().data(), nnz, stream); - - if (prune_invalid_id_) { - int selected_num; - cudaMemcpyAsync(&selected_num, selected_num_d.flat().data(), - sizeof(int), cudaMemcpyDeviceToHost, stream); - cudaStreamSynchronize(stream); - new_nnz = selected_num; - } - - if (fill_empty_row_) { - cub::DeviceSelect::Flagged( - cub_temp_storage.flat().data(), max_cub_bytes, - reinterpret_cast( - tmp_indices_buffer.flat().data()), - all_flags->flat().data(), - reinterpret_cast( - indices_extended.flat().data()) + - new_nnz, - selected_num_d.flat().data(), batch_size, stream); - CK_CUDA_THROW_(cudaGetLastError()); - int selected_num; - cudaMemcpyAsync(&selected_num, selected_num_d.flat().data(), - sizeof(int), cudaMemcpyDeviceToHost, stream); - cudaStreamSynchronize(stream); - new_nnz += selected_num; - } - } - - // 3.5 set the correct pointer - const int64_t* values_in = (fill_empty_row_ || prune_invalid_id_) - ? reinterpret_cast( - values_extended.flat().data()) - : reinterpret_cast( - values_tensor->flat().data()); - const IndicePair* indices_in = - (fill_empty_row_ || prune_invalid_id_) - ? reinterpret_cast( - indices_extended.flat().data()) - : reinterpret_cast( - indices_tensor->flat().data()); - - OpOutputList partitioned_values; - OP_REQUIRES_OK(ctx, - ctx->output_list("partitioned_values", &partitioned_values)); - OpOutputList partitioned_indices; - OP_REQUIRES_OK( - ctx, ctx->output_list("partitioned_indices", &partitioned_indices)); - - // 4. set output - if (num_partitions_ == 1) { - // single partition case, just directly copy - Tensor* pv_out; - OP_REQUIRES_OK( - ctx, partitioned_values.allocate( - 0, TensorShape({static_cast(new_nnz)}), &pv_out)); - Tensor* pi_out; - OP_REQUIRES_OK( - ctx, - partitioned_indices.allocate( - 0, TensorShape({static_cast(new_nnz), 2}), &pi_out)); - - cudaMemcpyAsync(pv_out->flat().data(), values_in, - sizeof(int64_t) * new_nnz, cudaMemcpyDeviceToDevice, - stream); - cudaMemcpyAsync(pi_out->flat().data(), indices_in, - sizeof(IndicePair) * new_nnz, cudaMemcpyDeviceToDevice, - stream); - - } else { - // multi-partitions case, calcaulate indices and split them. - Tensor values_sorted; - OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_INT64, TensorShape{new_nnz}, - &values_sorted)); - Tensor indices_sorted; - OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_INT64, TensorShape{new_nnz, 2}, - &indices_sorted)); - - cub::DeviceRadixSort::SortPairs( - cub_temp_storage.flat().data(), max_cub_bytes, values_in, - reinterpret_cast(values_sorted.flat().data()), - indices_in, - reinterpret_cast(indices_sorted.flat().data()), - int(new_nnz), 0, sizeof(int64_t) * 8, stream); - CK_CUDA_THROW_(cudaGetLastError()); - - // 4.1 calculate how many elements for each - // partition - Tensor partition_sizes_accumulate; - OP_REQUIRES_OK( - ctx, - ctx->allocate_temp( - DT_INT64, TensorShape({static_cast(num_partitions_)}), - &partition_sizes_accumulate)); - cudaMemcpyAsync(partition_sizes_accumulate.flat().data(), - partition_sizes_accumulate_.data(), - num_partitions_ * sizeof(int64_t), cudaMemcpyHostToDevice, - stream); - - Tensor elements_offset_per_partition; - OP_REQUIRES_OK( - ctx, - ctx->allocate_temp( - DT_INT64, TensorShape({static_cast(num_partitions_)}), - &elements_offset_per_partition)); - - { - const int blocks = num_partitions_; - const int threads = 1; - CalcElementsOffsetPerPartition<<>>( - reinterpret_cast( - values_sorted.flat().data()), - reinterpret_cast( - partition_sizes_accumulate.flat().data()), - reinterpret_cast( - elements_offset_per_partition.flat().data()), - int(new_nnz)); - CK_CUDA_THROW_(cudaGetLastError()); - } - - elements_offset_per_partition_.clear(); - elements_offset_per_partition_.resize(num_partitions_); - // stream_executor::DeviceMemoryBase - // elements_offset_per_partition_wrapped( - // elements_offset_per_partition.flat().data(), - // num_partitions_); - // stream->ThenMemcpy(elements_offset_per_partition_.data(), - // elements_offset_per_partition_wrapped, - // num_partitions_ * - // sizeof(int64_t)); - // stream->BlockHostUntilDone(); - - cudaMemcpyAsync(elements_offset_per_partition_.data(), - elements_offset_per_partition.flat().data(), - num_partitions_ * sizeof(int64_t), cudaMemcpyDeviceToHost, - stream); - cudaStreamSynchronize(stream); - - // 4.2 set output - int64_t sub_start_offset = 0; - for (int i = 0; i < num_partitions_; i++) { - int64_t size = elements_offset_per_partition_[i] - sub_start_offset; - - Tensor* sub_partitioned_values; - OP_REQUIRES_OK(ctx, partitioned_values.allocate( - i, TensorShape({static_cast(size)}), - &sub_partitioned_values)); - - Tensor* sub_partitioned_indices; - OP_REQUIRES_OK(ctx, partitioned_indices.allocate( - i, TensorShape({static_cast(size), 2}), - &sub_partitioned_indices)); - - if (size > 0) { - // some partition does not have any - // element that falls in it - const int threads = linear_mapping_threads; - int blocks = CalcBlocksLinearMapping(size, threads); - - const int partition_start_base = - i == 0 ? 0 : partition_sizes_accumulate_[i - 1]; - GatherAndConvertToSubPartition<<>>( - reinterpret_cast( - values_sorted.flat().data()) + - sub_start_offset, - reinterpret_cast( - sub_partitioned_values->flat().data()), - partition_start_base, size); - - CK_CUDA_THROW_(cudaGetLastError()); - - // stream_executor::DeviceMemoryBase - // sub_indices_sorted_wrapped( - // reinterpret_cast(indices_sorted.flat().data()) - // + - // partition_start_base, - // size * sizeof(IndicePair)); - // stream_executor::DeviceMemoryBase - // sub_indices_out_wrapped( - // reinterpret_cast( - // sub_partitioned_indices.flat().data()), - // size * sizeof(IndicePair)); - // stream->ThenMemcpy(&sub_indices_out_wrapped, - // sub_indices_sorted_wrapped, - // size * 2 * - // sizeof(int64_t)); - cudaMemcpyAsync( - sub_partitioned_indices->flat().data(), - indices_sorted.flat().data() + 2 * sub_start_offset, - size * 2 * sizeof(int64_t), cudaMemcpyDeviceToDevice, stream); - } - sub_start_offset = elements_offset_per_partition_[i]; - } - } - // Op kernel execution done - } - - private: - int num_partitions_; - int partition_axis_; - bool fill_empty_row_; - bool prune_invalid_id_; - int64_t default_id_; - std::vector partition_sizes_accumulate_; - std::vector elements_offset_per_partition_; -}; - -REGISTER_KERNEL_BUILDER(Name("FusedEmbeddingSparsePreLookUp") - .Device(DEVICE_GPU) - .HostMemory("partition_shapes") - .HostMemory("sp_dense_shape"), - FusedEmbeddingSparsePreLookUpGPU); -} // namespace tensorflow - -#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/fused_embedding/fused_embedding_pre_ops_test.cc b/tensorflow/core/kernels/fused_embedding/fused_embedding_pre_ops_test.cc deleted file mode 100644 index e960330406c..00000000000 --- a/tensorflow/core/kernels/fused_embedding/fused_embedding_pre_ops_test.cc +++ /dev/null @@ -1,352 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include - -#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" -#include "tensorflow/core/framework/fake_input.h" -#include "tensorflow/core/framework/node_def_builder.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/ops_testutil.h" -#include "tensorflow/core/kernels/ops_util.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/platform/test_benchmark.h" -#include "tensorflow/core/public/session.h" - -namespace tensorflow { -namespace { - -enum class Device { CPU, GPU }; - -class FusedEmbeddingSparsePreLookUpOpTest : public OpsTestBase { - protected: - void MakeOpAndSetDevice(Device device, const int num_partitions, - const bool fill_empty_row, - const bool prune_invalid_id, const int default_id) { - if (device == Device::GPU) { - SetDevice(DEVICE_GPU, - std::unique_ptr(DeviceFactory::NewDevice( - "GPU", {}, "/job:a/replica:0/task:0"))); - } - - TF_EXPECT_OK(NodeDefBuilder("fused_embedding_sparse_pre_look_up", - "FusedEmbeddingSparsePreLookUp") - .Attr("num_partitions", num_partitions) - .Attr("partition_axis", 0) - .Attr("fill_empty_row", fill_empty_row) - .Attr("prune_invalid_id", prune_invalid_id) - .Attr("default_id", default_id) - .Input(FakeInput(num_partitions, DT_INT64)) - .Input(FakeInput(DT_INT64)) - .Input(FakeInput(DT_INT64)) - .Input(FakeInput(DT_INT64)) - .Finalize(node_def())); - TF_EXPECT_OK(InitOp()); - } -}; - -TEST_F(FusedEmbeddingSparsePreLookUpOpTest, Partition3_Int64) { - MakeOpAndSetDevice(Device::GPU, 3, false, false, -1); - // partition_shapes 0 - AddInputFromArray(TensorShape({2}), {6, 16}); - // partition_shapes 1 - AddInputFromArray(TensorShape({2}), {3, 16}); - // partition_shapes 2 - AddInputFromArray(TensorShape({2}), {7, 16}); - // sp_values - AddInputFromArray(TensorShape({12}), - {1, 5, 3, 6, 12, 14, 15, 0, 5, 5, 11, 7}); - // sp_indices - AddInputFromArray(TensorShape({12, 2}), - {2, 3, 4, 6, 1, 6, 12, 12, 12, 12, 11, 5, - 15, 0, 11, 6, 7, 9, 11, 8, 12, 13, 13, 0}); - // sp_dense_shape - AddInputFromArray(TensorShape({2}), {16, 16}); - - TF_ASSERT_OK(RunOpKernel()); - TF_EXPECT_OK(device_->Sync()); - { - Tensor expected_values(allocator(), DT_INT64, TensorShape({6})); - test::FillValues(&expected_values, {0, 1, 3, 5, 5, 5}); - test::ExpectTensorEqual(expected_values, *GetOutput(0)); - - Tensor expected_indices(allocator(), DT_INT64, TensorShape({6, 2})); - test::FillValues(&expected_indices, - {11, 6, 2, 3, 1, 6, 4, 6, 7, 9, 11, 8}); - test::ExpectTensorEqual(expected_indices, *GetOutput(3)); - } - - { - Tensor expected_values(allocator(), DT_INT64, TensorShape({2})); - test::FillValues(&expected_values, {0, 1}); - test::ExpectTensorEqual(expected_values, *GetOutput(1)); - Tensor expected_indices(allocator(), DT_INT64, TensorShape({2, 2})); - test::FillValues(&expected_indices, {12, 12, 13, 0}); - test::ExpectTensorEqual(expected_indices, *GetOutput(4)); - } - - { - Tensor expected_values(allocator(), DT_INT64, TensorShape({4})); - test::FillValues(&expected_values, {2, 3, 5, 6}); - test::ExpectTensorEqual(expected_values, *GetOutput(2)); - - Tensor expected_indices(allocator(), DT_INT64, TensorShape({4, 2})); - test::FillValues(&expected_indices, {12, 13, 12, 12, 11, 5, 15, 0}); - test::ExpectTensorEqual(expected_indices, *GetOutput(5)); - } -} - -TEST_F(FusedEmbeddingSparsePreLookUpOpTest, Partition2_Fill_Empty) { - MakeOpAndSetDevice(Device::GPU, 2, true, false, -1); - // partition_shapes 0 - AddInputFromArray(TensorShape({2}), {5, 8}); - // partition_shapes 1 - AddInputFromArray(TensorShape({2}), {5, 8}); - - // sp_values - AddInputFromArray(TensorShape({10}), - {0, 4, 3, -2, 5, -3, -4, 9, -6, 2}); - - // sp_indices - AddInputFromArray( - TensorShape({10, 2}), - {0, 0, 0, 4, 1, 2, 3, 0, 3, 4, 4, 0, 5, 2, 6, 0, 6, 1, 6, 7}); - - // sp_dense_shape - AddInputFromArray(TensorShape({2}), {7, 8}); - - TF_ASSERT_OK(RunOpKernel()); - TF_EXPECT_OK(device_->Sync()); - - { - Tensor expected_values(allocator(), DT_INT64, TensorShape({9})); - test::FillValues(&expected_values, {-6, -4, -3, -2, 0, 0, 2, 3, 4}); - test::ExpectTensorEqual(expected_values, *GetOutput(0)); - - Tensor expected_indices(allocator(), DT_INT64, TensorShape({9, 2})); - test::FillValues(&expected_indices, {6, 1, 5, 2, 4, 0, 3, 0, 0, 0, 2, - 0, 6, 7, 1, 2, 0, 4}); - test::ExpectTensorEqual(expected_indices, *GetOutput(2)); - } - - { - Tensor expected_values(allocator(), DT_INT64, TensorShape({2})); - test::FillValues(&expected_values, {0, 4}); - test::ExpectTensorEqual(expected_values, *GetOutput(1)); - Tensor expected_indices(allocator(), DT_INT64, TensorShape({2, 2})); - test::FillValues(&expected_indices, {3, 4, 6, 0}); - test::ExpectTensorEqual(expected_indices, *GetOutput(3)); - } -} - -TEST_F(FusedEmbeddingSparsePreLookUpOpTest, - Partition2_Fill_Empty_Prune_Invalid) { - MakeOpAndSetDevice(Device::GPU, 2, true, true, -1); - // partition_shapes 0 - AddInputFromArray(TensorShape({2}), {5, 8}); - // partition_shapes 1 - AddInputFromArray(TensorShape({2}), {5, 8}); - - // sp_values - AddInputFromArray(TensorShape({10}), - {0, 4, 3, -2, 5, -3, -4, 9, -6, 2}); - - // sp_indices - AddInputFromArray( - TensorShape({10, 2}), - {0, 0, 0, 4, 1, 2, 3, 0, 3, 4, 4, 0, 5, 2, 6, 0, 6, 1, 6, 7}); - - // sp_dense_shape - AddInputFromArray(TensorShape({2}), {7, 8}); - - TF_ASSERT_OK(RunOpKernel()); - TF_EXPECT_OK(device_->Sync()); - - { - Tensor expected_values(allocator(), DT_INT64, TensorShape({7})); - test::FillValues(&expected_values, {0, 0, 0, 0, 2, 3, 4}); - test::ExpectTensorEqual(expected_values, *GetOutput(0)); - - Tensor expected_indices(allocator(), DT_INT64, TensorShape({7, 2})); - test::FillValues(&expected_indices, - {0, 0, 2, 0, 4, 0, 5, 0, 6, 7, 1, 2, 0, 4}); - test::ExpectTensorEqual(expected_indices, *GetOutput(2)); - } - - { - Tensor expected_values(allocator(), DT_INT64, TensorShape({2})); - test::FillValues(&expected_values, {0, 4}); - test::ExpectTensorEqual(expected_values, *GetOutput(1)); - Tensor expected_indices(allocator(), DT_INT64, TensorShape({2, 2})); - test::FillValues(&expected_indices, {3, 4, 6, 0}); - test::ExpectTensorEqual(expected_indices, *GetOutput(3)); - } -} - -TEST_F(FusedEmbeddingSparsePreLookUpOpTest, - Partition2_Fill_Empty_Prune_Invalid_Default_7) { - MakeOpAndSetDevice(Device::GPU, 2, true, true, 7); - // partition_shapes 0 - AddInputFromArray(TensorShape({2}), {5, 8}); - // partition_shapes 1 - AddInputFromArray(TensorShape({2}), {5, 8}); - - // sp_values - AddInputFromArray(TensorShape({10}), - {0, 4, 3, -2, 5, -3, -4, 9, -6, 2}); - - // sp_indices - AddInputFromArray( - TensorShape({10, 2}), - {0, 0, 0, 4, 1, 2, 3, 0, 3, 4, 4, 0, 5, 2, 6, 0, 6, 1, 6, 7}); - - // sp_dense_shape - AddInputFromArray(TensorShape({2}), {7, 8}); - - TF_ASSERT_OK(RunOpKernel()); - TF_EXPECT_OK(device_->Sync()); - - { - Tensor expected_values(allocator(), DT_INT64, TensorShape({4})); - test::FillValues(&expected_values, {0, 2, 3, 4}); - test::ExpectTensorEqual(expected_values, *GetOutput(0)); - - Tensor expected_indices(allocator(), DT_INT64, TensorShape({4, 2})); - test::FillValues(&expected_indices, {0, 0, 6, 7, 1, 2, 0, 4}); - test::ExpectTensorEqual(expected_indices, *GetOutput(2)); - } - - { - Tensor expected_values(allocator(), DT_INT64, TensorShape({5})); - test::FillValues(&expected_values, {0, 2, 2, 2, 4}); - test::ExpectTensorEqual(expected_values, *GetOutput(1)); - Tensor expected_indices(allocator(), DT_INT64, TensorShape({5, 2})); - test::FillValues(&expected_indices, {3, 4, 2, 0, 4, 0, 5, 0, 6, 0}); - test::ExpectTensorEqual(expected_indices, *GetOutput(3)); - } -} - -TEST_F(FusedEmbeddingSparsePreLookUpOpTest, - Partition2_Prune_Invalid_Default_3) { - MakeOpAndSetDevice(Device::GPU, 2, false, true, 3); - // partition_shapes 0 - AddInputFromArray(TensorShape({2}), {5, 8}); - // partition_shapes 1 - AddInputFromArray(TensorShape({2}), {5, 8}); - - // sp_values - AddInputFromArray(TensorShape({10}), - {0, 4, 3, -2, 5, -3, -4, 9, -6, 2}); - - // sp_indices - AddInputFromArray( - TensorShape({10, 2}), - {0, 0, 0, 4, 1, 2, 3, 0, 3, 4, 4, 0, 5, 2, 6, 0, 6, 1, 6, 7}); - - // sp_dense_shape - AddInputFromArray(TensorShape({2}), {7, 8}); - - TF_ASSERT_OK(RunOpKernel()); - TF_EXPECT_OK(device_->Sync()); - - { - Tensor expected_values(allocator(), DT_INT64, TensorShape({4})); - test::FillValues(&expected_values, {0, 2, 3, 4}); - test::ExpectTensorEqual(expected_values, *GetOutput(0)); - - Tensor expected_indices(allocator(), DT_INT64, TensorShape({4, 2})); - test::FillValues(&expected_indices, {0, 0, 6, 7, 1, 2, 0, 4}); - test::ExpectTensorEqual(expected_indices, *GetOutput(2)); - } - - { - Tensor expected_values(allocator(), DT_INT64, TensorShape({2})); - test::FillValues(&expected_values, {0, 4}); - test::ExpectTensorEqual(expected_values, *GetOutput(1)); - Tensor expected_indices(allocator(), DT_INT64, TensorShape({2, 2})); - test::FillValues(&expected_indices, {3, 4, 6, 0}); - test::ExpectTensorEqual(expected_indices, *GetOutput(3)); - } -} - -TEST_F(FusedEmbeddingSparsePreLookUpOpTest, Partition1) { - MakeOpAndSetDevice(Device::GPU, 1, false, false, -1); - // partition_shapes 0 - AddInputFromArray(TensorShape({2}), {10, 8}); - - // sp_values - AddInputFromArray(TensorShape({10}), - {0, 4, 3, -2, 5, -3, -4, 9, -6, 2}); - - // sp_indices - AddInputFromArray( - TensorShape({10, 2}), - {0, 0, 0, 4, 1, 2, 3, 0, 3, 4, 4, 0, 5, 2, 6, 0, 6, 1, 6, 7}); - - // sp_dense_shape - AddInputFromArray(TensorShape({2}), {7, 8}); - - TF_ASSERT_OK(RunOpKernel()); - TF_EXPECT_OK(device_->Sync()); - - { - Tensor expected_values(allocator(), DT_INT64, TensorShape({10})); - test::FillValues(&expected_values, - {0, 4, 3, -2, 5, -3, -4, 9, -6, 2}); - test::ExpectTensorEqual(expected_values, *GetOutput(0)); - - Tensor expected_indices(allocator(), DT_INT64, TensorShape({10, 2})); - test::FillValues(&expected_indices, {0, 0, 0, 4, 1, 2, 3, 0, 3, 4, - 4, 0, 5, 2, 6, 0, 6, 1, 6, 7}); - test::ExpectTensorEqual(expected_indices, *GetOutput(1)); - } -} - -TEST_F(FusedEmbeddingSparsePreLookUpOpTest, - Partition1_Fill_Empty_Prune_Invalid_Default_3) { - MakeOpAndSetDevice(Device::GPU, 1, true, true, 3); - // partition_shapes 0 - AddInputFromArray(TensorShape({2}), {10, 8}); - - // sp_values - AddInputFromArray(TensorShape({10}), - {0, 4, 3, -2, 5, -3, -4, 9, -6, 2}); - - // sp_indices - AddInputFromArray( - TensorShape({10, 2}), - {0, 0, 0, 4, 1, 2, 3, 0, 3, 4, 4, 0, 5, 2, 6, 0, 6, 1, 6, 7}); - - // sp_dense_shape - AddInputFromArray(TensorShape({2}), {7, 8}); - - TF_ASSERT_OK(RunOpKernel()); - TF_EXPECT_OK(device_->Sync()); - - { - Tensor expected_values(allocator(), DT_INT64, TensorShape({9})); - test::FillValues(&expected_values, {0, 4, 3, 5, 9, 2, 3, 3, 3}); - test::ExpectTensorEqual(expected_values, *GetOutput(0)); - - Tensor expected_indices(allocator(), DT_INT64, TensorShape({9, 2})); - test::FillValues(&expected_indices, {0, 0, 0, 4, 1, 2, 3, 4, 6, 0, 6, - 7, 2, 0, 4, 0, 5, 0}); - test::ExpectTensorEqual(expected_indices, *GetOutput(1)); - } -} - -} // namespace -} // namespace tensorflow \ No newline at end of file diff --git a/tensorflow/core/kernels/fused_embedding/gpu/common.cu.h b/tensorflow/core/kernels/fused_embedding/gpu/common.cu.h new file mode 100644 index 00000000000..5a96fa15893 --- /dev/null +++ b/tensorflow/core/kernels/fused_embedding/gpu/common.cu.h @@ -0,0 +1,55 @@ +#ifndef TENSORFLOW_CORE_KERNELS_FUSED_EMBEDDING_GPU_COMMON_CU_H_ +#define TENSORFLOW_CORE_KERNELS_FUSED_EMBEDDING_GPU_COMMON_CU_H_ + +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/core/framework/op_kernel.h" + +#define CK_CUDA_THROW_(x) \ + do { \ + cudaError_t retval = (x); \ + if (retval != cudaSuccess) { \ + throw std::runtime_error(std::string("Runtime error: ") + \ + (cudaGetErrorString(retval)) + " " + __FILE__ + \ + ":" + std::to_string(__LINE__) + " \n"); \ + } \ + } while (0) + +namespace tensorflow { + +namespace fused_embedding { + +template +inline T* data_p_with_type(Tensor& t) { + return reinterpret_cast(t.data()); +} + +template +inline T* data_p_with_type(const Tensor& t) { + return reinterpret_cast(t.data()); +} + +template +inline T* data_p_with_type(Tensor* t) { + return reinterpret_cast(t->data()); +} + +template +inline T* data_p_with_type(const Tensor* t) { + return reinterpret_cast(t->data()); +} + +struct IndicePair { + int64_t row_in_batch; + int64_t entry_in_column; +}; + +} // namespace fused_embedding + +} // namespace tensorflow + +#endif // GOOGLE_CUDA + +#endif // TENSORFLOW_CORE_KERNELS_FUSED_EMBEDDING_GPU_COMMON_CU_H_ \ No newline at end of file diff --git a/tensorflow/core/kernels/fused_embedding/gpu/functions/hash_functions.cu.h b/tensorflow/core/kernels/fused_embedding/gpu/functions/hash_functions.cu.h new file mode 100644 index 00000000000..7a5a9ff5e68 --- /dev/null +++ b/tensorflow/core/kernels/fused_embedding/gpu/functions/hash_functions.cu.h @@ -0,0 +1,128 @@ +#ifndef TENSORFLOW_CORE_KERNELS_FUSED_EMBEDDING_GPU_FUNCTIONS_HASH_FUNCTIONS_CU_H_ +#define TENSORFLOW_CORE_KERNELS_FUSED_EMBEDDING_GPU_FUNCTIONS_HASH_FUNCTIONS_CU_H_ + +// MurmurHash3_32 implementation from +// https://github.com/aappleby/smhasher/blob/master/src/MurmurHash3.cpp +//----------------------------------------------------------------------------- +// MurmurHash3 was written by Austin Appleby, and is placed in the public +// domain. The author hereby disclaims copyright to this source code. +// Note - The x86 and x64 versions do _not_ produce the same results, as the +// algorithms are optimized for their respective platforms. You can still +// compile and run any of them on any platform, but your performance with the +// non-native version will be less than optimal. + +namespace tensorflow { + +namespace gpu_unique_with_counts { + +template +struct MurmurHash3_32 { + using argument_type = Key; + using result_type = uint32_t; + + /*__forceinline__ + __host__ __device__ + MurmurHash3_32() : m_seed( 0 ) {}*/ + + __forceinline__ __host__ __device__ static uint32_t rotl32(uint32_t x, + int8_t r) { + return (x << r) | (x >> (32 - r)); + } + + __forceinline__ __host__ __device__ static uint32_t fmix32(uint32_t h) { + h ^= h >> 16; + h *= 0x85ebca6b; + h ^= h >> 13; + h *= 0xc2b2ae35; + h ^= h >> 16; + return h; + } + + /* --------------------------------------------------------------------------*/ + /** + * @Synopsis Combines two hash values into a new single hash value. Called + * repeatedly to create a hash value from several variables. + * Taken from the Boost hash_combine function + * https://www.boost.org/doc/libs/1_35_0/doc/html/boost/hash_combine_id241013.html + * + * @Param lhs The first hash value to combine + * @Param rhs The second hash value to combine + * + * @Returns A hash value that intelligently combines the lhs and rhs hash + * values + */ + /* ----------------------------------------------------------------------------*/ + __host__ __device__ static result_type hash_combine(result_type lhs, + result_type rhs) { + result_type combined{lhs}; + + combined ^= rhs + 0x9e3779b9 + (combined << 6) + (combined >> 2); + + return combined; + } + + __forceinline__ __host__ __device__ static result_type hash(const Key& key) { + constexpr int len = sizeof(argument_type); + const uint8_t* const data = (const uint8_t*)&key; + constexpr int nblocks = len / 4; + uint32_t h1 = m_seed; + constexpr uint32_t c1 = 0xcc9e2d51; + constexpr uint32_t c2 = 0x1b873593; + //---------- + // body + const uint32_t* const blocks = (const uint32_t*)(data + nblocks * 4); + for (int i = -nblocks; i; i++) { + uint32_t k1 = blocks[i]; // getblock32(blocks,i); + k1 *= c1; + k1 = rotl32(k1, 15); + k1 *= c2; + h1 ^= k1; + h1 = rotl32(h1, 13); + h1 = h1 * 5 + 0xe6546b64; + } + //---------- + // tail + const uint8_t* tail = (const uint8_t*)(data + nblocks * 4); + uint32_t k1 = 0; + switch (len & 3) { + case 3: + k1 ^= tail[2] << 16; + case 2: + k1 ^= tail[1] << 8; + case 1: + k1 ^= tail[0]; + k1 *= c1; + k1 = rotl32(k1, 15); + k1 *= c2; + h1 ^= k1; + }; + //---------- + // finalization + h1 ^= len; + h1 = fmix32(h1); + return h1; + } +}; + +template +struct Fix_Hash { + using result_type = index_type; + + __forceinline__ __host__ __device__ static index_type hash( + const key_type& key) { + return result; + } +}; + +template +struct Mod_Hash { + __forceinline__ __host__ __device__ static result_type hash( + const key_type& key) { + return (result_type)key; + } +}; + +} // namespace gpu_unique_with_counts +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_FUSED_EMBEDDING_GPU_FUNCTIONS_HASH_FUNCTIONS_CU_H_ diff --git a/tensorflow/core/kernels/fused_embedding/gpu/functions/kernels.cu.cc b/tensorflow/core/kernels/fused_embedding/gpu/functions/kernels.cu.cc new file mode 100644 index 00000000000..6cf7c223aeb --- /dev/null +++ b/tensorflow/core/kernels/fused_embedding/gpu/functions/kernels.cu.cc @@ -0,0 +1,614 @@ + + +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/core/kernels/fused_embedding/gpu/functions/kernels.cu.h" + +#include + +#include "tensorflow/core/kernels/fused_embedding/gpu/common.cu.h" +#include "tensorflow/core/platform/stream_executor.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" + +namespace tensorflow { + +namespace fused_embedding { + +using GPUDevice = Eigen::GpuDevice; + +#define LINER_MAPPING_THREADS 128 + +inline int CalcBlocksLinearMapping(const int problem_size, const int threads) { + return problem_size % threads == 0 ? (problem_size / threads) + : (problem_size / threads + 1); +} + +__global__ void InitFlagsToOneInt4Kernel(int length, int* flags) { + int offset = blockIdx.x * blockDim.x + threadIdx.x; + if (4 * offset + 3 < length) { + *((int4*)(flags + 4 * offset)) = make_int4(1, 1, 1, 1); + } else if (4 * offset < length) { + for (int i = 0; i < length - 4 * offset; i++) { + flags[4 * offset + i] = 1; + } + } +} + +void InitFlagsToOneInt4(const GPUDevice& d, int length, int* flags) { + const int threads = LINER_MAPPING_THREADS; + const int blocks = CalcBlocksLinearMapping(length, threads * 4); + TF_CHECK_OK(GpuLaunchKernel(InitFlagsToOneInt4Kernel, blocks, threads, 0, + d.stream(), length, flags)); +} + +__global__ void DetectInvalidKernel(const int64_t* values, const int64_t nnz, + int* invalid_id_flag) { + const int offset = blockIdx.x * blockDim.x + threadIdx.x; + if (offset < nnz) { + const int64_t value = values[offset]; + if (value < 0) { + atomicAnd(invalid_id_flag + offset, 0); + } + } +} + +void DetectInvalid(const GPUDevice& d, const int64_t* values, const int64_t nnz, + int* invalid_id_flag) { + const int threads = LINER_MAPPING_THREADS; + const int blocks = CalcBlocksLinearMapping(nnz, threads); + TF_CHECK_OK(GpuLaunchKernel(DetectInvalidKernel, blocks, threads, 0, + d.stream(), values, nnz, invalid_id_flag)); +} + +__global__ void FusedMultiFunctionalKernel( + const IndicePair* indices, const int64_t* values, const int64_t nnz, + const int64_t batch_size, const bool prune, const int64_t default_id, + int* row_emptiness_flag, int* invalid_id_flag, + IndicePair* tmp_indices_buffer, int64_t* values_extended) { + // This kernel will do many things together + // 1. The first part of threads will do job 1(DetectRowEmptiness), others will + // do job2(InitBatchRowsBuffer) + // 2. Do job3 (set values extended to default id) + + const int offset = blockIdx.x * blockDim.x + threadIdx.x; + if (offset < nnz) { + // do DetectRowEmptiness + if (prune) { + const int64_t value = values[offset]; + if (value < 0) { + // invalid, set invalid_id_flag + atomicAnd(invalid_id_flag + offset, 0); + } else { + // valid, set row_emptiness_flag + const int64_t row_in_batch = indices[offset].row_in_batch; + atomicAnd(row_emptiness_flag + row_in_batch, 0); + } + } else { + // set row_emptiness_flag + const int64_t row_in_batch = indices[offset].row_in_batch; + atomicAnd(row_emptiness_flag + row_in_batch, 0); + } + } else { + // do InitBatchRowsBuffer + const int other_offset = offset - nnz; + if (other_offset < batch_size) { + tmp_indices_buffer[other_offset].row_in_batch = other_offset; + // always set entry id to 0; + tmp_indices_buffer[other_offset].entry_in_column = 0; + } + } + + // set values extended to default id + if (2 * offset + 1 < nnz + batch_size) { + longlong2 l2 = make_longlong2(default_id, default_id); + *((longlong2*)(values_extended + 2 * offset)) = l2; + } else if (2 * offset < nnz + batch_size) { + values_extended[2 * offset] = default_id; + } +} + +void FusedMultiFunctional(const GPUDevice& d, const IndicePair* indices, + const int64_t* values, const int64_t nnz, + const int64_t batch_size, const bool prune, + const int64_t default_id, int* row_emptiness_flag, + int* invalid_id_flag, IndicePair* tmp_indices_buffer, + int64_t* values_extended) { + const int threads = LINER_MAPPING_THREADS; + const int blocks = CalcBlocksLinearMapping(nnz + batch_size, threads); + TF_CHECK_OK(GpuLaunchKernel( + FusedMultiFunctionalKernel, blocks, threads, 0, d.stream(), indices, + values, nnz, batch_size, prune, default_id, row_emptiness_flag, + invalid_id_flag, tmp_indices_buffer, values_extended)); +} + +template +__global__ void InitFillEmptyBuffersKernel( + int64_t batch_size, int64_t nnz, int64_t default_id, float default_weight, + const int64_t* sp_values, const int64_t* sp_indices, int64_t* sp_values_out, + int64_t* sp_indices_out, float* sp_weights_values_out, bool* is_row_empty, + int64_t* tmp_indices) { + const int global_tid = threadIdx.x + blockDim.x * blockIdx.x; + if (global_tid < batch_size) { + // init is_row_empty + is_row_empty[global_tid] = true; + } else if (global_tid < 3 * batch_size) { + // init tmp indices + const int new_global_tid = global_tid - batch_size; + // even tid keep batch_id, odd tid keep 0 + const int64_t data = ((new_global_tid + 1) % 2) * (new_global_tid / 2); + tmp_indices[new_global_tid] = data; + } + + if (global_tid < (batch_size + nnz)) { + sp_values_out[global_tid] = default_id; + } + + // using template here to let compiler decide whether to optimize this section + // out + if (use_sparse_weights) { + if (global_tid < (batch_size + nnz)) { + sp_weights_values_out[global_tid] = default_weight; + } + } + + // using template here to let compiler decide whether to optimize this section + // out + if (!prune) { + if (global_tid < nnz) { + sp_values_out[global_tid] = sp_values[global_tid]; + } + + if (global_tid < 2 * nnz) { + sp_indices_out[global_tid] = sp_indices[global_tid]; + } + } +} + +void InitFillEmptyBuffers(const GPUDevice& d, const int64_t batch_size, + const int64_t nnz, const int64_t default_id, + const float default_weight, const bool prune, + const bool use_sparse_weights, + const int64_t* sp_values, const int64_t* sp_indices, + int64_t* sp_values_out, int64_t* sp_indices_out, + float* sp_weights_values_out, bool* is_row_empty, + int64_t* tmp_indices) { + const int threads = 32; + const int blocks = CalcBlocksLinearMapping( + std::max(std::max(3 * batch_size, batch_size + nnz), 2 * nnz), threads); + +#define LAUNCH_KERNEL(prune, use_sparse_weights) \ + TF_CHECK_OK(GpuLaunchKernel( \ + InitFillEmptyBuffersKernel, blocks, threads, \ + 0, d.stream(), batch_size, nnz, default_id, default_weight, sp_values, \ + sp_indices, sp_values_out, sp_indices_out, sp_weights_values_out, \ + is_row_empty, tmp_indices)); + + if (prune && use_sparse_weights) { + LAUNCH_KERNEL(true, true); + } else if (prune && !use_sparse_weights) { + LAUNCH_KERNEL(true, false); + } else if (!prune && use_sparse_weights) { + LAUNCH_KERNEL(false, true); + } else if (!prune && !use_sparse_weights) { + LAUNCH_KERNEL(false, false); + } +#undef LAUNCH_KERNEL +} + +template +void __global__ DetectEmptyRowKernel(const int64_t* indices, + const int64_t* sp_values, + const float* sp_weights_values, + const int64_t nnz, bool* is_row_empty) { + const int global_tid = threadIdx.x + blockIdx.x * blockDim.x; + if (global_tid < nnz) { + const int64_t row_in_batch = indices[2 * global_tid]; + // use template for compiler to optimize + if (prune) { + if (prune_sparse_weights) { + if (sp_values[global_tid] >= 0 && sp_weights_values[global_tid] > 0.0) { + is_row_empty[row_in_batch] = false; + } + } else { + if (sp_values[global_tid] >= 0) { + is_row_empty[row_in_batch] = false; + } + } + } else { + is_row_empty[row_in_batch] = false; + } + } +} + +void DetectEmptyRow(const GPUDevice& d, const int64_t* indices, + const int64_t* sp_values, const float* sp_weights_values, + const bool prune, const bool prune_sparse_weights, + const int64_t nnz, bool* is_row_empty) { + const int threads = 32; + const int blocks = CalcBlocksLinearMapping(nnz, threads); + +#define LAUNCH_KERNEL(prune, prune_sparse_weights) \ + TF_CHECK_OK(GpuLaunchKernel( \ + DetectEmptyRowKernel, blocks, threads, 0, \ + d.stream(), indices, sp_values, sp_weights_values, nnz, is_row_empty)); + + if (prune) { + if (prune_sparse_weights) { + LAUNCH_KERNEL(true, true); + } else { + LAUNCH_KERNEL(true, false); + } + } else { + LAUNCH_KERNEL(false, false); + } +#undef LAUNCH_KERNEL +} + +template +__global__ void RangeInitKernel(const int64_t length, T* out) { + int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < length) { + out[idx] = T(idx); + } +} + +template +void RangeInit(const GPUDevice& d, const int64_t length, T* out) { + const int threads = LINER_MAPPING_THREADS; + const int blocks = CalcBlocksLinearMapping(length, threads); + TF_CHECK_OK(GpuLaunchKernel(RangeInitKernel, blocks, threads, 0, + d.stream(), length, out)); +} + +template void RangeInit(const GPUDevice& d, const int64_t length, + int64_t* out); + +__global__ void SumUpEmbeddingShardSinglePartitionKernel( + const float* emb_shard, const int64_t* indices_before_unique, + const int* unique_idxs, const float* sp_weights_values, + const bool use_sparse_weights, const int nnz, const float max_norm, + const int emb_vec_size, float* emb_vectors, int* feature_nums) { + __shared__ float l2_sum[1]; + + if (blockIdx.x < nnz) { + const int64_t row_in_batch = indices_before_unique[2 * blockIdx.x]; + const int unique_id = unique_idxs[blockIdx.x]; + float emb_element = emb_shard[unique_id * emb_vec_size + threadIdx.x]; + if (max_norm >= 0.0f) { + if (threadIdx.x == 0) { + l2_sum[0] = 0.0f; + } + __syncthreads(); + atomicAdd(l2_sum, emb_element * emb_element); + __syncthreads(); + float l2_norm = sqrtf(l2_sum[0]); + if (l2_norm > max_norm) { + emb_element *= max_norm / l2_norm; + } + } + + if (use_sparse_weights) { + atomicAdd(emb_vectors + row_in_batch * emb_vec_size + threadIdx.x, + emb_element * sp_weights_values[blockIdx.x]); + } else { + atomicAdd(emb_vectors + row_in_batch * emb_vec_size + threadIdx.x, + emb_element); + } + + if (threadIdx.x == 0) { + atomicAdd(feature_nums + row_in_batch, 1); + } + } +} + +void SumUpEmbeddingShardSinglePartition( + const GPUDevice& d, const float* emb_shard, + const int64_t* indices_before_unique, const int* unique_idxs, + const float* sp_weights_values, const bool use_sparse_weights, + const int nnz, const float max_norm, const int emb_vec_size, + float* emb_vectors, int* feature_nums) { + const int blocks = nnz; + const int threads = emb_vec_size; + TF_CHECK_OK( + GpuLaunchKernel(SumUpEmbeddingShardSinglePartitionKernel, blocks, threads, + 0, d.stream(), emb_shard, indices_before_unique, + unique_idxs, sp_weights_values, use_sparse_weights, nnz, + max_norm, emb_vec_size, emb_vectors, feature_nums)); +} + +__global__ void SumUpEmbeddingShardMultiPartitionKernel( + const void* const* emb_shard_ptrs, const int* partition_permutation, + const int64_t* indices_before_unique, const int* unique_idxs, + const float* sp_weights_values, const bool use_sparse_weights, + const int nnz, const float max_norm, const int emb_vec_size, + float* emb_vectors, int* feature_nums) { + __shared__ float l2_sum[1]; + + if (blockIdx.x < nnz) { + const int64_t row_in_batch = indices_before_unique[2 * blockIdx.x]; + const int unique_id = unique_idxs[blockIdx.x]; + const int partition_id = partition_permutation[2 * unique_id]; + const int64_t offset_in_partition = + partition_permutation[2 * unique_id + 1]; + + float emb_element = + ((const float*)(emb_shard_ptrs[partition_id]))[offset_in_partition * + emb_vec_size + + threadIdx.x]; + if (max_norm >= 0.0f) { + if (threadIdx.x == 0) { + l2_sum[0] = 0.0f; + } + __syncthreads(); + atomicAdd(l2_sum, emb_element * emb_element); + __syncthreads(); + float l2_norm = sqrtf(l2_sum[0]); + if (l2_norm > max_norm) { + emb_element *= max_norm / l2_norm; + } + } + + if (use_sparse_weights) { + atomicAdd(emb_vectors + row_in_batch * emb_vec_size + threadIdx.x, + emb_element * sp_weights_values[blockIdx.x]); + } else { + atomicAdd(emb_vectors + row_in_batch * emb_vec_size + threadIdx.x, + emb_element); + } + + if (threadIdx.x == 0) { + atomicAdd(feature_nums + row_in_batch, 1); + } + } +} + +void SumUpEmbeddingShardMultiPartition( + const GPUDevice& d, const void* const* emb_shard_ptrs, + const int* partition_permutation, const int64_t* indices_before_unique, + const int* unique_idxs, const float* sp_weights_values, + const bool use_sparse_weights, const int nnz, const float max_norm, + const int emb_vec_size, float* emb_vectors, int* feature_nums) { + const int blocks = nnz; + const int threads = emb_vec_size; + TF_CHECK_OK(GpuLaunchKernel( + SumUpEmbeddingShardMultiPartitionKernel, blocks, threads, 0, d.stream(), + emb_shard_ptrs, partition_permutation, indices_before_unique, unique_idxs, + sp_weights_values, use_sparse_weights, nnz, max_norm, emb_vec_size, + emb_vectors, feature_nums)); +} + +template +__global__ void ApplyCombinerKernel(const bool* is_row_empty, + const bool set_empty_row_zero, + int* feature_nums, float* emb_vectors) { + const int offset = blockIdx.x * blockDim.x + threadIdx.x; + const int feature_num = feature_nums[blockIdx.x]; + if (set_empty_row_zero) { + if (is_row_empty[blockIdx.x]) { + feature_nums[blockIdx.x] = 0; + emb_vectors[offset] = 0.0f; + return; + } + } + const float emb_element = emb_vectors[offset]; + emb_vectors[offset] = Combine(emb_element, feature_num); +} + +template +void ApplyCombiner(const GPUDevice& d, const int batch_size, + const int emb_vec_size, const bool* is_row_empty, + const bool set_empty_row_zero, int* feature_nums, + float* emb_vectors) { + const int blocks = batch_size; + const int threads = emb_vec_size; + TF_CHECK_OK(GpuLaunchKernel(ApplyCombinerKernel, blocks, threads, 0, + d.stream(), is_row_empty, set_empty_row_zero, + feature_nums, emb_vectors)); +} + +template void ApplyCombiner(const GPUDevice& d, const int batch_size, + const int emb_vec_size, + const bool* is_row_empty, + const bool set_empty_row_zero, + int* feature_nums, float* emb_vectors); +template void ApplyCombiner(const GPUDevice& d, const int batch_size, + const int emb_vec_size, + const bool* is_row_empty, + const bool set_empty_row_zero, + int* feature_nums, float* emb_vectors); +template void ApplyCombiner(const GPUDevice& d, const int batch_size, + const int emb_vec_size, + const bool* is_row_empty, + const bool set_empty_row_zero, + int* feature_nums, float* emb_vectors); + +template +__global__ void DistributeGradToShardSinglePartitionKernel( + const float* top_grad, const float* emb_shard, + const int64_t* indices_before_unique, const int* unique_idxs, + const float* sp_weights_values, const bool use_sparse_weights, + const int nnz, const int emb_vec_size, const float max_norm, + const bool set_empty_row_zero, const int* feature_nums, + const bool* is_row_empty, float* grad_shard) { + __shared__ float l2_sum[1]; + float l2_norm = -1.0f; + + if (blockIdx.x < nnz) { + const int64_t row_in_batch = indices_before_unique[2 * blockIdx.x]; + if (set_empty_row_zero && is_row_empty[row_in_batch]) { + return; + } + + const int unique_id = unique_idxs[blockIdx.x]; + + if (max_norm >= 0.0f) { + const float emb_element = + emb_shard[unique_id * emb_vec_size + threadIdx.x]; + if (threadIdx.x == 0) { + l2_sum[0] = 0.0f; + } + __syncthreads(); + atomicAdd(l2_sum, emb_element * emb_element); + __syncthreads(); + l2_norm = sqrtf(l2_sum[0]); + } + + float grad = top_grad[row_in_batch * emb_vec_size + threadIdx.x]; + const int feature_num = feature_nums[row_in_batch]; + grad = CombineGrad(grad, feature_num); + if (use_sparse_weights) { + grad = grad * sp_weights_values[blockIdx.x]; + } + if (max_norm >= 0.0f && l2_norm > max_norm) { + grad *= max_norm / l2_norm; + } + + atomicAdd(grad_shard + unique_id * emb_vec_size + threadIdx.x, grad); + } +} + +template +void DistributeGradToShardSinglePartition( + const GPUDevice& d, const float* top_grad, const float* emb_shard, + const int64_t* indices_before_unique, const int* unique_idxs, + const float* sp_weights_values, const bool use_sparse_weights, + const int nnz, const int emb_vec_size, const float max_norm, + const bool set_empty_row_zero, const int* feature_nums, + const bool* is_row_empty, float* grad_shard) { + const int blocks = nnz; + const int threads = emb_vec_size; + TF_CHECK_OK(GpuLaunchKernel( + DistributeGradToShardSinglePartitionKernel, blocks, threads, 0, + d.stream(), top_grad, emb_shard, indices_before_unique, unique_idxs, + sp_weights_values, use_sparse_weights, nnz, emb_vec_size, max_norm, + set_empty_row_zero, feature_nums, is_row_empty, grad_shard)); +} + +template void DistributeGradToShardSinglePartition( + const GPUDevice& d, const float* top_grad, const float* emb_shard, + const int64_t* indices_before_unique, const int* unique_idxs, + const float* sp_weights_values, const bool use_sparse_weights, + const int nnz, const int emb_vec_size, const float max_norm, + const bool set_empty_row_zero, const int* feature_nums, + const bool* is_row_empty, float* grad_shard); + +template void DistributeGradToShardSinglePartition( + const GPUDevice& d, const float* top_grad, const float* emb_shard, + const int64_t* indices_before_unique, const int* unique_idxs, + const float* sp_weights_values, const bool use_sparse_weights, + const int nnz, const int emb_vec_size, const float max_norm, + const bool set_empty_row_zero, const int* feature_nums, + const bool* is_row_empty, float* grad_shard); + +template void DistributeGradToShardSinglePartition( + const GPUDevice& d, const float* top_grad, const float* emb_shard, + const int64_t* indices_before_unique, const int* unique_idxs, + const float* sp_weights_values, const bool use_sparse_weights, + const int nnz, const int emb_vec_size, const float max_norm, + const bool set_empty_row_zero, const int* feature_nums, + const bool* is_row_empty, float* grad_shard); + +template +__global__ void DistributeGradToShardMultiPartitionKernel( + const float* top_grad, const void* const* emb_shard_ptrs, + const int* partition_permutation, const int64_t* indices_before_unique, + const int* unique_idxs, const float* sp_weights_values, + const bool use_sparse_weights, const int nnz, const int emb_vec_size, + const float max_norm, const bool set_empty_row_zero, + const int* feature_nums, const bool* is_row_empty, void** grad_shard_ptrs) { + __shared__ float l2_sum[1]; + float l2_norm = -1.0f; + + if (blockIdx.x < nnz) { + const int64_t row_in_batch = indices_before_unique[2 * blockIdx.x]; + if (set_empty_row_zero && is_row_empty[row_in_batch]) { + return; + } + const int unique_id = unique_idxs[blockIdx.x]; + const int partition_id = partition_permutation[2 * unique_id]; + const int64_t offset_in_partition = + partition_permutation[2 * unique_id + 1]; + + if (max_norm >= 0.0f) { + float emb_element = + ((const float*)(emb_shard_ptrs[partition_id]))[offset_in_partition * + emb_vec_size + + threadIdx.x]; + if (threadIdx.x == 0) { + l2_sum[0] = 0.0f; + } + __syncthreads(); + atomicAdd(l2_sum, emb_element * emb_element); + __syncthreads(); + l2_norm = sqrtf(l2_sum[0]); + } + + float grad = top_grad[row_in_batch * emb_vec_size + threadIdx.x]; + const int feature_num = feature_nums[row_in_batch]; + grad = CombineGrad(grad, feature_num); + if (use_sparse_weights) { + grad = grad * sp_weights_values[blockIdx.x]; + } + if (max_norm >= 0.0f && l2_norm > max_norm) { + grad *= max_norm / l2_norm; + } + + atomicAdd(((float*)(grad_shard_ptrs[partition_id])) + + offset_in_partition * emb_vec_size + threadIdx.x, + grad); + } +} + +template +void DistributeGradToShardMultiPartition( + const GPUDevice& d, const float* top_grad, + const void* const* emb_shard_ptrs, const int* partition_permutation, + const int64_t* indices_before_unique, const int* unique_idxs, + const float* sp_weights_values, const bool use_sparse_weights, + const int nnz, const int emb_vec_size, const float max_norm, + const bool set_empty_row_zero, const int* feature_nums, + const bool* is_row_empty, void** grad_shard_ptrs) { + const int blocks = nnz; + const int threads = emb_vec_size; + TF_CHECK_OK(GpuLaunchKernel( + DistributeGradToShardMultiPartitionKernel, blocks, threads, 0, + d.stream(), top_grad, emb_shard_ptrs, partition_permutation, + indices_before_unique, unique_idxs, sp_weights_values, use_sparse_weights, + nnz, emb_vec_size, max_norm, set_empty_row_zero, feature_nums, + is_row_empty, grad_shard_ptrs)); +} + +template void DistributeGradToShardMultiPartition( + const GPUDevice& d, const float* top_grad, + const void* const* emb_shard_ptrs, const int* partition_permutation, + const int64_t* indices_before_unique, const int* unique_idxs, + const float* sp_weights_values, const bool use_sparse_weights, + const int nnz, const int emb_vec_size, const float max_norm, + const bool set_empty_row_zero, const int* feature_nums, + const bool* is_row_empty, void** grad_shard_ptrs); + +template void DistributeGradToShardMultiPartition( + const GPUDevice& d, const float* top_grad, + const void* const* emb_shard_ptrs, const int* partition_permutation, + const int64_t* indices_before_unique, const int* unique_idxs, + const float* sp_weights_values, const bool use_sparse_weights, + const int nnz, const int emb_vec_size, const float max_norm, + const bool set_empty_row_zero, const int* feature_nums, + const bool* is_row_empty, void** grad_shard_ptrs); + +template void DistributeGradToShardMultiPartition( + const GPUDevice& d, const float* top_grad, + const void* const* emb_shard_ptrs, const int* partition_permutation, + const int64_t* indices_before_unique, const int* unique_idxs, + const float* sp_weights_values, const bool use_sparse_weights, + const int nnz, const int emb_vec_size, const float max_norm, + const bool set_empty_row_zero, const int* feature_nums, + const bool* is_row_empty, void** grad_shard_ptrs); + +} // namespace fused_embedding + +} // namespace tensorflow + +#endif // GOOGLE_CUDA \ No newline at end of file diff --git a/tensorflow/core/kernels/fused_embedding/gpu/functions/kernels.cu.h b/tensorflow/core/kernels/fused_embedding/gpu/functions/kernels.cu.h new file mode 100644 index 00000000000..6b444562cbe --- /dev/null +++ b/tensorflow/core/kernels/fused_embedding/gpu/functions/kernels.cu.h @@ -0,0 +1,136 @@ +#ifndef TENSORFLOW_CORE_KERNELS_FUSED_EMBEDDING_GPU_FUNCTIONS_KERNELS_CU_H_ +#define TENSORFLOW_CORE_KERNELS_FUSED_EMBEDDING_GPU_FUNCTIONS_KERNELS_CU_H_ + +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/kernels/fused_embedding/gpu/common.cu.h" + +namespace tensorflow { + +namespace fused_embedding { + +using GPUDevice = Eigen::GpuDevice; + +enum Combiner { Mean, Sum, Sqrtn }; + +template +__forceinline__ __device__ float Combine(const float in, const int feature_num); + +template <> +__forceinline__ __device__ float Combine(const float in, + const int feature_num) { + return in / sqrtf(feature_num); +} + +template <> +__forceinline__ __device__ float Combine(const float in, + const int feature_num) { + return in / feature_num; +} + +template <> +__forceinline__ __device__ float Combine(const float in, + const int feature_num) { + return in; +} + +template +__forceinline__ __device__ float CombineGrad(const float grad, + const int feature_num); + +template <> +__forceinline__ __device__ float CombineGrad(const float grad, + const int feature_num) { + return grad / sqrtf(feature_num); +} + +template <> +__forceinline__ __device__ float CombineGrad(const float grad, + const int feature_num) { + return grad / feature_num; +} + +template <> +__forceinline__ __device__ float CombineGrad(const float grad, + const int feature_num) { + return grad; +} + +void InitFlagsToOneInt4(const GPUDevice& d, int length, int* flags); + +void DetectInvalid(const GPUDevice& d, const int64_t* values, const int64_t nnz, + int* invalid_id_flag); + +void FusedMultiFunctional(const GPUDevice& d, const IndicePair* indices, + const int64_t* values, const int64_t nnz, + const int64_t batch_size, const bool prune_invalid_id, + const int64_t default_id, int* row_emptiness_flag, + int* invalid_id_flag, IndicePair* tmp_indices_buffer, + int64_t* values_extended); + +void InitFillEmptyBuffers(const GPUDevice& d, const int64_t batch_size, + const int64_t nnz, const int64_t default_id, + const float default_weight, const bool prune, + const bool use_sparse_weights, + const int64_t* sp_values, const int64_t* sp_indices, + int64_t* sp_values_out, int64_t* sp_indices_out, + float* sp_weights_values_out, bool* is_row_empty, + int64_t* tmp_indices); + +void DetectEmptyRow(const GPUDevice& d, const int64_t* indices, + const int64_t* sp_values, const float* sp_weights_values, + const bool prune, const bool prune_sparse_weights, + const int64_t nnz, bool* is_row_empty); + +template +void RangeInit(const GPUDevice& d, const int64_t length, T* out); + +void SumUpEmbeddingShardSinglePartition( + const GPUDevice& d, const float* emb_shard, + const int64_t* indices_before_unique, const int* unique_idxs, + const float* sp_weights_values, const bool use_sparse_weights, + const int nnz, const float max_norm, const int emb_vec_size, + float* emb_vectors, int* feature_nums); + +void SumUpEmbeddingShardMultiPartition( + const GPUDevice& d, const void* const* emb_shard_ptrs, + const int* partition_permutation, const int64_t* indices_before_unique, + const int* unique_idxs, const float* sp_weights_values, + const bool use_sparse_weights, const int nnz, const float max_norm, + const int emb_vec_size, float* emb_vectors, int* feature_nums); + +template +void ApplyCombiner(const GPUDevice& d, const int batch_size, + const int emb_vec_size, const bool* is_row_empty, + const bool set_empty_row_zero, int* feature_nums, + float* emb_vectors); + +template +void DistributeGradToShardSinglePartition( + const GPUDevice& d, const float* top_grad, const float* emb_shard, + const int64_t* indices_before_unique, const int* unique_idxs, + const float* sp_weights_values, const bool use_sparse_weights, + const int nnz, const int emb_vec_size, const float max_norm, + const bool set_empty_row_zero, const int* feature_nums, + const bool* is_row_empty, float* grad_shard); + +template +void DistributeGradToShardMultiPartition( + const GPUDevice& d, const float* top_grad, + const void* const* emb_shard_ptrs, const int* partition_permutation, + const int64_t* indices_before_unique, const int* unique_idxs, + const float* sp_weights_values, const bool use_sparse_weights, + const int nnz, const int emb_vec_size, const float max_norm, + const bool set_empty_row_zero, const int* feature_nums, + const bool* is_row_empty, void** grad_shard_ptrs); + +} // namespace fused_embedding + +} // namespace tensorflow + +#endif // GOOGLE_CUDA + +#endif // TENSORFLOW_CORE_KERNELS_FUSED_EMBEDDING_GPU_FUNCTIONS_KERNELS_CU_H_ \ No newline at end of file diff --git a/tensorflow/core/kernels/fused_embedding/gpu/functions/partition_select.cu.cc b/tensorflow/core/kernels/fused_embedding/gpu/functions/partition_select.cu.cc new file mode 100644 index 00000000000..5131f3ddb28 --- /dev/null +++ b/tensorflow/core/kernels/fused_embedding/gpu/functions/partition_select.cu.cc @@ -0,0 +1,364 @@ +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU +#include "tensorflow/core/kernels/fused_embedding/gpu/functions/partition_select.cu.h" + +#include +#include +#include + +#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" +#include "tensorflow/core/kernels/fused_embedding/gpu/common.cu.h" +#include "tensorflow/core/lib/core/bits.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" +#include "tensorflow/stream_executor/cuda/cuda_activation.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; + +namespace fused_embedding { + +// A macro helper to declare SelectScanKernel, because they just have only +// a little bit differences. Need to define macros before using this +#define DeclareSelectScanKernel \ + template \ + __global__ void SelectScanKernelName( \ + const T* keys, SelectScanKernelArgs const int64 num_partitions, \ + const int64 length, const int64 predicates_length, \ + const int64 counters_length, unsigned int* predicates, \ + unsigned int* accumulate_counters) { \ + __shared__ unsigned int warp_cnt_sum[1]; \ + int lnid = threadIdx.x % 32; \ + if (lnid == 0) warp_cnt_sum[0] = 0; /* init shared mem of sum */ \ + const int warp_iteration = WarpWorkload / 32; \ + const int select_scan_warps = length % WarpWorkload == 0 \ + ? (length / WarpWorkload) \ + : (length / WarpWorkload + 1); \ + const int partition_id = blockIdx.x / select_scan_warps; \ + int warp_id_in_partition = blockIdx.x % select_scan_warps; \ + unsigned int mask; \ + unsigned int cnt; \ + SelectScanKernelLoadCodeBlock; \ + _Pragma("unroll") for (int i = 0; i < warp_iteration; i++) { \ + int selected; \ + int load_offset = warp_id_in_partition * WarpWorkload + i * 32 + lnid; \ + if (load_offset < length) { \ + T key = keys[load_offset]; \ + SelectScanKernelEvalCodeBlock; \ + } else { \ + selected = 0; \ + } \ + mask = __ballot_sync(0xffffffff, selected); \ + if (lnid == 0) \ + predicates[partition_id * predicates_length + \ + warp_id_in_partition * (WarpWorkload / 32) + i] = mask; \ + cnt = __popc(mask); \ + if (lnid == 0) atomicAdd(warp_cnt_sum, cnt); \ + } \ + /* use different threads in warp to accumulate to different location of \ + * accumulate_counters, this will do a prefix sum on global mem */ \ + const int counters_slots_need_accumulate = \ + select_scan_warps - warp_id_in_partition; \ + const unsigned int warp_cnt = warp_cnt_sum[0]; \ + for (int i = 0; i < counters_slots_need_accumulate; i += 32) { \ + if (i + lnid < counters_slots_need_accumulate) { \ + atomicAdd(accumulate_counters + partition_id * counters_length + \ + warp_id_in_partition + lnid, \ + warp_cnt); \ + } \ + } \ + } + +// A macro helper to declare SelectKernel, because they just have only +// a little bit differences. Need to define macros before using this +#define DeclareSelectKernel \ + template \ + __global__ void SelectKernelName( \ + const T* keys, SelectKernelArgs unsigned int* predicates, \ + unsigned int* accumulate_counters, const int64 num_partitions, \ + const int64 length, const int64 predicates_length, \ + const int64 counters_length, void** output_ptrs, TIndex* permutation) { \ + const int lnid = threadIdx.x % 32; \ + const int warp_iteration = WarpWorkload / 32; \ + const int select_scan_warps = length % WarpWorkload == 0 \ + ? (length / WarpWorkload) \ + : (length / WarpWorkload + 1); \ + const int partition_id = blockIdx.x / select_scan_warps; \ + int warp_id_in_partition = blockIdx.x % select_scan_warps; \ + unsigned int predmask = 0; \ + unsigned int cnt = 0; \ + SelectKernelLoadCodeBlock; \ + T* keys_output_ptr = (T*)(output_ptrs[partition_id]); \ + if (lnid < warp_iteration) { \ + predmask = \ + predicates[partition_id * predicates_length + \ + (warp_id_in_partition * (WarpWorkload / 32)) + lnid]; \ + cnt = __popc(predmask); \ + } \ + _Pragma("unroll") for (int offset = 1; offset < warp_iteration; \ + offset <<= 1) { \ + /* prefix sum */ \ + int n = __shfl_up_sync(0xffffffff, cnt, offset); \ + if (lnid >= offset) cnt += n; \ + } \ + unsigned int global_index = 0; \ + if (warp_id_in_partition > 0) \ + global_index = accumulate_counters[partition_id * counters_length + \ + warp_id_in_partition - 1]; \ + _Pragma("unroll") for (int i = 0; i < warp_iteration; i++) { \ + unsigned int mask = __shfl_sync(0xffffffff, predmask, i); \ + unsigned int inner_warp_index = 0; \ + if (i > 0) inner_warp_index = __shfl_sync(0xffffffff, cnt, i - 1); \ + if (mask & (1 << lnid)) { \ + int load_offset = warp_id_in_partition * WarpWorkload + i * 32 + lnid; \ + T key = keys[load_offset]; /* Will not cause out of boundry access, \ + because mask will be 0 for this thread*/ \ + SelectKernelRecalcKeyCodeBlock; \ + int output_offset = \ + global_index + inner_warp_index + \ + __popc(mask & (((unsigned int)(1) << lnid) - (unsigned int)(1))); \ + keys_output_ptr[output_offset] = new_key; \ + permutation[2 * load_offset] = partition_id; \ + permutation[2 * load_offset + 1] = (TIndex)output_offset; \ + } \ + } \ + } + +// A macro helper to declare DefinePartitionSelect. Need to define +// macros before using this +#define DeclareSelect \ + template \ + void SelectName(OpKernelContext* ctx, const Tensor* keys, \ + SelectArgs const int64 num_partitions, \ + cudaEvent_t memcpy_event, OpOutputList& selected_keys, \ + Tensor* permutation) { \ + OP_REQUIRES(ctx, keys->dims() == 1, \ + errors::InvalidArgument("Tensor keys must ranks 1")); \ + OP_REQUIRES( \ + ctx, \ + WarpWorkload >= 32 && WarpWorkload <= 1024 && \ + (WarpWorkload && !(WarpWorkload & (WarpWorkload - 1))), \ + errors::InvalidArgument( \ + "WarpWorkload must be larger than warp size and less than 1024 " \ + "32 and is exponential of 2, 32, 64, 128, i.e.")); \ + const GPUDevice& device = ctx->eigen_gpu_device(); \ + const int64 length = keys->NumElements(); \ + const int64 warp_iteration = WarpWorkload / 32; \ + Tensor predicates; \ + Tensor accumulate_counters; \ + const int64 select_scan_warps = length % WarpWorkload == 0 \ + ? (length / WarpWorkload) \ + : (length / WarpWorkload + 1); \ + const int64 counters_length = select_scan_warps; \ + const int64 predicates_length = select_scan_warps * warp_iteration; \ + OP_REQUIRES_OK( \ + ctx, ctx->allocate_temp(DT_UINT32, \ + TensorShape{counters_length * num_partitions}, \ + &accumulate_counters)); \ + CK_CUDA_THROW_(cudaMemsetAsync( \ + data_p_with_type(accumulate_counters), 0, \ + accumulate_counters.NumElements() * sizeof(unsigned int), \ + device.stream())); \ + OP_REQUIRES_OK( \ + ctx, ctx->allocate_temp( \ + DT_UINT32, TensorShape{predicates_length * num_partitions}, \ + &predicates)); \ + \ + { \ + const int64 threads = 32; \ + const int64 blocks = select_scan_warps * num_partitions; \ + OP_REQUIRES_OK( \ + ctx, \ + GpuLaunchKernel( \ + SelectScanKernelName, blocks, threads, 0, \ + device.stream(), data_p_with_type(keys), \ + SelectScanPassArgs num_partitions, length, predicates_length, \ + counters_length, data_p_with_type(predicates), \ + data_p_with_type(accumulate_counters))); \ + } \ + std::vector selected_nums_host; \ + selected_nums_host.resize(num_partitions); \ + /* copy the last element(which is the sum of previous) with stride */ \ + CK_CUDA_THROW_(cudaMemcpy2DAsync( \ + selected_nums_host.data(), 1 * sizeof(unsigned int), \ + data_p_with_type(accumulate_counters) + \ + counters_length - 1, \ + counters_length * sizeof(unsigned int), 1 * sizeof(unsigned int), \ + num_partitions, cudaMemcpyDeviceToHost, device.stream())); \ + CK_CUDA_THROW_(cudaEventRecord(memcpy_event, device.stream())); \ + CK_CUDA_THROW_(cudaEventSynchronize(memcpy_event)); \ + \ + std::vector output_ptrs_host; \ + output_ptrs_host.resize(num_partitions); \ + for (int i = 0; i < num_partitions; i++) { \ + Tensor* tmp_out; \ + OP_REQUIRES_OK( \ + ctx, selected_keys.allocate( \ + i, TensorShape({int64(selected_nums_host[i])}), &tmp_out)); \ + output_ptrs_host[i] = data_p_with_type(tmp_out); \ + } \ + Tensor output_ptrs; \ + OP_REQUIRES_OK( \ + ctx, ctx->allocate_temp(DT_UINT64, TensorShape{int64(num_partitions)}, \ + &output_ptrs)); \ + CK_CUDA_THROW_(cudaMemcpyAsync(data_p_with_type(output_ptrs), \ + output_ptrs_host.data(), \ + num_partitions * sizeof(size_t), \ + cudaMemcpyHostToDevice, device.stream())); \ + { \ + const int64 threads = 32; \ + const int64 blocks = select_scan_warps * num_partitions; \ + OP_REQUIRES_OK( \ + ctx, GpuLaunchKernel( \ + SelectKernelName, blocks, threads, \ + 0, device.stream(), data_p_with_type(keys), \ + SelectPassArgs data_p_with_type(predicates), \ + data_p_with_type(accumulate_counters), \ + num_partitions, length, predicates_length, counters_length, \ + data_p_with_type(output_ptrs), \ + data_p_with_type(permutation))); \ + } \ + } + +// =============== Div Selection =============== // +#define SelectName PartitionSelectDiv +#define SelectArgs const Tensor &accu_div, +#define SelectScanPassArgs data_p_with_type(accu_div), +#define SelectPassArgs SelectScanPassArgs + +#define SelectScanKernelName SelectScanDivKernel +#define SelectScanKernelArgs const int64 *accu_div, +#define SelectScanKernelLoadCodeBlock \ + int64 lower_bound = partition_id > 0 ? accu_div[partition_id - 1] : 0; \ + int64 upper_bound = accu_div[partition_id]; +#define SelectScanKernelEvalCodeBlock \ + selected = int(key >= lower_bound && key < upper_bound); + +#define SelectKernelName SelectDivKernel +#define SelectKernelArgs SelectScanKernelArgs + +#define SelectKernelLoadCodeBlock \ + int64 lower_bound = partition_id > 0 ? accu_div[partition_id - 1] : 0; +#define SelectKernelRecalcKeyCodeBlock T new_key = key - lower_bound; + +DeclareSelectScanKernel; +DeclareSelectKernel; +DeclareSelect; + +template void PartitionSelectDiv( + OpKernelContext* ctx, const Tensor* keys, const Tensor& accu_div, + const int64 num_partitions, cudaEvent_t memcpy_event, + OpOutputList& selected_keys, Tensor* permutation); + +template void PartitionSelectDiv( + OpKernelContext* ctx, const Tensor* keys, const Tensor& accu_div, + const int64 num_partitions, cudaEvent_t memcpy_event, + OpOutputList& selected_keys, Tensor* permutation); + +template void PartitionSelectDiv( + OpKernelContext* ctx, const Tensor* keys, const Tensor& accu_div, + const int64 num_partitions, cudaEvent_t memcpy_event, + OpOutputList& selected_keys, Tensor* permutation); + +template void PartitionSelectDiv( + OpKernelContext* ctx, const Tensor* keys, const Tensor& accu_div, + const int64 num_partitions, cudaEvent_t memcpy_event, + OpOutputList& selected_keys, Tensor* permutation); + +template void PartitionSelectDiv( + OpKernelContext* ctx, const Tensor* keys, const Tensor& accu_div, + const int64 num_partitions, cudaEvent_t memcpy_event, + OpOutputList& selected_keys, Tensor* permutation); + +// =============== Mod Selection =============== // + +#define SelectName PartitionSelectMod +#define SelectArgs +#define SelectScanPassArgs +#define SelectPassArgs SelectScanPassArgs + +#define SelectScanKernelName SelectScanModKernel +#define SelectScanKernelArgs +#define SelectScanKernelLoadCodeBlock +#define SelectScanKernelEvalCodeBlock \ + selected = int(key % num_partitions == partition_id); + +#define SelectKernelName SelectModKernel +#define SelectKernelArgs SelectScanKernelArgs + +#define SelectKernelLoadCodeBlock +#define SelectKernelRecalcKeyCodeBlock T new_key = key / num_partitions; + +DeclareSelectScanKernel; +DeclareSelectKernel; +DeclareSelect; + +template void PartitionSelectMod( + OpKernelContext* ctx, const Tensor* keys, const int64 num_partitions, + cudaEvent_t memcpy_event, OpOutputList& selected_keys, Tensor* permutation); + +template void PartitionSelectMod( + OpKernelContext* ctx, const Tensor* keys, const int64 num_partitions, + cudaEvent_t memcpy_event, OpOutputList& selected_keys, Tensor* permutation); + +template void PartitionSelectMod( + OpKernelContext* ctx, const Tensor* keys, const int64 num_partitions, + cudaEvent_t memcpy_event, OpOutputList& selected_keys, Tensor* permutation); + +template void PartitionSelectMod( + OpKernelContext* ctx, const Tensor* keys, const int64 num_partitions, + cudaEvent_t memcpy_event, OpOutputList& selected_keys, Tensor* permutation); + +template void PartitionSelectMod( + OpKernelContext* ctx, const Tensor* keys, const int64 num_partitions, + cudaEvent_t memcpy_event, OpOutputList& selected_keys, Tensor* permutation); + +// =============== Mod EV Selection =============== // + +#define SelectName PartitionSelectModEV +#define SelectArgs +#define SelectScanPassArgs +#define SelectPassArgs SelectScanPassArgs + +#define SelectScanKernelName SelectScanModEVKernel +#define SelectScanKernelArgs +#define SelectScanKernelLoadCodeBlock +#define SelectScanKernelEvalCodeBlock \ + selected = int(key % 1000 % num_partitions == partition_id); + +#define SelectKernelName SelectModEVKernel +#define SelectKernelArgs SelectScanKernelArgs + +#define SelectKernelLoadCodeBlock +#define SelectKernelRecalcKeyCodeBlock T new_key = key; + +DeclareSelectScanKernel; +DeclareSelectKernel; +DeclareSelect; + +template void PartitionSelectModEV( + OpKernelContext* ctx, const Tensor* keys, const int64 num_partitions, + cudaEvent_t memcpy_event, OpOutputList& selected_keys, Tensor* permutation); + +template void PartitionSelectModEV( + OpKernelContext* ctx, const Tensor* keys, const int64 num_partitions, + cudaEvent_t memcpy_event, OpOutputList& selected_keys, Tensor* permutation); + +template void PartitionSelectModEV( + OpKernelContext* ctx, const Tensor* keys, const int64 num_partitions, + cudaEvent_t memcpy_event, OpOutputList& selected_keys, Tensor* permutation); + +template void PartitionSelectModEV( + OpKernelContext* ctx, const Tensor* keys, const int64 num_partitions, + cudaEvent_t memcpy_event, OpOutputList& selected_keys, Tensor* permutation); + +template void PartitionSelectModEV( + OpKernelContext* ctx, const Tensor* keys, const int64 num_partitions, + cudaEvent_t memcpy_event, OpOutputList& selected_keys, Tensor* permutation); + +} // namespace fused_embedding + +} // namespace tensorflow + +#endif // GOOGLE_CUDA \ No newline at end of file diff --git a/tensorflow/core/kernels/fused_embedding/gpu/functions/partition_select.cu.h b/tensorflow/core/kernels/fused_embedding/gpu/functions/partition_select.cu.h new file mode 100644 index 00000000000..0f67a3fd15b --- /dev/null +++ b/tensorflow/core/kernels/fused_embedding/gpu/functions/partition_select.cu.h @@ -0,0 +1,35 @@ +#ifndef TENSORFLOW_CORE_KERNELS_FUSED_EMBEDDING_GPU_FUNCTIONS_PARTITION_SELECT_CU_H_ +#define TENSORFLOW_CORE_KERNELS_FUSED_EMBEDDING_GPU_FUNCTIONS_PARTITION_SELECT_CU_H_ + +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { + +namespace fused_embedding { +template +void PartitionSelectDiv(OpKernelContext* ctx, const Tensor* keys, + const Tensor& accu_div, const int64 num_partitions, + cudaEvent_t memcpy_event, OpOutputList& selected_keys, + Tensor* permutation); + +template +void PartitionSelectMod(OpKernelContext* ctx, const Tensor* keys, + const int64 num_partitions, cudaEvent_t memcpy_event, + OpOutputList& selected_keys, Tensor* permutation); + +template +void PartitionSelectModEV(OpKernelContext* ctx, const Tensor* keys, + const int64 num_partitions, cudaEvent_t memcpy_event, + OpOutputList& selected_keys, Tensor* permutation); + +} // namespace fused_embedding + +} // namespace tensorflow + +#endif // GOOGLE_CUDA + +#endif // TENSORFLOW_CORE_KERNELS_FUSED_EMBEDDING_GPU_FUNCTIONS_PARTITION_SELECT_CU_H_ \ No newline at end of file diff --git a/tensorflow/core/kernels/fused_embedding/gpu/fused_embedding_post_v2_ops_gpus.cu.cc b/tensorflow/core/kernels/fused_embedding/gpu/fused_embedding_post_v2_ops_gpus.cu.cc new file mode 100644 index 00000000000..66ac3ff5964 --- /dev/null +++ b/tensorflow/core/kernels/fused_embedding/gpu/fused_embedding_post_v2_ops_gpus.cu.cc @@ -0,0 +1,336 @@ +#include +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" + +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "cub/thread/thread_operators.cuh" +#include "tensorflow/core/kernels/fused_embedding/gpu/common.cu.h" +#include "tensorflow/core/kernels/fused_embedding/gpu/functions/kernels.cu.h" +#include "tensorflow/core/profiler/nvtx_utils.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" + +namespace tensorflow { +using GPUDevice = Eigen::GpuDevice; + +class FusedEmbeddingSparsePostLookUpV2GPU : public OpKernel { + public: + explicit FusedEmbeddingSparsePostLookUpV2GPU(OpKernelConstruction* ctx) + : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("num_partitions", &num_partitions_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("partition_axis", &partition_axis_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("combiner", &combiner_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("max_norm", &max_norm_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("fill_empty_row", &fill_empty_row_)); + int temp_default_id; + OP_REQUIRES_OK(ctx, ctx->GetAttr("default_id", &temp_default_id)); + default_id_ = int64_t(temp_default_id); + OP_REQUIRES_OK(ctx, + ctx->GetAttr("use_sparse_weights", &use_sparse_weights_)); + } + + void Compute(OpKernelContext* ctx) override { + nvtx::ScopedRangeIfEnabled nvtx_range(this); + + using namespace fused_embedding; + auto device = ctx->eigen_device(); + + OpInputList emb_shards; + OP_REQUIRES_OK(ctx, ctx->input_list("emb_shards", &emb_shards)); + + Tensor const* partition_permutation = nullptr; + OP_REQUIRES_OK(ctx, + ctx->input("partition_permutation", &partition_permutation)); + + Tensor const* dense_shape_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("sp_dense_shape", &dense_shape_tensor)); + + Tensor const* is_row_empty = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("is_row_empty", &is_row_empty)); + + Tensor const* indices_before_unique = nullptr; + OP_REQUIRES_OK(ctx, + ctx->input("indices_before_unique", &indices_before_unique)); + + Tensor const* unique_idxs = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("unique_idxs", &unique_idxs)); + + Tensor const* sp_weights_values = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("sp_weights_values", &sp_weights_values)); + + const int emb_vec_size = emb_shards[0].shape().dim_size(1); + const int batch_size = dense_shape_tensor->flat().data()[0]; + const int nnz = indices_before_unique->shape().dim_size(0); + const bool set_empty_row_zero = default_id_ < 0 && fill_empty_row_; + + // = 1. sum up emb values from emb_shards and dump into output = // + Tensor* emb_vectors_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output("emb_vectors", + TensorShape({int64(batch_size), + int64(emb_vec_size)}), + &emb_vectors_tensor)); + CK_CUDA_THROW_(cudaMemsetAsync( + emb_vectors_tensor->flat().data(), 0x0, + sizeof(float) * emb_vectors_tensor->NumElements(), device.stream())); + + Tensor* feature_nums; + OP_REQUIRES_OK(ctx, ctx->allocate_output("feature_nums", + TensorShape({int64(batch_size)}), + &feature_nums)); + CK_CUDA_THROW_(cudaMemsetAsync(feature_nums->flat().data(), 0x0, + sizeof(int) * feature_nums->NumElements(), + device.stream())); + + Tensor* emb_shard_ptrs; + OP_REQUIRES_OK(ctx, + ctx->allocate_output("emb_shard_ptrs", + TensorShape({int64(num_partitions_)}), + &emb_shard_ptrs)); + + if (num_partitions_ == 1) { + SumUpEmbeddingShardSinglePartition( + device, data_p_with_type(emb_shards[0]), + data_p_with_type(indices_before_unique), + data_p_with_type(unique_idxs), + data_p_with_type(sp_weights_values), use_sparse_weights_, nnz, + max_norm_, emb_vec_size, data_p_with_type(emb_vectors_tensor), + data_p_with_type(feature_nums)); + } else { + std::vector emb_shard_ptrs_host; + emb_shard_ptrs_host.resize(num_partitions_); + for (int i = 0; i < num_partitions_; i++) { + emb_shard_ptrs_host[i] = data_p_with_type(emb_shards[i]); + } + + CK_CUDA_THROW_(cudaMemcpyAsync(data_p_with_type(emb_shard_ptrs), + emb_shard_ptrs_host.data(), + num_partitions_ * sizeof(size_t), + cudaMemcpyHostToDevice, device.stream())); + + SumUpEmbeddingShardMultiPartition( + device, data_p_with_type(emb_shard_ptrs), + data_p_with_type(partition_permutation), + data_p_with_type(indices_before_unique), + data_p_with_type(unique_idxs), + data_p_with_type(sp_weights_values), use_sparse_weights_, nnz, + max_norm_, emb_vec_size, data_p_with_type(emb_vectors_tensor), + data_p_with_type(feature_nums)); + } + + // ================================================================ // + + // ========================= 2. combiner ========================== // + if (combiner_ == "sqrtn") { + ApplyCombiner(device, batch_size, emb_vec_size, + data_p_with_type(is_row_empty), + set_empty_row_zero, + data_p_with_type(feature_nums), + data_p_with_type(emb_vectors_tensor)); + } else if (combiner_ == "mean") { + ApplyCombiner(device, batch_size, emb_vec_size, + data_p_with_type(is_row_empty), + set_empty_row_zero, + data_p_with_type(feature_nums), + data_p_with_type(emb_vectors_tensor)); + } else { + ApplyCombiner(device, batch_size, emb_vec_size, + data_p_with_type(is_row_empty), + set_empty_row_zero, + data_p_with_type(feature_nums), + data_p_with_type(emb_vectors_tensor)); + } + // ================================================================ // + } + + private: + int num_partitions_; + int partition_axis_; + std::string combiner_; + float max_norm_; + bool fill_empty_row_; + int64_t default_id_; + bool use_sparse_weights_; +}; + +REGISTER_KERNEL_BUILDER(Name("FusedEmbeddingSparsePostLookUpV2") + .Device(DEVICE_GPU) + .HostMemory("sp_dense_shape"), + FusedEmbeddingSparsePostLookUpV2GPU); + +class FusedEmbeddingSparsePostLookUpV2GradGPU : public OpKernel { + public: + explicit FusedEmbeddingSparsePostLookUpV2GradGPU(OpKernelConstruction* ctx) + : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("num_partitions", &num_partitions_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("partition_axis", &partition_axis_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("combiner", &combiner_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("max_norm", &max_norm_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("fill_empty_row", &fill_empty_row_)); + int temp_default_id; + OP_REQUIRES_OK(ctx, ctx->GetAttr("default_id", &temp_default_id)); + default_id_ = int64_t(temp_default_id); + OP_REQUIRES_OK(ctx, + ctx->GetAttr("use_sparse_weights", &use_sparse_weights_)); + } + + void Compute(OpKernelContext* ctx) override { + nvtx::ScopedRangeIfEnabled nvtx_range(this); + + using namespace fused_embedding; + auto device = ctx->eigen_device(); + + Tensor const* top_grad_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("top_grad", &top_grad_tensor)); + + OpInputList emb_shards; + OP_REQUIRES_OK(ctx, ctx->input_list("emb_shards", &emb_shards)); + + Tensor const* emb_shard_ptrs = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("emb_shard_ptrs", &emb_shard_ptrs)); + + Tensor const* partition_permutation = nullptr; + OP_REQUIRES_OK(ctx, + ctx->input("partition_permutation", &partition_permutation)); + + Tensor const* feature_nums = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("feature_nums", &feature_nums)); + + Tensor const* indices_before_unique = nullptr; + OP_REQUIRES_OK(ctx, + ctx->input("indices_before_unique", &indices_before_unique)); + + Tensor const* unique_idxs = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("unique_idxs", &unique_idxs)); + + Tensor const* is_row_empty = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("is_row_empty", &is_row_empty)); + + Tensor const* sp_weights_values = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("sp_weights_values", &sp_weights_values)); + + OpOutputList grad_shards; + OP_REQUIRES_OK(ctx, ctx->output_list("grad_shards", &grad_shards)); + + const int batch_size = top_grad_tensor->shape().dim_size(0); + const int emb_vec_size = emb_shards[0].shape().dim_size(1); + const int nnz = indices_before_unique->shape().dim_size(0); + const bool set_empty_row_zero = default_id_ < 0 && fill_empty_row_; + + std::vector grad_shard_ptrs_host; + grad_shard_ptrs_host.resize(num_partitions_); + for (int i = 0; i < num_partitions_; i++) { + Tensor* grad_out; + grad_shards.allocate(i, emb_shards[i].shape(), &grad_out); + grad_shard_ptrs_host[i] = data_p_with_type(grad_out); + CK_CUDA_THROW_(cudaMemsetAsync(data_p_with_type(grad_out), 0x0, + sizeof(float) * grad_out->NumElements(), + device.stream())); + } + + if (num_partitions_ == 1) { + if (combiner_ == "mean") { + DistributeGradToShardSinglePartition( + device, data_p_with_type(top_grad_tensor), + data_p_with_type(emb_shards[0]), + data_p_with_type(indices_before_unique), + data_p_with_type(unique_idxs), + data_p_with_type(sp_weights_values), + use_sparse_weights_, nnz, emb_vec_size, max_norm_, + set_empty_row_zero, data_p_with_type(feature_nums), + data_p_with_type(is_row_empty), + data_p_with_type(grad_shards[0])); + } else if (combiner_ == "sqrt") { + DistributeGradToShardSinglePartition( + device, data_p_with_type(top_grad_tensor), + data_p_with_type(emb_shards[0]), + data_p_with_type(indices_before_unique), + data_p_with_type(unique_idxs), + data_p_with_type(sp_weights_values), + use_sparse_weights_, nnz, emb_vec_size, max_norm_, + set_empty_row_zero, data_p_with_type(feature_nums), + data_p_with_type(is_row_empty), + data_p_with_type(grad_shards[0])); + } else { + DistributeGradToShardSinglePartition( + device, data_p_with_type(top_grad_tensor), + data_p_with_type(emb_shards[0]), + data_p_with_type(indices_before_unique), + data_p_with_type(unique_idxs), + data_p_with_type(sp_weights_values), + use_sparse_weights_, nnz, emb_vec_size, max_norm_, + set_empty_row_zero, data_p_with_type(feature_nums), + data_p_with_type(is_row_empty), + data_p_with_type(grad_shards[0])); + } + + } else { + Tensor grad_shard_ptrs; + OP_REQUIRES_OK(ctx, ctx->allocate_temp( + DT_UINT64, TensorShape({int64(num_partitions_)}), + &grad_shard_ptrs)); + CK_CUDA_THROW_(cudaMemcpyAsync(data_p_with_type(grad_shard_ptrs), + grad_shard_ptrs_host.data(), + num_partitions_ * sizeof(size_t), + cudaMemcpyHostToDevice, device.stream())); + + if (combiner_ == "mean") { + DistributeGradToShardMultiPartition( + device, data_p_with_type(top_grad_tensor), + data_p_with_type(emb_shard_ptrs), + data_p_with_type(partition_permutation), + data_p_with_type(indices_before_unique), + data_p_with_type(unique_idxs), + data_p_with_type(sp_weights_values), + use_sparse_weights_, nnz, emb_vec_size, max_norm_, + set_empty_row_zero, data_p_with_type(feature_nums), + data_p_with_type(is_row_empty), + data_p_with_type(grad_shard_ptrs)); + } else if (combiner_ == "sqrt") { + DistributeGradToShardMultiPartition( + device, data_p_with_type(top_grad_tensor), + data_p_with_type(emb_shard_ptrs), + data_p_with_type(partition_permutation), + data_p_with_type(indices_before_unique), + data_p_with_type(unique_idxs), + data_p_with_type(sp_weights_values), + use_sparse_weights_, nnz, emb_vec_size, max_norm_, + set_empty_row_zero, data_p_with_type(feature_nums), + data_p_with_type(is_row_empty), + data_p_with_type(grad_shard_ptrs)); + } else { + DistributeGradToShardMultiPartition( + device, data_p_with_type(top_grad_tensor), + data_p_with_type(emb_shard_ptrs), + data_p_with_type(partition_permutation), + data_p_with_type(indices_before_unique), + data_p_with_type(unique_idxs), + data_p_with_type(sp_weights_values), + use_sparse_weights_, nnz, emb_vec_size, max_norm_, + set_empty_row_zero, data_p_with_type(feature_nums), + data_p_with_type(is_row_empty), + data_p_with_type(grad_shard_ptrs)); + } + } + } + + private: + int num_partitions_; + int partition_axis_; + std::string combiner_; + float max_norm_; + bool fill_empty_row_; + int64_t default_id_; + bool use_sparse_weights_; +}; + +REGISTER_KERNEL_BUILDER( + Name("FusedEmbeddingSparsePostLookUpV2Grad").Device(DEVICE_GPU), + FusedEmbeddingSparsePostLookUpV2GradGPU); + +} // namespace tensorflow + +#endif // GOOGLE_CUDA \ No newline at end of file diff --git a/tensorflow/core/kernels/fused_embedding/gpu/partition_with_permutation_ops.cu.cc b/tensorflow/core/kernels/fused_embedding/gpu/partition_with_permutation_ops.cu.cc new file mode 100644 index 00000000000..afd0edea540 --- /dev/null +++ b/tensorflow/core/kernels/fused_embedding/gpu/partition_with_permutation_ops.cu.cc @@ -0,0 +1,161 @@ +#include +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" + +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/core/kernels/fused_embedding/gpu/common.cu.h" +#include "tensorflow/core/kernels/fused_embedding/gpu/functions/kernels.cu.h" +#include "tensorflow/core/kernels/fused_embedding/gpu/functions/partition_select.cu.h" +#include "tensorflow/core/profiler/nvtx_utils.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" +#include "cub/device/device_radix_sort.cuh" +#include "cub/device/device_select.cuh" +#include "cub/iterator/constant_input_iterator.cuh" +#include "cub/thread/thread_operators.cuh" + +namespace tensorflow { +using GPUDevice = Eigen::GpuDevice; + +class PartitionWithPermutationGPU : public OpKernel { + public: + explicit PartitionWithPermutationGPU(OpKernelConstruction* ctx) + : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("num_partitions", &num_partitions_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("partition_axis", &partition_axis_)); + OP_REQUIRES_OK(ctx, + ctx->GetAttr("partition_strategy", &partition_strategy_)); + cudaEventCreateWithFlags(&memcpy_event_, cudaEventDisableTiming); + } + + void Compute(OpKernelContext* ctx) override { + using namespace fused_embedding; + auto device = ctx->eigen_device(); + + nvtx::ScopedRangeIfEnabled nvtx_range(this); + + OpOutputList partitioned_values; + OP_REQUIRES_OK(ctx, + ctx->output_list("partitioned_values", &partitioned_values)); + + Tensor const* input = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("input", &input)); + const int64 input_size = input->NumElements(); + + Tensor* partition_permutation; + OP_REQUIRES_OK(ctx, ctx->allocate_output("partition_permutation", + TensorShape{input_size, 2}, + &partition_permutation)); + + if (partition_strategy_ == "div") { + OpInputList partition_shapes; + OP_REQUIRES_OK(ctx, + ctx->input_list("partition_shapes", &partition_shapes)); + + std::vector accu_div_host; + accu_div_host.resize(num_partitions_); + for (int i = 0; i < partition_shapes.size(); i++) { + OP_REQUIRES(ctx, partition_shapes[i].NumElements() == 2, + errors::InvalidArgument( + "input partition_shapes must all have 2 elements")); + const int64_t div = partition_shapes[i].flat().data()[0]; + accu_div_host[i] = i == 0 ? div : accu_div_host[i - 1] + div; + } + + Tensor accu_div; + OP_REQUIRES_OK( + ctx, + ctx->allocate_temp( + DT_INT64, TensorShape({static_cast(num_partitions_)}), + &accu_div)); + CK_CUDA_THROW_(cudaMemcpyAsync(data_p_with_type(accu_div), + accu_div_host.data(), + num_partitions_ * sizeof(int64_t), + cudaMemcpyHostToDevice, device.stream())); + + if (input_size < 512) { + PartitionSelectDiv( + ctx, input, accu_div, num_partitions_, memcpy_event_, + partitioned_values, partition_permutation); + } else if (input_size < 1024) { + PartitionSelectDiv( + ctx, input, accu_div, num_partitions_, memcpy_event_, + partitioned_values, partition_permutation); + } else if (input_size < 2048) { + PartitionSelectDiv( + ctx, input, accu_div, num_partitions_, memcpy_event_, + partitioned_values, partition_permutation); + } else if (input_size < 4096) { + PartitionSelectDiv( + ctx, input, accu_div, num_partitions_, memcpy_event_, + partitioned_values, partition_permutation); + } else { + PartitionSelectDiv( + ctx, input, accu_div, num_partitions_, memcpy_event_, + partitioned_values, partition_permutation); + } + } else if (partition_strategy_ == "mod") { + if (input_size < 512) { + PartitionSelectMod(ctx, input, num_partitions_, + memcpy_event_, partitioned_values, + partition_permutation); + } else if (input_size < 1024) { + PartitionSelectMod(ctx, input, num_partitions_, + memcpy_event_, partitioned_values, + partition_permutation); + } else if (input_size < 2048) { + PartitionSelectMod(ctx, input, num_partitions_, + memcpy_event_, partitioned_values, + partition_permutation); + } else if (input_size < 4096) { + PartitionSelectMod(ctx, input, num_partitions_, + memcpy_event_, partitioned_values, + partition_permutation); + } else { + PartitionSelectMod(ctx, input, num_partitions_, + memcpy_event_, partitioned_values, + partition_permutation); + } + } else if (partition_strategy_ == "mod_ev") { + if (input_size < 512) { + PartitionSelectModEV(ctx, input, num_partitions_, + memcpy_event_, partitioned_values, + partition_permutation); + } else if (input_size < 1024) { + PartitionSelectModEV(ctx, input, num_partitions_, + memcpy_event_, partitioned_values, + partition_permutation); + } else if (input_size < 2048) { + PartitionSelectModEV(ctx, input, num_partitions_, + memcpy_event_, partitioned_values, + partition_permutation); + } else if (input_size < 4096) { + PartitionSelectModEV(ctx, input, num_partitions_, + memcpy_event_, partitioned_values, + partition_permutation); + } else { + PartitionSelectModEV( + ctx, input, num_partitions_, memcpy_event_, partitioned_values, + partition_permutation); + } + } + } + + private: + int num_partitions_; + int partition_axis_; + std::string partition_strategy_; + cudaEvent_t memcpy_event_; +}; + +REGISTER_KERNEL_BUILDER(Name("PartitionWithPermutation") + .Device(DEVICE_GPU) + .HostMemory("partition_shapes"), + PartitionWithPermutationGPU); +} // namespace tensorflow + +#endif // GOOGLE_CUDA \ No newline at end of file diff --git a/tensorflow/core/kernels/fused_embedding/gpu/prune_invalid_and_fill_empty_rows_ops.cu.cc b/tensorflow/core/kernels/fused_embedding/gpu/prune_invalid_and_fill_empty_rows_ops.cu.cc new file mode 100644 index 00000000000..9bb334f369b --- /dev/null +++ b/tensorflow/core/kernels/fused_embedding/gpu/prune_invalid_and_fill_empty_rows_ops.cu.cc @@ -0,0 +1,288 @@ +#include +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" + +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include +#include + +#include "cub/device/device_radix_sort.cuh" +#include "cub/device/device_select.cuh" +#include "cub/iterator/constant_input_iterator.cuh" +#include "cub/thread/thread_operators.cuh" +#include "tensorflow/core/kernels/fused_embedding/gpu/common.cu.h" +#include "tensorflow/core/kernels/fused_embedding/gpu/functions/kernels.cu.h" +#include "tensorflow/core/profiler/nvtx_utils.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" + +namespace tensorflow { +using GPUDevice = Eigen::GpuDevice; + +class PruneInvalidAndFillEmptyRowsGPU : public OpKernel { + public: + struct PruneInvalidSelectOp { + template + __host__ __device__ __forceinline__ bool operator()( + ThurstTupleT const& tuple) const { + return thrust::get<0>(tuple) >= 0; + } + }; + + struct PruneInvalidWithWeightSelectOp { + template + __host__ __device__ __forceinline__ bool operator()( + ThurstTupleT const& tuple) const { + return thrust::get<0>(tuple) >= 0 && thrust::get<2>(tuple) > 0; + } + }; + + explicit PruneInvalidAndFillEmptyRowsGPU(OpKernelConstruction* ctx) + : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("fill_empty_row", &fill_empty_row_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("prune", &prune_)); + int temp_default_id; + OP_REQUIRES_OK(ctx, ctx->GetAttr("default_id", &temp_default_id)); + default_id_ = int64(temp_default_id); + + OP_REQUIRES_OK(ctx, + ctx->GetAttr("use_sparse_weights", &use_sparse_weights_)); + OP_REQUIRES_OK( + ctx, ctx->GetAttr("prune_sparse_weights", &prune_sparse_weights_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("default_weight", &default_weight_)); + + cudaEventCreateWithFlags(&memcpy_event_, cudaEventDisableTiming); + } + + void Compute(OpKernelContext* ctx) override { + using namespace fused_embedding; + auto device = ctx->eigen_device(); + + const int64 default_id = default_id_ >= 0 ? default_id_ : 0; + + nvtx::ScopedRangeIfEnabled nvtx_range(this); + + // 1. bind & set inputs, vars, outputs and Init buffers. + Tensor const* sp_values = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("sp_values", &sp_values)); + const int64 nnz = sp_values->shape().dim_size(0); + + Tensor const* sp_indices = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("sp_indices", &sp_indices)); + + Tensor const* sp_dense_shape = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("sp_dense_shape", &sp_dense_shape)); + const int64 batch_size = sp_dense_shape->flat().data()[0]; + + Tensor const* sp_weights_values = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("sp_weights_values", &sp_weights_values)); + + if (!prune_ && !fill_empty_row_) { + ctx->set_output("sp_values_out", *sp_values); + ctx->set_output("sp_indices_out", *sp_indices); + ctx->set_output("sp_weights_values_out", *sp_weights_values); + return; + } + + Tensor sp_values_out; + Tensor sp_indices_out; + Tensor sp_weights_values_out; + + OP_REQUIRES_OK(ctx, + ctx->allocate_temp(DT_INT64, TensorShape{nnz + batch_size}, + &sp_values_out)); + OP_REQUIRES_OK( + ctx, ctx->allocate_temp(DT_INT64, TensorShape{nnz + batch_size, 2}, + &sp_indices_out)); + + if (use_sparse_weights_) { + OP_REQUIRES_OK(ctx, + ctx->allocate_temp(DT_FLOAT, TensorShape{nnz + batch_size}, + &sp_weights_values_out)); + } + + Tensor tmp_indices; + Tensor* is_row_empty; + Tensor selected_num_d; + + OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_INT64, TensorShape{batch_size, 2}, + &tmp_indices)); + + OP_REQUIRES_OK( + ctx, ctx->allocate_temp(DT_INT32, TensorShape{1}, &selected_num_d)); + + if (fill_empty_row_) { + OP_REQUIRES_OK( + ctx, ctx->allocate_output("is_row_empty", TensorShape{batch_size}, + &is_row_empty)); + + InitFillEmptyBuffers(device, batch_size, nnz, default_id, default_weight_, + prune_, use_sparse_weights_, + data_p_with_type(sp_values), + data_p_with_type(sp_indices), + data_p_with_type(sp_values_out), + data_p_with_type(sp_indices_out), + data_p_with_type(sp_weights_values_out), + data_p_with_type(is_row_empty), + data_p_with_type(tmp_indices)); + + DetectEmptyRow(device, data_p_with_type(sp_indices), + data_p_with_type(sp_values), + data_p_with_type(sp_weights_values), + prune_, prune_sparse_weights_, nnz, + data_p_with_type(is_row_empty)); + + } else { + OP_REQUIRES_OK(ctx, ctx->allocate_output("is_row_empty", TensorShape{1}, + &is_row_empty)); + } + + // 2. Allocate cub tmp + + // nnz = number of non zero + int new_nnz = nnz; + Tensor cub_temp_storage; + size_t max_cub_bytes = 0; + size_t temp_storage_bytes = 0; + auto triple_input_iter = thrust::make_zip_iterator( + thrust::make_tuple(data_p_with_type(sp_values), + data_p_with_type(sp_indices), + data_p_with_type(sp_weights_values))); + + auto double_input_iter = thrust::make_zip_iterator( + thrust::make_tuple(data_p_with_type(sp_values), + data_p_with_type(sp_indices))); + + auto triple_output_iter = thrust::make_zip_iterator( + thrust::make_tuple(data_p_with_type(sp_values_out), + data_p_with_type(sp_indices_out), + data_p_with_type(sp_weights_values_out))); + + auto double_output_iter = thrust::make_zip_iterator( + thrust::make_tuple(data_p_with_type(sp_values_out), + data_p_with_type(sp_indices_out))); + + auto with_weight_select_op = PruneInvalidWithWeightSelectOp(); + auto select_op = PruneInvalidSelectOp(); + + if (prune_) { + if (use_sparse_weights_) { + if (prune_sparse_weights_) { + cub::DeviceSelect::If(nullptr, temp_storage_bytes, triple_input_iter, + triple_output_iter, (int*)nullptr, int(nnz), + with_weight_select_op, device.stream()); + } else { + cub::DeviceSelect::If(nullptr, temp_storage_bytes, triple_input_iter, + triple_output_iter, (int*)nullptr, int(nnz), + select_op, device.stream()); + } + } else { + cub::DeviceSelect::If(nullptr, temp_storage_bytes, double_input_iter, + double_output_iter, (int*)nullptr, int(nnz), + select_op, device.stream()); + } + max_cub_bytes = temp_storage_bytes > max_cub_bytes ? temp_storage_bytes + : max_cub_bytes; + } + + if (fill_empty_row_) { + cub::DeviceSelect::Flagged((void*)nullptr, temp_storage_bytes, + (IndicePair*)nullptr, (int*)nullptr, + (IndicePair*)nullptr, (int*)nullptr, + batch_size, device.stream()); + max_cub_bytes = temp_storage_bytes > max_cub_bytes ? temp_storage_bytes + : max_cub_bytes; + } + + OP_REQUIRES_OK( + ctx, ctx->allocate_temp( + DT_INT8, TensorShape({static_cast(max_cub_bytes)}), + &cub_temp_storage)); + + // 3. select valid id & empty row indices + if (prune_) { + if (use_sparse_weights_) { + if (prune_sparse_weights_) { + cub::DeviceSelect::If(data_p_with_type(cub_temp_storage), + max_cub_bytes, triple_input_iter, + triple_output_iter, + data_p_with_type(selected_num_d), nnz, + with_weight_select_op, device.stream()); + } else { + cub::DeviceSelect::If(data_p_with_type(cub_temp_storage), + max_cub_bytes, triple_input_iter, + triple_output_iter, + data_p_with_type(selected_num_d), nnz, + select_op, device.stream()); + } + + } else { + cub::DeviceSelect::If(data_p_with_type(cub_temp_storage), + max_cub_bytes, double_input_iter, + double_output_iter, + data_p_with_type(selected_num_d), nnz, + select_op, device.stream()); + } + int selected_num; + CK_CUDA_THROW_(cudaMemcpyAsync( + &selected_num, data_p_with_type(selected_num_d), sizeof(int), + cudaMemcpyDeviceToHost, device.stream())); + CK_CUDA_THROW_(cudaEventRecord(memcpy_event_, device.stream())); + CK_CUDA_THROW_(cudaEventSynchronize(memcpy_event_)); + new_nnz = selected_num; + } + + if (fill_empty_row_) { + cub::DeviceSelect::Flagged( + data_p_with_type(cub_temp_storage), max_cub_bytes, + data_p_with_type(tmp_indices), + data_p_with_type(is_row_empty), + data_p_with_type(sp_indices_out) + new_nnz, + data_p_with_type(selected_num_d), batch_size, device.stream()); + int selected_num; + CK_CUDA_THROW_(cudaMemcpyAsync( + &selected_num, data_p_with_type(selected_num_d), sizeof(int), + cudaMemcpyDeviceToHost, device.stream())); + CK_CUDA_THROW_(cudaEventRecord(memcpy_event_, device.stream())); + CK_CUDA_THROW_(cudaEventSynchronize(memcpy_event_)); + new_nnz += selected_num; + } + + Tensor new_sp_values_out = sp_values_out.Slice(0, new_nnz); + Tensor new_sp_indices_out = sp_indices_out.Slice(0, new_nnz); + + ctx->set_output("sp_values_out", new_sp_values_out); + ctx->set_output("sp_indices_out", new_sp_indices_out); + + if (use_sparse_weights_) { + Tensor new_sp_weights_values_out = + sp_weights_values_out.Slice(0, new_nnz); + ctx->set_output("sp_weights_values_out", new_sp_weights_values_out); + } else { + Tensor* unused; + OP_REQUIRES_OK(ctx, ctx->allocate_output("sp_weights_values_out", + TensorShape{1}, &unused)); + } + } + + private: + bool fill_empty_row_; + bool prune_; + int64 default_id_; + bool use_sparse_weights_; + bool prune_sparse_weights_; + float default_weight_; + cudaEvent_t memcpy_event_; +}; + +REGISTER_KERNEL_BUILDER(Name("PruneInvalidAndFillEmptyRows") + .Device(DEVICE_GPU) + .HostMemory("sp_dense_shape"), + PruneInvalidAndFillEmptyRowsGPU); +} // namespace tensorflow + +#endif // GOOGLE_CUDA‰ \ No newline at end of file diff --git a/tensorflow/core/kernels/fused_embedding/gpu/tests/fused_embedding_post_v2_grad_ops_test.cc b/tensorflow/core/kernels/fused_embedding/gpu/tests/fused_embedding_post_v2_grad_ops_test.cc new file mode 100644 index 00000000000..176b8ef7862 --- /dev/null +++ b/tensorflow/core/kernels/fused_embedding/gpu/tests/fused_embedding_post_v2_grad_ops_test.cc @@ -0,0 +1,464 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/session.h" + +namespace tensorflow { +namespace { + +enum class Device { GPU }; + +class FusedEmbeddingSparsePostLookUpV2GradOpTest : public OpsTestBase { + protected: + void MakeOpAndSetDevice(Device device, int num_partitions, DataType dtype, + const std::string& combiner, const float max_norm, + const bool fill_empty_row, const int default_id, + const bool use_sparse_weights) { + if (device == Device::GPU) { + SetDevice(DEVICE_GPU, + std::unique_ptr(DeviceFactory::NewDevice( + "GPU", {}, "/job:a/replica:0/task:0"))); + } + + TF_EXPECT_OK(NodeDefBuilder("fused_embedding_sparse_post_look_up_v2_grad", + "FusedEmbeddingSparsePostLookUpV2Grad") + .Attr("T", dtype) + .Attr("num_partitions", num_partitions) + .Attr("partition_axis", 0) + .Attr("combiner", combiner) + .Attr("max_norm", max_norm) + .Attr("fill_empty_row", fill_empty_row) + .Attr("default_id", default_id) + .Attr("use_sparse_weights", use_sparse_weights) + .Input(FakeInput(dtype)) + .Input(FakeInput(num_partitions, dtype)) + .Input(FakeInput(DT_UINT64)) + .Input(FakeInput(DT_INT32)) + .Input(FakeInput(DT_INT32)) + .Input(FakeInput(DT_INT64)) + .Input(FakeInput(DT_INT32)) + .Input(FakeInput(DT_BOOL)) + .Input(FakeInput(DT_FLOAT)) + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + } +}; + +TEST_F(FusedEmbeddingSparsePostLookUpV2GradOpTest, Partition2MeanMaxNorm100) { + const int nnz = 10; + const int batch_size = 4; + const int emb_vector_dim = 8; + const int entries = 8; + + MakeOpAndSetDevice(Device::GPU, 2, DT_FLOAT, "mean", 100.0, false, -1, false); + + // top_grad + AddInputFromArray( + TensorShape({batch_size, emb_vector_dim}), + {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, + 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, + 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0}); + + // emb_shards 0 + AddInputFromArray( + TensorShape({6, emb_vector_dim}), + {8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 24.0, 25.0, 26.0, 27.0, + 28.0, 29.0, 30.0, 31.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, + 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 32.0, 33.0, 34.0, 35.0, + 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0}); + + // make same input to dump to emb_shard_ptrs + Tensor emb_shards_0(allocator(), DT_FLOAT, TensorShape({6, emb_vector_dim})); + test::FillValues( + &emb_shards_0, + {8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 24.0, 25.0, 26.0, 27.0, + 28.0, 29.0, 30.0, 31.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, + 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 32.0, 33.0, 34.0, 35.0, + 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0}); + + // emb_shards 1 + AddInputFromArray( + TensorShape({4, emb_vector_dim}), + {56.0, 57.0, 58.0, 59.0, 60.0, 61.0, 62.0, 63.0, + 96.0, 97.0, 98.0, 99.0, 100.0, 101.0, 102.0, 103.0, + 96.0, 97.0, 98.0, 99.0, 100.0, 101.0, 102.0, 103.0, + 120.0, 121.0, 122.0, 123.0, 124.0, 125.0, 126.0, 127.0}); + + // make same input to dump to emb_shard_ptrs + Tensor emb_shards_1(allocator(), DT_FLOAT, TensorShape({4, emb_vector_dim})); + test::FillValues( + &emb_shards_1, {56.0, 57.0, 58.0, 59.0, 60.0, 61.0, 62.0, 63.0, + 96.0, 97.0, 98.0, 99.0, 100.0, 101.0, 102.0, 103.0, + 96.0, 97.0, 98.0, 99.0, 100.0, 101.0, 102.0, 103.0, + 120.0, 121.0, 122.0, 123.0, 124.0, 125.0, 126.0, 127.0}); + + // emb_shard_ptrs + AddInputFromArray(TensorShape({2}), + {reinterpret_cast(emb_shards_0.data()), + reinterpret_cast(emb_shards_1.data())}); + + // partition_permutation + AddInputFromArray(TensorShape({10, 2}), {0, 0, 1, 0, 0, 1, 1, 1, 0, 2, + 1, 2, 0, 3, 1, 3, 0, 4, 0, 5}); + + // feature_nums + AddInputFromArray(TensorShape({batch_size}), {2, 3, 3, 2}); + + // indices_before_unique + AddInputFromArray( + TensorShape({nnz, 2}), + {0, 5, 1, 7, 0, 1, 2, 4, 2, 1, 2, 7, 1, 2, 3, 0, 3, 6, 1, 1}); + + // unique_idxs + AddInputFromArray(TensorShape({nnz}), {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + + // is_row_empty + AddInputFromArray(TensorShape({batch_size}), + {false, false, false, false}); + + // sp_weights_values + AddInputFromArray(TensorShape({1}), {1.0}); + + TF_ASSERT_OK(RunOpKernel()); + TF_EXPECT_OK(device_->Sync()); + + { + Tensor grad_shards_1(allocator(), DT_FLOAT, + TensorShape({6, emb_vector_dim})); + test::FillValues( + &grad_shards_1, + {0.00000000, 0.50000000, 1.00000000, 1.50000000, 2.00000000, + 2.50000000, 3.00000000, 3.50000000, 0.00000000, 0.50000000, + 1.00000000, 1.50000000, 2.00000000, 2.50000000, 3.00000000, + 3.50000000, 5.33333349, 5.66666651, 6.00000000, 6.33333349, + 6.66666651, 7.00000000, 7.33333349, 7.66666651, 2.65028572, + 2.98157120, 3.31285667, 3.64414287, 3.97542834, 4.30671406, + 4.63799953, 4.96928549, 11.92628479, 12.42321396, 12.92014217, + 13.41707039, 13.91399956, 14.41092777, 14.90785599, 15.40478516, + 2.16437674, 2.43492365, 2.70547056, 2.97601795, 3.24656487, + 3.51711202, 3.78765893, 4.05820608}); + test::ExpectTensorNear(grad_shards_1, *GetOutput(0), 1e-4); + } + + { + Tensor grad_shards_2(allocator(), DT_FLOAT, + TensorShape({4, emb_vector_dim})); + test::FillValues( + &grad_shards_2, + {1.58337951, 1.78130186, 1.97922409, 2.17714667, 2.37506914, 2.57299161, + 2.77091384, 2.96883631, 1.89459133, 2.01300311, 2.13141513, 2.24982715, + 2.36823893, 2.48665094, 2.60506320, 2.72347474, 1.89459133, 2.01300311, + 2.13141513, 2.24982715, 2.36823893, 2.48665094, 2.60506320, 2.72347474, + 3.43474555, 3.57786012, 3.72097445, 3.86408877, 4.00720310, 4.15031767, + 4.29343224, 4.43654633}); + test::ExpectTensorNear(grad_shards_2, *GetOutput(1), 1e-4); + } +} + +TEST_F(FusedEmbeddingSparsePostLookUpV2GradOpTest, Partition2SUMUnique) { + const int nnz = 6; + const int batch_size = 4; + const int emb_vector_dim = 1; + const int entries = 8; + + MakeOpAndSetDevice(Device::GPU, 2, DT_FLOAT, "sum", -1.0, true, -1, false); + + // top_grad + AddInputFromArray(TensorShape({batch_size, emb_vector_dim}), + {1.0, 2.0, 3.0, 4.0}); + + // emb_shards 0 + AddInputFromArray(TensorShape({3, emb_vector_dim}), {4.0, 5.0, 6.0}); + // make same input to dump to emb_shard_ptrs + Tensor emb_shards_0(allocator(), DT_FLOAT, TensorShape({3, emb_vector_dim})); + test::FillValues(&emb_shards_0, {4.0, 5.0, 6.0}); + + // emb_shards 1 + AddInputFromArray(TensorShape({2, emb_vector_dim}), {6.0, 7.0}); + Tensor emb_shards_1(allocator(), DT_FLOAT, TensorShape({2, emb_vector_dim})); + test::FillValues(&emb_shards_1, {6.0, 7.0}); + + // emb_shard_ptrs + AddInputFromArray(TensorShape({2}), + {reinterpret_cast(emb_shards_0.data()), + reinterpret_cast(emb_shards_1.data())}); + + // partition_permutation + AddInputFromArray(TensorShape({5, 2}), {1, 1, 0, 2, 0, 0, 1, 0, 0, 1}); + + // feature_nums + AddInputFromArray(TensorShape({batch_size}), {2, 2, 1, 2}); + + // values after fill empty: 1, 1, 2, 3, 4, 2, 0 + // after unique 1, 2, 3, 4, 0 + + // indices_before_unique + AddInputFromArray(TensorShape({nnz + 1, 2}), + {0, 1, 0, 3, 1, 2, 1, 3, 3, 2, 3, 6, 2, 0}); + + // unique_idxs + AddInputFromArray(TensorShape({nnz + 1}), {0, 0, 1, 2, 3, 1, 4}); + + // is_row_empty + AddInputFromArray(TensorShape({batch_size}), + {false, false, true, false}); + + // sp_weights_values + AddInputFromArray(TensorShape({1}), {1.0}); + + TF_ASSERT_OK(RunOpKernel()); + TF_EXPECT_OK(device_->Sync()); + + { + /* + permute 2 -> unique_counts: 1, unique_offsets: 4 -> + idx_of_input_to_unique: 3 -> batch: 1, grad: 2.0 + + permute 4 -> unique_counts: 1, unique_offsets: 6 -> + idx_of_input_to_unique: 6 -> batch: 2: grad: 0.0, because fill_empty + row + + permute 1 -> unique_counts: 2, unique_offsets: 2 + -> idx_of_input_to_unique: 2 -> batch: 1 -> grad : 2.0 + -> idx_of_input_to_unique: 5 -> batch: 3 -> grad : 4.0 + sum grad: 6.0 + */ + Tensor grad_shards_1(allocator(), DT_FLOAT, + TensorShape({3, emb_vector_dim})); + test::FillValues(&grad_shards_1, {2.0, 0.0, 6.0}); + test::ExpectTensorNear(grad_shards_1, *GetOutput(0), 1e-4); + } + + { + /* + permute 3 -> unique_counts: 1, unique_offsets: 5 -> + idx_of_input_to_unique: 4 -> batch: 3 -> grad: 4.0 + + permute 0 -> unique_counts: 2, unique_offsets: 0 + -> idx_of_input_to_unique 0 -> batch: 0 -> grad: 1.0 + -> idx_of_input_to_unique 1 -> batch: 0 -> grad: 1.0 + sum grad: 2.0 + */ + Tensor grad_shards_2(allocator(), DT_FLOAT, + TensorShape({2, emb_vector_dim})); + test::FillValues(&grad_shards_2, {4.0, 2.0}); + test::ExpectTensorNear(grad_shards_2, *GetOutput(1), 1e-4); + } +} + +TEST_F(FusedEmbeddingSparsePostLookUpV2GradOpTest, + Partition2SUMUniqueDefault4) { + const int nnz = 6; + const int batch_size = 4; + const int emb_vector_dim = 1; + const int entries = 8; + + MakeOpAndSetDevice(Device::GPU, 2, DT_FLOAT, "sum", -1.0, true, 4, false); + + // top_grad + AddInputFromArray(TensorShape({batch_size, emb_vector_dim}), + {1.0, 2.0, 3.0, 4.0}); + + // emb_shards 0 + AddInputFromArray(TensorShape({3, emb_vector_dim}), {4.0, 5.0, 6.0}); + // make same input to dump to emb_shard_ptrs + Tensor emb_shards_0(allocator(), DT_FLOAT, TensorShape({3, emb_vector_dim})); + test::FillValues(&emb_shards_0, {4.0, 5.0, 6.0}); + + // emb_shards 1 + AddInputFromArray(TensorShape({2, emb_vector_dim}), {6.0, 7.0}); + Tensor emb_shards_1(allocator(), DT_FLOAT, TensorShape({2, emb_vector_dim})); + test::FillValues(&emb_shards_1, {6.0, 7.0}); + + // emb_shard_ptrs + AddInputFromArray(TensorShape({2}), + {reinterpret_cast(emb_shards_0.data()), + reinterpret_cast(emb_shards_1.data())}); + + // partition_permutation + AddInputFromArray(TensorShape({5, 2}), {1, 1, 0, 2, 0, 0, 1, 0, 0, 1}); + + // feature_nums + AddInputFromArray(TensorShape({batch_size}), {2, 2, 1, 2}); + + // values after fill empty: 1, 1, 2, 3, 4, 2, 0 + // after unique 1, 2, 3, 4, 0 + + // indices_before_unique + AddInputFromArray(TensorShape({nnz + 1, 2}), + {0, 1, 0, 3, 1, 2, 1, 3, 3, 2, 3, 6, 2, 0}); + + // unique_idxs + AddInputFromArray(TensorShape({nnz + 1}), {0, 0, 1, 2, 3, 1, 4}); + + // is_row_empty + AddInputFromArray(TensorShape({batch_size}), + {false, false, true, false}); + + // sp_weights_values + AddInputFromArray(TensorShape({1}), {1.0}); + + TF_ASSERT_OK(RunOpKernel()); + TF_EXPECT_OK(device_->Sync()); + + { + /* + permute 2 -> unique_counts: 1, unique_offsets: 4 -> + idx_of_input_to_unique: 3 -> batch: 1, grad: 2.0 + + permute 4 -> unique_counts: 1, unique_offsets: 6 -> + idx_of_input_to_unique: 6 -> batch: 2: grad: 3.0 + + permute 1 -> unique_counts: 2, unique_offsets: 2 + -> idx_of_input_to_unique: 2 -> batch: 1 -> grad : 2.0 + -> idx_of_input_to_unique: 5 -> batch: 3 -> grad : 4.0 + sum grad: 6.0 + */ + Tensor grad_shards_1(allocator(), DT_FLOAT, + TensorShape({3, emb_vector_dim})); + test::FillValues(&grad_shards_1, {2.0, 3.0, 6.0}); + test::ExpectTensorNear(grad_shards_1, *GetOutput(0), 1e-4); + } + + { + /* + permute 3 -> unique_counts: 1, unique_offsets: 5 -> + idx_of_input_to_unique: 4 -> batch: 3 -> grad: 4.0 + + permute 0 -> unique_counts: 2, unique_offsets: 0 + -> idx_of_input_to_unique 0 -> batch: 0 -> grad: 1.0 + -> idx_of_input_to_unique 1 -> batch: 0 -> grad: 1.0 + sum grad: 2.0 + */ + Tensor grad_shards_2(allocator(), DT_FLOAT, + TensorShape({2, emb_vector_dim})); + test::FillValues(&grad_shards_2, {4.0, 2.0}); + test::ExpectTensorNear(grad_shards_2, *GetOutput(1), 1e-4); + } +} + +TEST_F(FusedEmbeddingSparsePostLookUpV2GradOpTest, SinglePartitionSUMUnique) { + const int nnz = 6; + const int batch_size = 4; + const int emb_vector_dim = 1; + const int entries = 8; + + MakeOpAndSetDevice(Device::GPU, 1, DT_FLOAT, "sum", -1.0, true, -1, false); + + // top_grad + AddInputFromArray(TensorShape({batch_size, emb_vector_dim}), + {1.0, 2.0, 3.0, 4.0}); + + // emb_shards 0 + AddInputFromArray(TensorShape({5, emb_vector_dim}), + {7.0, 6.0, 4.0, 6.0, 5.0}); + + // emb_shard_ptrs, whatever, will not be used + AddInputFromArray(TensorShape({1}), {0}); + + // partition_permutation, whatever, will not be used + AddInputFromArray(TensorShape({1, 1}), {1}); + + // feature_nums + AddInputFromArray(TensorShape({batch_size}), {2, 2, 1, 2}); + + // values after fill empty: 1, 1, 2, 3, 4, 2, 0 + // after unique 1, 2, 3, 4, 0 + + // indices_before_unique + AddInputFromArray(TensorShape({nnz + 1, 2}), + {0, 1, 0, 3, 1, 2, 1, 3, 3, 2, 3, 6, 2, 0}); + + // unique_idxs + AddInputFromArray(TensorShape({nnz + 1}), {0, 0, 1, 2, 3, 1, 4}); + + // is_row_empty + AddInputFromArray(TensorShape({batch_size}), + {false, false, true, false}); + + // sp_weights_values + AddInputFromArray(TensorShape({1}), {1.0}); + + TF_ASSERT_OK(RunOpKernel()); + TF_EXPECT_OK(device_->Sync()); + + Tensor grad_shards_0(allocator(), DT_FLOAT, TensorShape({5, emb_vector_dim})); + test::FillValues(&grad_shards_0, {2.0, 6.0, 2.0, 4.0, 0.0}); + test::ExpectTensorNear(grad_shards_0, *GetOutput(0), 1e-4); +} + +TEST_F(FusedEmbeddingSparsePostLookUpV2GradOpTest, + SinglePartitionSUMUniqueSparseWeight) { + const int nnz = 6; + const int batch_size = 4; + const int emb_vector_dim = 1; + const int entries = 8; + + MakeOpAndSetDevice(Device::GPU, 1, DT_FLOAT, "sum", -1.0, true, -1, true); + + // top_grad + AddInputFromArray(TensorShape({batch_size, emb_vector_dim}), + {1.0, 1.0, 1.0, 1.0}); + + // emb_shards 0 + AddInputFromArray(TensorShape({5, emb_vector_dim}), + {7.0, 6.0, 4.0, 6.0, 5.0}); + + // emb_shard_ptrs, whatever, will not be used + AddInputFromArray(TensorShape({1}), {0}); + + // partition_permutation, whatever, will not be used + AddInputFromArray(TensorShape({1, 1}), {1}); + + // feature_nums + AddInputFromArray(TensorShape({batch_size}), {2, 2, 1, 2}); + + // values after fill empty: 1, 1, 2, 3, 4, 2, 0 + // after unique 1, 2, 3, 4, 0 + + // indices_before_unique + AddInputFromArray(TensorShape({nnz + 1, 2}), + {0, 1, 0, 3, 1, 2, 1, 3, 3, 2, 3, 6, 2, 0}); + + // unique_idxs + AddInputFromArray(TensorShape({nnz + 1}), {0, 0, 1, 2, 3, 1, 4}); + + // is_row_empty + AddInputFromArray(TensorShape({batch_size}), + {false, false, true, false}); + + // sp_weights_values + AddInputFromArray(TensorShape({nnz + 1}), + {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}); + + TF_ASSERT_OK(RunOpKernel()); + TF_EXPECT_OK(device_->Sync()); + + Tensor grad_shards_0(allocator(), DT_FLOAT, TensorShape({5, emb_vector_dim})); + test::FillValues(&grad_shards_0, {3.0, 9.0, 4.0, 5.0, 0.0}); + test::ExpectTensorNear(grad_shards_0, *GetOutput(0), 1e-4); +} + +} // namespace +} // namespace tensorflow \ No newline at end of file diff --git a/tensorflow/core/kernels/fused_embedding/gpu/tests/fused_embedding_post_v2_ops_test.cc b/tensorflow/core/kernels/fused_embedding/gpu/tests/fused_embedding_post_v2_ops_test.cc new file mode 100644 index 00000000000..aa4c0b361bc --- /dev/null +++ b/tensorflow/core/kernels/fused_embedding/gpu/tests/fused_embedding_post_v2_ops_test.cc @@ -0,0 +1,410 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/session.h" + +namespace tensorflow { +namespace { + +enum class Device { GPU }; +class FusedEmbeddingSparsePostLookUpV2OpTest : public OpsTestBase { + protected: + void MakeOpAndSetDevice(Device device, int num_partitions, DataType dtype, + const std::string& combiner, const float max_norm, + const bool fill_empty_row, const int default_id, + const bool use_sparse_weights) { + if (device == Device::GPU) { + SetDevice(DEVICE_GPU, + std::unique_ptr(DeviceFactory::NewDevice( + "GPU", {}, "/job:a/replica:0/task:0"))); + } + + TF_EXPECT_OK(NodeDefBuilder("fused_embedding_sparse_post_look_up_v2", + "FusedEmbeddingSparsePostLookUpV2") + .Attr("T", dtype) + .Attr("num_partitions", num_partitions) + .Attr("partition_axis", 0) + .Attr("combiner", combiner) + .Attr("max_norm", max_norm) + .Attr("fill_empty_row", fill_empty_row) + .Attr("default_id", default_id) + .Attr("use_sparse_weights", use_sparse_weights) + .Input(FakeInput(num_partitions, dtype)) + .Input(FakeInput(DT_INT32)) + .Input(FakeInput(DT_INT64)) + .Input(FakeInput(DT_INT64)) + .Input(FakeInput(DT_BOOL)) + .Input(FakeInput(DT_INT32)) + .Input(FakeInput(DT_FLOAT)) + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + } +}; + +TEST_F(FusedEmbeddingSparsePostLookUpV2OpTest, + Partition3CombinerSqrtnMaxNorm200) { + const int nnz = 10; + const int batch_size = 4; + const int emb_vector_dim = 8; + const int entries = 8; + + MakeOpAndSetDevice(Device::GPU, 3, DT_FLOAT, "sqrtn", 200.0, false, -1, + false); + + // emb_shards 0 + AddInputFromArray( + TensorShape({6, emb_vector_dim}), + { + 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 24.0, 25.0, + 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 24.0, 25.0, 26.0, 27.0, + 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, + 38.0, 39.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, + 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, + }); + // emb_shards 1 + AddInputFromArray(TensorShape({1, emb_vector_dim}), + {56.0, 57.0, 58.0, 59.0, 60.0, 61.0, 62.0, 63.0}); + // emb_shards 2 + AddInputFromArray( + TensorShape({3, emb_vector_dim}), + {96.0, 97.0, 98.0, 99.0, 100.0, 101.0, 102.0, 103.0, + 96.0, 97.0, 98.0, 99.0, 100.0, 101.0, 102.0, 103.0, + 120.0, 121.0, 122.0, 123.0, 124.0, 125.0, 126.0, 127.0}); + + // partition_permutation + AddInputFromArray(TensorShape({nnz, 2}), {0, 0, 0, 1, 0, 2, 0, 3, 0, 4, + 0, 5, 1, 0, 2, 0, 2, 1, 2, 2}); + + // sp_dense_shape + AddInputFromArray(TensorShape({2}), {batch_size, entries}); + + // indices_before_unique + AddInputFromArray( + TensorShape({nnz, 2}), + {0, 5, 0, 1, 2, 1, 1, 2, 3, 6, 1, 1, 1, 7, 2, 4, 2, 7, 3, 0}); + + // is_row_empty + AddInputFromArray(TensorShape({batch_size}), + {false, false, false, false}); + + // unique_idxs + AddInputFromArray(TensorShape({nnz}), {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + + // sp_weights_values + AddInputFromArray(TensorShape({1}), {1.0}); + + TF_ASSERT_OK(RunOpKernel()); + TF_EXPECT_OK(device_->Sync()); + + { + Tensor expected_emb_vectors(allocator(), DT_FLOAT, + TensorShape({batch_size, emb_vector_dim})); + test::FillValues( + &expected_emb_vectors, + {22.62741661, 24.04163170, 25.45584488, 26.87005806, 28.28427124, + 29.69848442, 31.11269951, 32.52691269, 73.90083313, 75.63288879, + 77.36493683, 79.09698486, 80.82904053, 82.56108856, 84.29314423, + 86.02519226, 92.61308289, 94.01081848, 95.40855408, 96.80628204, + 98.20401764, 99.60175323, 100.99948120, 102.39721680, 71.20205688, + 72.31395721, 73.42584991, 74.53774261, 75.64963531, 76.76153564, + 77.87342834, 78.98532867}); + test::ExpectTensorNear(expected_emb_vectors, *GetOutput(0), 1e-4); + } + { + Tensor feature_nums_expected(allocator(), DT_INT32, + TensorShape({batch_size})); + test::FillValues(&feature_nums_expected, {2, 3, 3, 2}); + test::ExpectTensorEqual(feature_nums_expected, *GetOutput(1)); + } +} + +TEST_F(FusedEmbeddingSparsePostLookUpV2OpTest, Partition2SumFillEmpty) { + const int nnz = 3; + const int batch_size = 3; + const int emb_vector_dim = 4; + const int entries = 8; + + MakeOpAndSetDevice(Device::GPU, 2, DT_FLOAT, "sum", -1.0, true, -1, false); + + // emb_shards 0 + AddInputFromArray(TensorShape({2, emb_vector_dim}), + {1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0}); + // emb_shards 1 + AddInputFromArray(TensorShape({2, emb_vector_dim}), + {10.0, 10.0, 10.0, 10.0, 13.0, 13.0, 13.0, 13.0}); + + // partition_permutation + AddInputFromArray(TensorShape({nnz + 1, 2}), {1, 1, 0, 0, 0, 1, 1, 0}); + + // sp_dense_shape + AddInputFromArray(TensorShape({2}), {batch_size, entries}); + + // indices_before_unique + AddInputFromArray(TensorShape({nnz + 1, 2}), {2, 0, 0, 0, 0, 5, 1, 4}); + + // is_row_empty + AddInputFromArray(TensorShape({batch_size}), {false, false, true}); + + // unique_idxs + AddInputFromArray(TensorShape({nnz + 1}), {0, 1, 2, 3}); + + // sp_weights_values + AddInputFromArray(TensorShape({1}), {1.0}); + + TF_ASSERT_OK(RunOpKernel()); + TF_EXPECT_OK(device_->Sync()); + + { + Tensor expected_emb_vectors(allocator(), DT_FLOAT, + TensorShape({batch_size, emb_vector_dim})); + test::FillValues( + &expected_emb_vectors, + {3.0, 3.0, 3.0, 3.0, 10.0, 10.0, 10.0, 10.0, 0.0, 0.0, 0.0, 0.0}); + test::ExpectTensorNear(expected_emb_vectors, *GetOutput(0), 1e-4); + } + { + Tensor feature_nums_expected(allocator(), DT_INT32, + TensorShape({batch_size})); + test::FillValues(&feature_nums_expected, {2, 1, 0}); + test::ExpectTensorEqual(feature_nums_expected, *GetOutput(1)); + } +} + +TEST_F(FusedEmbeddingSparsePostLookUpV2OpTest, Partition2SumFillEmptyDefault2) { + const int nnz = 3; + const int batch_size = 3; + const int emb_vector_dim = 4; + const int entries = 8; + + MakeOpAndSetDevice(Device::GPU, 2, DT_FLOAT, "sum", -1.0, true, 2, false); + + // emb_shards 0 + AddInputFromArray(TensorShape({2, emb_vector_dim}), + {1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0}); + // emb_shards 1 + AddInputFromArray(TensorShape({2, emb_vector_dim}), + {10.0, 10.0, 10.0, 10.0, 13.0, 13.0, 13.0, 13.0}); + + // partition_permutation + AddInputFromArray(TensorShape({nnz + 1, 2}), {1, 1, 0, 0, 0, 1, 1, 0}); + + // sp_dense_shape + AddInputFromArray(TensorShape({2}), {batch_size, entries}); + + // indices_before_unique + AddInputFromArray(TensorShape({nnz + 1, 2}), {2, 0, 0, 0, 0, 5, 1, 4}); + + // is_row_empty + AddInputFromArray(TensorShape({batch_size}), {false, false, true}); + + // unique_idxs + AddInputFromArray(TensorShape({nnz + 1}), {0, 1, 2, 3}); + + // sp_weights_values + AddInputFromArray(TensorShape({1}), {1.0}); + + TF_ASSERT_OK(RunOpKernel()); + TF_EXPECT_OK(device_->Sync()); + + { + Tensor expected_emb_vectors(allocator(), DT_FLOAT, + TensorShape({batch_size, emb_vector_dim})); + test::FillValues( + &expected_emb_vectors, + {3.0, 3.0, 3.0, 3.0, 10.0, 10.0, 10.0, 10.0, 13.0, 13.0, 13.0, 13.0}); + test::ExpectTensorNear(expected_emb_vectors, *GetOutput(0), 1e-4); + } + { + Tensor feature_nums_expected(allocator(), DT_INT32, + TensorShape({batch_size})); + test::FillValues(&feature_nums_expected, {2, 1, 1}); + test::ExpectTensorEqual(feature_nums_expected, *GetOutput(1)); + } +} + +TEST_F(FusedEmbeddingSparsePostLookUpV2OpTest, + Partition2MeanFillEmptyDefault2Unique) { + const int nnz = 7; + const int batch_size = 5; + const int emb_vector_dim = 2; + const int entries = 4; + + MakeOpAndSetDevice(Device::GPU, 2, DT_FLOAT, "mean", -1.0, true, 2, false); + + // emb_shards 0 + AddInputFromArray(TensorShape({2, emb_vector_dim}), + {1.0, 1.0, 2.0, 2.0}); + // emb_shards 1 + AddInputFromArray(TensorShape({2, emb_vector_dim}), + {3.0, 3.0, 4.0, 4.0}); + + // partition_permutation + AddInputFromArray(TensorShape({4, 2}), {1, 0, 0, 1, 1, 1, 0, 0}); + + // sp_dense_shape + AddInputFromArray(TensorShape({2}), {batch_size, entries}); + + // indices_before_unique + AddInputFromArray(TensorShape({nnz + 1, 2}), + {0, 1, 0, 2, 1, 0, 1, 1, 3, 0, 3, 1, 4, 0, 2, 0}); + + // is_row_empty + AddInputFromArray(TensorShape({batch_size}), + {false, false, true, false, false}); + + // unique_idxs + AddInputFromArray(TensorShape({nnz + 1}), {0, 1, 1, 0, 2, 3, 3, 3}); + + // sp_weights_values + AddInputFromArray(TensorShape({1}), {1.0}); + + TF_ASSERT_OK(RunOpKernel()); + TF_EXPECT_OK(device_->Sync()); + + // {3 + 2, 3 + 2, 3 + 2, 3 + 2, 1, 1, 4 + 1, 4 + 1, 1, 1} + // {2, 2, 1, 2, 1} + + { + Tensor expected_emb_vectors(allocator(), DT_FLOAT, + TensorShape({batch_size, emb_vector_dim})); + test::FillValues(&expected_emb_vectors, + {2.5, 2.5, 2.5, 2.5, 1.0, 1.0, 2.5, 2.5, 1.0, 1.0}); + test::ExpectTensorNear(expected_emb_vectors, *GetOutput(0), 1e-4); + } + { + Tensor feature_nums_expected(allocator(), DT_INT32, + TensorShape({batch_size})); + test::FillValues(&feature_nums_expected, {2, 2, 1, 2, 1}); + test::ExpectTensorEqual(feature_nums_expected, *GetOutput(1)); + } +} + +TEST_F(FusedEmbeddingSparsePostLookUpV2OpTest, + SinglePartitionMeanFillEmptyDefault2Unique) { + const int nnz = 7; + const int batch_size = 5; + const int emb_vector_dim = 2; + const int entries = 4; + + MakeOpAndSetDevice(Device::GPU, 1, DT_FLOAT, "mean", -1.0, true, 2, false); + + // emb_shards 0 + AddInputFromArray(TensorShape({4, emb_vector_dim}), + {3.0, 3.0, 2.0, 2.0, 4.0, 4.0, 1.0, 1.0}); + + // partition_permutation, whatever, will not use this + AddInputFromArray(TensorShape({1, 1}), {1}); + + // sp_dense_shape + AddInputFromArray(TensorShape({2}), {batch_size, entries}); + + // indices_before_unique + AddInputFromArray(TensorShape({nnz + 1, 2}), + {0, 1, 0, 2, 1, 0, 1, 1, 3, 0, 3, 1, 4, 0, 2, 0}); + + // is_row_empty + AddInputFromArray(TensorShape({batch_size}), + {false, false, true, false, false}); + + // unique_idxs + AddInputFromArray(TensorShape({nnz + 1}), {0, 1, 1, 0, 2, 3, 3, 3}); + + // sp_weights_values + AddInputFromArray(TensorShape({1}), {1.0}); + + TF_ASSERT_OK(RunOpKernel()); + TF_EXPECT_OK(device_->Sync()); + + // {3 + 2, 3 + 2, 3 + 2, 3 + 2, 1, 1, 4 + 1, 4 + 1, 1, 1} + // {2, 2, 1, 2, 1} + + { + Tensor expected_emb_vectors(allocator(), DT_FLOAT, + TensorShape({batch_size, emb_vector_dim})); + test::FillValues(&expected_emb_vectors, + {2.5, 2.5, 2.5, 2.5, 1.0, 1.0, 2.5, 2.5, 1.0, 1.0}); + test::ExpectTensorNear(expected_emb_vectors, *GetOutput(0), 1e-4); + } + { + Tensor feature_nums_expected(allocator(), DT_INT32, + TensorShape({batch_size})); + test::FillValues(&feature_nums_expected, {2, 2, 1, 2, 1}); + test::ExpectTensorEqual(feature_nums_expected, *GetOutput(1)); + } +} + +TEST_F(FusedEmbeddingSparsePostLookUpV2OpTest, + SinglePartitionMeanFillEmptyDefault2UniqueSparseWeights) { + const int nnz = 7; + const int batch_size = 5; + const int emb_vector_dim = 2; + const int entries = 4; + + MakeOpAndSetDevice(Device::GPU, 1, DT_FLOAT, "sum", -1.0, true, 2, true); + + // emb_shards 0 + AddInputFromArray(TensorShape({4, emb_vector_dim}), + {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}); + + // partition_permutation, whatever, will not use this + AddInputFromArray(TensorShape({1, 1}), {1}); + + // sp_dense_shape + AddInputFromArray(TensorShape({2}), {batch_size, entries}); + + // indices_before_unique + AddInputFromArray(TensorShape({nnz + 1, 2}), + {0, 1, 0, 2, 1, 0, 1, 1, 3, 0, 3, 1, 4, 0, 2, 0}); + + // is_row_empty + AddInputFromArray(TensorShape({batch_size}), + {false, false, true, false, false}); + + // unique_idxs + AddInputFromArray(TensorShape({nnz + 1}), {0, 1, 1, 0, 2, 3, 3, 3}); + + // sp_weights_values + AddInputFromArray(TensorShape({nnz + 1}), + {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}); + + TF_ASSERT_OK(RunOpKernel()); + TF_EXPECT_OK(device_->Sync()); + + { + Tensor expected_emb_vectors(allocator(), DT_FLOAT, + TensorShape({batch_size, emb_vector_dim})); + test::FillValues(&expected_emb_vectors, + {3.0, 3.0, 7.0, 7.0, 8.0, 8.0, 11.0, 11.0, 7.0, 7.0}); + test::ExpectTensorNear(expected_emb_vectors, *GetOutput(0), 1e-4); + } + { + Tensor feature_nums_expected(allocator(), DT_INT32, + TensorShape({batch_size})); + test::FillValues(&feature_nums_expected, {2, 2, 1, 2, 1}); + test::ExpectTensorEqual(feature_nums_expected, *GetOutput(1)); + } +} + +} // namespace +} // namespace tensorflow \ No newline at end of file diff --git a/tensorflow/core/kernels/fused_embedding/gpu/tests/partition_with_permutation_ops_test.cc b/tensorflow/core/kernels/fused_embedding/gpu/tests/partition_with_permutation_ops_test.cc new file mode 100644 index 00000000000..ed7952e2faa --- /dev/null +++ b/tensorflow/core/kernels/fused_embedding/gpu/tests/partition_with_permutation_ops_test.cc @@ -0,0 +1,151 @@ +#include + +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/session.h" + +namespace tensorflow { +namespace { + +enum class Device { GPU }; + +class PartitionWithPermutationOpTest : public OpsTestBase { + protected: + void MakeOpAndSetDevice(Device device, const int num_partitions, + const std::string& partition_strategy) { + if (device == Device::GPU) { + SetDevice(DEVICE_GPU, + std::unique_ptr(DeviceFactory::NewDevice( + "GPU", {}, "/job:a/replica:0/task:0"))); + } + + TF_EXPECT_OK( + NodeDefBuilder("partition_with_permutation", "PartitionWithPermutation") + .Attr("num_partitions", num_partitions) + .Attr("partition_axis", 0) + .Attr("partition_strategy", partition_strategy) + .Input(FakeInput(DT_INT64)) + .Input(FakeInput(num_partitions, DT_INT64)) + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + } +}; + +TEST_F(PartitionWithPermutationOpTest, Partition3Div) { + MakeOpAndSetDevice(Device::GPU, 3, std::string("div")); + // sp_values + AddInputFromArray(TensorShape({12}), + {1, 5, 3, 6, 12, 14, 15, 0, 5, 5, 11, 7}); + // partition_shapes 0 + AddInputFromArray(TensorShape({2}), {6, 16}); + // partition_shapes 1 + AddInputFromArray(TensorShape({2}), {3, 16}); + // partition_shapes 2 + AddInputFromArray(TensorShape({2}), {7, 16}); + + TF_ASSERT_OK(RunOpKernel()); + TF_EXPECT_OK(device_->Sync()); + // partitioned_values 0 + { + Tensor expected(allocator(), DT_INT64, TensorShape({6})); + test::FillValues(&expected, {1, 5, 3, 0, 5, 5}); + test::ExpectTensorEqual(expected, *GetOutput(0)); + } + // partitioned_values 1 + { + Tensor expected(allocator(), DT_INT64, TensorShape({2})); + test::FillValues(&expected, {6 - 6, 7 - 6}); + test::ExpectTensorEqual(expected, *GetOutput(1)); + } + // partitioned_values 2 + { + Tensor expected(allocator(), DT_INT64, TensorShape({4})); + test::FillValues(&expected, {12 - 9, 14 - 9, 15 - 9, 11 - 9}); + test::ExpectTensorEqual(expected, *GetOutput(2)); + } + + // partition_permutation + { + Tensor expected(allocator(), DT_INT32, TensorShape({12, 2})); + test::FillValues(&expected, {0, 0, 0, 1, 0, 2, 1, 0, 2, 0, 2, 1, + 2, 2, 0, 3, 0, 4, 0, 5, 2, 3, 1, 1}); + test::ExpectTensorEqual(expected, *GetOutput(3)); + } +} + +TEST_F(PartitionWithPermutationOpTest, Partition2Mod) { + MakeOpAndSetDevice(Device::GPU, 2, std::string("mod")); + // sp_values + AddInputFromArray(TensorShape({12}), + {1, 5, 3, 6, 12, 14, 15, 0, 5, 5, 11, 7}); + // partition_shapes 0 + AddInputFromArray(TensorShape({2}), {6, 16}); + // partition_shapes 1 + AddInputFromArray(TensorShape({2}), {6, 16}); + + TF_ASSERT_OK(RunOpKernel()); + TF_EXPECT_OK(device_->Sync()); + // partitioned_values 0 + { + Tensor expected(allocator(), DT_INT64, TensorShape({4})); + test::FillValues(&expected, {6 / 2, 12 / 2, 14 / 2, 0 / 2}); + test::ExpectTensorEqual(expected, *GetOutput(0)); + } + // partitioned_values 1 + { + Tensor expected(allocator(), DT_INT64, TensorShape({8})); + test::FillValues( + &expected, {1 / 2, 5 / 2, 3 / 2, 15 / 2, 5 / 2, 5 / 2, 11 / 2, 7 / 2}); + test::ExpectTensorEqual(expected, *GetOutput(1)); + } + + // partition_permutation + { + Tensor expected(allocator(), DT_INT32, TensorShape({12, 2})); + test::FillValues(&expected, {1, 0, 1, 1, 1, 2, 0, 0, 0, 1, 0, 2, + 1, 3, 0, 3, 1, 4, 1, 5, 1, 6, 1, 7}); + test::ExpectTensorEqual(expected, *GetOutput(2)); + } +} + +TEST_F(PartitionWithPermutationOpTest, Partition2ModEV) { + MakeOpAndSetDevice(Device::GPU, 2, std::string("mod_ev")); + // sp_values + AddInputFromArray(TensorShape({6}), {5, 28, 1003, 2004, 1834, 17833}); + // partition_shapes 0 + AddInputFromArray(TensorShape({2}), {10000, 8}); + // partition_shapes 1 + AddInputFromArray(TensorShape({2}), {10000, 8}); + + TF_ASSERT_OK(RunOpKernel()); + TF_EXPECT_OK(device_->Sync()); + // partitioned_values 0 + { + Tensor expected(allocator(), DT_INT64, TensorShape({3})); + test::FillValues(&expected, {28, 2004, 1834}); + test::ExpectTensorEqual(expected, *GetOutput(0)); + } + // partitioned_values 1 + { + Tensor expected(allocator(), DT_INT64, TensorShape({3})); + test::FillValues(&expected, {5, 1003, 17833}); + test::ExpectTensorEqual(expected, *GetOutput(1)); + } + + // partition_permutation + { + Tensor expected(allocator(), DT_INT32, TensorShape({6, 2})); + test::FillValues(&expected, {1, 0, 0, 0, 1, 1, 0, 1, 0, 2, 1, 2}); + test::ExpectTensorEqual(expected, *GetOutput(2)); + } +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/kernels/fused_embedding/gpu/tests/prune_invalid_and_fill_empty_rows_ops_test.cc b/tensorflow/core/kernels/fused_embedding/gpu/tests/prune_invalid_and_fill_empty_rows_ops_test.cc new file mode 100644 index 00000000000..90c9ef19cd4 --- /dev/null +++ b/tensorflow/core/kernels/fused_embedding/gpu/tests/prune_invalid_and_fill_empty_rows_ops_test.cc @@ -0,0 +1,250 @@ +#include + +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/session.h" + +namespace tensorflow { +namespace { + +enum class Device { GPU }; + +class PruneInvalidAndFillEmptyRowsOpTest : public OpsTestBase { + protected: + void MakeOpAndSetDevice(Device device, const bool fill_empty_row, + const bool prune, const int default_id, + const bool use_sparse_weights, + const bool prune_sparse_weights, + const float default_weight) { + if (device == Device::GPU) { + SetDevice(DEVICE_GPU, + std::unique_ptr(DeviceFactory::NewDevice( + "GPU", {}, "/job:a/replica:0/task:0"))); + } + + TF_EXPECT_OK(NodeDefBuilder("prune_invalid_and_fill_empty_rows", + "PruneInvalidAndFillEmptyRows") + .Attr("fill_empty_row", fill_empty_row) + .Attr("prune", prune) + .Attr("default_id", default_id) + .Attr("use_sparse_weights", use_sparse_weights) + .Attr("prune_sparse_weights", prune_sparse_weights) + .Attr("default_weight", default_weight) + .Input(FakeInput(DT_INT64)) + .Input(FakeInput(DT_INT64)) + .Input(FakeInput(DT_INT64)) + .Input(FakeInput(DT_FLOAT)) + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + } +}; + +TEST_F(PruneInvalidAndFillEmptyRowsOpTest, NothingHappend) { + MakeOpAndSetDevice(Device::GPU, false, false, -1, false, false, 1.0); + // sp_values + AddInputFromArray(TensorShape({12}), + {1, 5, 3, 6, 12, 14, 15, 0, 5, 5, 11, 7}); + // sp_indices + AddInputFromArray(TensorShape({12, 2}), + {2, 3, 4, 6, 1, 6, 12, 12, 12, 12, 11, 5, + 15, 0, 11, 6, 7, 9, 11, 8, 12, 13, 13, 0}); + // sp_dense_shape + AddInputFromArray(TensorShape({2}), {16, 16}); + + // sp_weights_values, even shape does not match sp_values, it doesn't matter + AddInputFromArray(TensorShape({1}), {1.0}); + + TF_ASSERT_OK(RunOpKernel()); + TF_EXPECT_OK(device_->Sync()); + // sp_values_out + { + Tensor expected(allocator(), DT_INT64, TensorShape({12})); + test::FillValues(&expected, + {1, 5, 3, 6, 12, 14, 15, 0, 5, 5, 11, 7}); + test::ExpectTensorEqual(expected, *GetOutput(0)); + } + // sp_indices_out + { + Tensor expected(allocator(), DT_INT64, TensorShape({12, 2})); + test::FillValues(&expected, + {2, 3, 4, 6, 1, 6, 12, 12, 12, 12, 11, 5, + 15, 0, 11, 6, 7, 9, 11, 8, 12, 13, 13, 0}); + test::ExpectTensorEqual(expected, *GetOutput(1)); + } + + // sp_weights_values_out + { + Tensor expected(allocator(), DT_FLOAT, TensorShape({1})); + test::FillValues(&expected, {1.0}); + test::ExpectTensorEqual(expected, *GetOutput(2)); + } +} + +TEST_F(PruneInvalidAndFillEmptyRowsOpTest, PruneWithAllValid) { + MakeOpAndSetDevice(Device::GPU, false, true, -1, true, false, 1.0); + // sp_values + AddInputFromArray(TensorShape({12}), + {1, 5, 3, 6, 12, 14, 15, 0, 5, 5, 11, 7}); + // sp_indices + AddInputFromArray(TensorShape({12, 2}), + {2, 3, 4, 6, 1, 6, 12, 12, 12, 12, 11, 5, + 15, 0, 11, 6, 7, 9, 11, 8, 12, 13, 13, 0}); + // sp_dense_shape + AddInputFromArray(TensorShape({2}), {16, 16}); + // sp_weights_values + AddInputFromArray(TensorShape({12}), {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}); + + TF_ASSERT_OK(RunOpKernel()); + TF_EXPECT_OK(device_->Sync()); + // sp_values_out + { + Tensor expected(allocator(), DT_INT64, TensorShape({12})); + test::FillValues(&expected, + {1, 5, 3, 6, 12, 14, 15, 0, 5, 5, 11, 7}); + test::ExpectTensorEqual(expected, *GetOutput(0)); + } + // sp_indices_out + { + Tensor expected(allocator(), DT_INT64, TensorShape({12, 2})); + test::FillValues(&expected, + {2, 3, 4, 6, 1, 6, 12, 12, 12, 12, 11, 5, + 15, 0, 11, 6, 7, 9, 11, 8, 12, 13, 13, 0}); + test::ExpectTensorEqual(expected, *GetOutput(1)); + } + // sp_weights_values_out + { + Tensor expected(allocator(), DT_FLOAT, TensorShape({12})); + test::FillValues(&expected, {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0}); + test::ExpectTensorEqual(expected, *GetOutput(2)); + } +} + +TEST_F(PruneInvalidAndFillEmptyRowsOpTest, FillEmptyRows) { + MakeOpAndSetDevice(Device::GPU, true, false, -1, false, false, 3.0); + // sp_values + AddInputFromArray(TensorShape({4}), {1, 5, 3, 6}); + // sp_indices + AddInputFromArray(TensorShape({4, 2}), {0, 1, 0, 2, 2, 3, 3, 1}); + // sp_dense_shape + AddInputFromArray(TensorShape({2}), {4, 8}); + // sp_weights_values + AddInputFromArray(TensorShape({4}), {1.0, 1.0, 1.0, 1.0}); + + TF_ASSERT_OK(RunOpKernel()); + TF_EXPECT_OK(device_->Sync()); + // sp_values_out + { + Tensor expected(allocator(), DT_INT64, TensorShape({5})); + test::FillValues(&expected, {1, 5, 3, 6, 0}); + test::ExpectTensorEqual(expected, *GetOutput(0)); + } + // sp_indices_out + { + Tensor expected(allocator(), DT_INT64, TensorShape({5, 2})); + test::FillValues(&expected, {0, 1, 0, 2, 2, 3, 3, 1, 1, 0}); + test::ExpectTensorEqual(expected, *GetOutput(1)); + } + // sp_weights_values_out, don't care + // { + // Tensor expected(allocator(), DT_FLOAT, TensorShape({5})); + // test::FillValues(&expected, {1.0, 1.0, 1.0, 1.0, 3.0}); + // test::ExpectTensorEqual(expected, *GetOutput(2)); + // } + // is_row_empty + { + Tensor expected(allocator(), DT_BOOL, TensorShape({4})); + test::FillValues(&expected, {false, true, false, false}); + test::ExpectTensorEqual(expected, *GetOutput(3)); + } +} + +TEST_F(PruneInvalidAndFillEmptyRowsOpTest, PruneAndFillEmptyRowsWithDefaultId) { + MakeOpAndSetDevice(Device::GPU, true, true, 8, true, false, 10.0); + // sp_values + AddInputFromArray(TensorShape({4}), {1, 5, 3, -5}); + // sp_indices + AddInputFromArray(TensorShape({4, 2}), {0, 1, 0, 2, 2, 3, 3, 1}); + // sp_dense_shape + AddInputFromArray(TensorShape({2}), {4, 8}); + // sp_weights_values + AddInputFromArray(TensorShape({4}), {1.0, 2.0, 3.0, 4.0}); + + TF_ASSERT_OK(RunOpKernel()); + TF_EXPECT_OK(device_->Sync()); + // sp_values_out + { + Tensor expected(allocator(), DT_INT64, TensorShape({5})); + test::FillValues(&expected, {1, 5, 3, 8, 8}); + test::ExpectTensorEqual(expected, *GetOutput(0)); + } + // sp_indices_out + { + Tensor expected(allocator(), DT_INT64, TensorShape({5, 2})); + test::FillValues(&expected, {0, 1, 0, 2, 2, 3, 1, 0, 3, 0}); + test::ExpectTensorEqual(expected, *GetOutput(1)); + } + // sp_weights_values_out + { + Tensor expected(allocator(), DT_FLOAT, TensorShape({5})); + test::FillValues(&expected, {1.0, 2.0, 3.0, 10.0, 10.0}); + test::ExpectTensorEqual(expected, *GetOutput(2)); + } + // is_row_empty + { + Tensor expected(allocator(), DT_BOOL, TensorShape({4})); + test::FillValues(&expected, {false, true, false, true}); + test::ExpectTensorEqual(expected, *GetOutput(3)); + } +} + +TEST_F(PruneInvalidAndFillEmptyRowsOpTest, + PruneAndFillEmptyRowsWithDefaultIdWithSparseWeights) { + MakeOpAndSetDevice(Device::GPU, true, true, 8, true, true, 10.0); + // sp_values + AddInputFromArray(TensorShape({4}), {1, 5, 3, -5}); + // sp_indices + AddInputFromArray(TensorShape({4, 2}), {0, 1, 0, 2, 2, 3, 3, 1}); + // sp_dense_shape + AddInputFromArray(TensorShape({2}), {4, 8}); + // sp_weights + AddInputFromArray(TensorShape({4}), {-1.0, 2.0, 3.0, 4.0}); + + TF_ASSERT_OK(RunOpKernel()); + TF_EXPECT_OK(device_->Sync()); + // sp_values_out + { + Tensor expected(allocator(), DT_INT64, TensorShape({4})); + test::FillValues(&expected, {5, 3, 8, 8}); + test::ExpectTensorEqual(expected, *GetOutput(0)); + } + // sp_indices_out + { + Tensor expected(allocator(), DT_INT64, TensorShape({4, 2})); + test::FillValues(&expected, {0, 2, 2, 3, 1, 0, 3, 0}); + test::ExpectTensorEqual(expected, *GetOutput(1)); + } + // sp_weights_values_out + { + Tensor expected(allocator(), DT_FLOAT, TensorShape({4})); + test::FillValues(&expected, {2.0, 3.0, 10.0, 10.0}); + test::ExpectTensorEqual(expected, *GetOutput(2)); + } + // is_row_empty + { + Tensor expected(allocator(), DT_BOOL, TensorShape({4})); + test::FillValues(&expected, {false, true, false, true}); + test::ExpectTensorEqual(expected, *GetOutput(3)); + } +} + +} // namespace +} // namespace tensorflow \ No newline at end of file diff --git a/tensorflow/core/kernels/fused_embedding/gpu/tests/unique_with_count_v3_ops_test.cc b/tensorflow/core/kernels/fused_embedding/gpu/tests/unique_with_count_v3_ops_test.cc new file mode 100644 index 00000000000..2efae49656b --- /dev/null +++ b/tensorflow/core/kernels/fused_embedding/gpu/tests/unique_with_count_v3_ops_test.cc @@ -0,0 +1,105 @@ +#include + +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/session.h" + +namespace tensorflow { +namespace { + +enum class Device { GPU }; + +class UniqueWithCountsV3OpTest : public OpsTestBase { + protected: + void MakeOpAndSetDevice(Device device) { + if (device == Device::GPU) { + SetDevice(DEVICE_GPU, + std::unique_ptr(DeviceFactory::NewDevice( + "GPU", {}, "/job:a/replica:0/task:0"))); + } + + TF_EXPECT_OK(NodeDefBuilder("unique_with_counts_v3", "UniqueWithCountsV3") + .Attr("CounterType", DT_INT32) + .Input(FakeInput(DT_INT64)) + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + } +}; + +TEST_F(UniqueWithCountsV3OpTest, UniqueWithCount) { + MakeOpAndSetDevice(Device::GPU); + const int input_size = 20; + const int uniq_size = 14; + // input + AddInputFromArray( + TensorShape({input_size}), + {1, 3, 2, 1, 4, 5, 6, 5, 7, 8, 1, 9, 10, 2, 13, 15, 17, 13, 12, 8}); + + TF_ASSERT_OK(RunOpKernel()); + TF_EXPECT_OK(device_->Sync()); + + std::vector input = {1, 3, 2, 1, 4, 5, 6, 5, 7, 8, + 1, 9, 10, 2, 13, 15, 17, 13, 12, 8}; + + std::vector unique_keys; + std::vector expected_unique_keys = {1, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 12, 13, 15, 17}; + + std::vector expected_unique_counts = {3, 2, 1, 1, 2, 1, 1, + 2, 1, 1, 1, 2, 1, 1}; + + auto unique_keys_tensor = GetOutput(0); + auto unique_idxs_tensor = GetOutput(1); + auto unique_counts_tenosr = GetOutput(2); + + // unique_idxs + { + test::internal::Expector::Equal(unique_idxs_tensor->dim_size(0), + input_size); + for (int i = 0; i < input_size; i++) { + test::internal::Expector::Equal( + unique_keys_tensor->flat() + .data()[unique_idxs_tensor->flat().data()[i]], + input[i]); + } + } + // unique_counts + { + for (int i = 0; i < uniq_size; i++) { + const int count = expected_unique_counts[i]; + const int64 expected_key = expected_unique_keys[i]; + for (int j = 0; j < uniq_size; j++) { + if (unique_keys_tensor->flat().data()[j] == expected_key) { + test::internal::Expector::Equal( + unique_counts_tenosr->flat().data()[j], count); + } + } + } + } + + // test unique_keys + { + test::internal::Expector::Equal(unique_keys_tensor->dim_size(0), + uniq_size); + for (int i = 0; i < uniq_size; i++) { + unique_keys.push_back(unique_keys_tensor->flat().data()[i]); + } + + std::sort(unique_keys.begin(), unique_keys.end()); + + for (int i = 0; i < uniq_size; i++) { + test::internal::Expector::Equal(unique_keys[i], + expected_unique_keys[i]); + } + } +} + +} // namespace +} // namespace tensorflow \ No newline at end of file diff --git a/tensorflow/core/kernels/fused_embedding/gpu/unique_with_count_v3_ops.cu.cc b/tensorflow/core/kernels/fused_embedding/gpu/unique_with_count_v3_ops.cu.cc new file mode 100644 index 00000000000..4086275da1b --- /dev/null +++ b/tensorflow/core/kernels/fused_embedding/gpu/unique_with_count_v3_ops.cu.cc @@ -0,0 +1,409 @@ +#include +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" + +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "cub/device/device_radix_sort.cuh" +#include "cub/device/device_select.cuh" +#include "cub/iterator/constant_input_iterator.cuh" +#include "cub/thread/thread_operators.cuh" +#include "tensorflow/core/kernels/fused_embedding/gpu/common.cu.h" +#include "tensorflow/core/kernels/fused_embedding/gpu/functions/hash_functions.cu.h" +#include "tensorflow/core/kernels/fused_embedding/gpu/functions/kernels.cu.h" +#include "tensorflow/core/profiler/nvtx_utils.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" + +// Overload CUDA atomic for other 64bit unsinged/signed integer type +__forceinline__ __device__ long atomicAdd(long* address, long val) { + return (long)atomicAdd((unsigned long long*)address, (unsigned long long)val); +} + +__forceinline__ __device__ long long atomicAdd(long long* address, + long long val) { + return (long long)atomicAdd((unsigned long long*)address, + (unsigned long long)val); +} + +__forceinline__ __device__ unsigned long atomicAdd(unsigned long* address, + unsigned long val) { + return (unsigned long)atomicAdd((unsigned long long*)address, + (unsigned long long)val); +} + +__forceinline__ __device__ long atomicCAS(long* address, long compare, + long val) { + return (long)atomicCAS((unsigned long long*)address, + (unsigned long long)compare, (unsigned long long)val); +} + +__forceinline__ __device__ long long atomicCAS(long long* address, + long long compare, + long long val) { + return (long long)atomicCAS((unsigned long long*)address, + (unsigned long long)compare, + (unsigned long long)val); +} + +__forceinline__ __device__ unsigned long atomicCAS(unsigned long* address, + unsigned long compare, + unsigned long val) { + return (unsigned long)atomicCAS((unsigned long long*)address, + (unsigned long long)compare, + (unsigned long long)val); +} + +namespace tensorflow { +using GPUDevice = Eigen::GpuDevice; + +namespace gpu_unique_with_counts { + +const static int block_size = 64; + +template +__global__ void InitKernel(KeyType* keys, CounterType* vals, + CounterType* counts, CounterType* counter, + const size_t capacity, const KeyType empty_key, + const CounterType empty_val, + const CounterType empty_counts, + const CounterType init_counter_val) { + const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < capacity) { + // Simply store every element a unused pair + keys[idx] = empty_key; + vals[idx] = empty_val; + counts[idx] = empty_counts; + } + if (idx == 0) { + counter[idx] = init_counter_val; + } +} + +template +void Init(const GPUDevice& d, KeyType* keys, CounterType* vals, + CounterType* counts, CounterType* counter, const size_t capacity, + const KeyType empty_key, const CounterType empty_val, + const CounterType empty_counts, const CounterType init_counter_val) { + const int threads = block_size; + const int blocks = (capacity - 1) / block_size + 1; + TF_CHECK_OK(GpuLaunchKernel(InitKernel, blocks, threads, + 0, d.stream(), keys, vals, counts, counter, + capacity, empty_key, empty_val, empty_counts, + init_counter_val)); +} + +template +__global__ void GetSizeKernel(const KeyType* keys, const size_t capacity, + size_t* d_size, const KeyType empty_key) { + /* Per block accumulator */ + __shared__ size_t block_acc; + + const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + + /* Initialize */ + if (threadIdx.x == 0) { + block_acc = 0; + } + __syncthreads(); + + /* Whether the slot mapping to the current thread is empty? do nothing : + * Atomically add to counter */ + if (idx < capacity) { + if (keys[idx] != empty_key) { + atomicAdd(&block_acc, 1); + } + } + __syncthreads(); + + /* Atomically reduce block counter to global conuter */ + if (threadIdx.x == 0) { + atomicAdd(d_size, block_acc); + } +} + +template +void GetSize(const GPUDevice& d, const KeyType* keys, const size_t capacity, + size_t* d_size, const KeyType empty_key) { + const int threads = block_size; + const int blocks = (capacity - 1) / block_size + 1; + TF_CHECK_OK(GpuLaunchKernel(GetSizeKernel, blocks, threads, 0, + d.stream(), keys, capacity, d_size, empty_key)); +} + +template > +__global__ void GetInsertKernel(const KeyType* d_key, CounterType* d_val, + const size_t len, KeyType* keys, + CounterType* vals, CounterType* counts, + const size_t capacity, + CounterType* d_global_counter, + KeyType empty_key, + const CounterType empty_val) { + const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < len) { + KeyType target_key = d_key[idx]; + size_t hash_index = hasher::hash(target_key) % capacity; + size_t counter = 0; + +#if __CUDA_ARCH__ < 700 + // pre-volta + bool thread_finished = false; +#endif + while (true) { + // Have searched all the slot in the hashtable, but all slots in the + // hashtable are occupied by other keys + if (counter >= capacity) { + assert(false && "error: unique op fails: hashtable is full"); + } + // Try to set the key for the current slot to target key + const KeyType old_key = + atomicCAS(keys + hash_index, empty_key, target_key); + volatile CounterType& target_val_pos = vals[hash_index]; +#if __CUDA_ARCH__ >= 700 + // volta & post-volta, independent scheduling + if (empty_key == old_key) { + CounterType result_val; + result_val = atomicAdd(d_global_counter, 1); + d_val[idx] = result_val; + target_val_pos = result_val; + atomicAdd(counts + hash_index, 1); + break; + } else if (target_key == old_key) { + while (target_val_pos == empty_val) + ; + d_val[idx] = target_val_pos; + atomicAdd(counts + hash_index, 1); + break; + } +#else + // pre-volta + if (empty_key == old_key || target_key == old_key) { + while (true) { + if (empty_key == old_key) { + CounterType result_val; + result_val = atomicAdd(d_global_counter, 1); + d_val[idx] = result_val; + target_val_pos = result_val; + atomicAdd(counts + hash_index, 1); + break; + } else { + if (target_val_pos != empty_val) { + d_val[idx] = target_val_pos; + atomicAdd(counts + hash_index, 1); + break; + } + } + } + thread_finished = true; + } + if (thread_finished) break; +#endif + counter++; + hash_index = (hash_index + 1) % capacity; + } + } +} + +template > +void GetInsert(const GPUDevice& d, const KeyType* d_key, CounterType* d_val, + const size_t input_size, KeyType* keys, CounterType* vals, + CounterType* counts, const size_t capacity, + CounterType* d_global_counter, KeyType empty_key, + const CounterType empty_val) { + const int threads = block_size; + const int blocks = (input_size - 1) / block_size + 1; + TF_CHECK_OK(GpuLaunchKernel(GetInsertKernel, + blocks, threads, 0, d.stream(), d_key, d_val, + input_size, keys, vals, counts, capacity, + d_global_counter, empty_key, empty_val)); +} + +template +__global__ void DumpKernel(KeyType* d_key, CounterType* d_counts, + const KeyType* keys, const CounterType* vals, + const CounterType* counts, const size_t offset, + const size_t capacity, size_t* d_dump_counter, + const KeyType empty_key) { + /* Per block accumulator */ + __shared__ size_t block_acc; + + const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + + /* Initialize */ + if (threadIdx.x == 0) { + block_acc = 0; + } + __syncthreads(); + + KeyType read_key; + CounterType read_val; + CounterType read_count; + bool valid_slot = false; + // Each thread gather the key and value from slot assigned to them. + if (idx < capacity) { + read_key = keys[offset + idx]; + if (read_key != empty_key) { + valid_slot = true; + atomicAdd(&block_acc, 1); + read_val = vals[offset + idx]; + read_count = counts[offset + idx]; + } + } + __syncthreads(); + + // Each block accumulate the dump count to global counter + if (threadIdx.x == 0) { + atomicAdd(d_dump_counter, block_acc); + } + + // Each thread store one slot's data back to global memory, d_dump_counter + // is how many slots in total dumped. + if (valid_slot) { + d_key[read_val] = read_key; + d_counts[read_val] = read_count; + } +} + +template +void Dump(const GPUDevice& d, KeyType* d_key, CounterType* d_counts, + const KeyType* keys, const CounterType* vals, + const CounterType* counts, const size_t offset, const size_t capacity, + size_t* d_dump_counter, const KeyType empty_key) { + const int threads = block_size; + const int blocks = (capacity - 1) / block_size + 1; + TF_CHECK_OK(GpuLaunchKernel(DumpKernel, blocks, threads, + 0, d.stream(), d_key, d_counts, keys, vals, + counts, offset, capacity, d_dump_counter, + empty_key)); +} + +} // namespace gpu_unique_with_counts + +template +class UniqueWithCountsV3 : public OpKernel { + public: + explicit UniqueWithCountsV3(OpKernelConstruction* ctx) : OpKernel(ctx) { + cudaEventCreateWithFlags(&memcpy_event_, cudaEventDisableTiming); + } + + void Compute(OpKernelContext* ctx) override { + using namespace gpu_unique_with_counts; + using fused_embedding::data_p_with_type; + + auto device = ctx->eigen_device(); + + nvtx::ScopedRangeIfEnabled nvtx_range(this); + + Tensor const* input = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("input", &input)); + + const size_t input_size = input->NumElements(); + const KeyType empty_key = std::numeric_limits::max(); + const CounterType empty_val = std::numeric_limits::max(); + const CounterType empty_counts = 0; + const CounterType init_counter_val = 0; + const float load_factor = 1.3; + const size_t capacity = input_size * load_factor; + + Tensor keys_storage; + OP_REQUIRES_OK( + ctx, ctx->allocate_temp(DataTypeToEnum::value, + TensorShape({int64(capacity)}), &keys_storage)); + + Tensor vals_storage; + OP_REQUIRES_OK( + ctx, ctx->allocate_temp(DataTypeToEnum::value, + TensorShape({int64(capacity)}), &vals_storage)); + + Tensor counts_storage; + OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum::value, + TensorShape({int64(capacity)}), + &counts_storage)); + + Tensor counter; + OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum::value, + TensorShape({int64(1)}), &counter)); + + Tensor dump_counter; + OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_INT64, TensorShape({int64(1)}), + &dump_counter)); + + Tensor* unique_idxs; + OP_REQUIRES_OK(ctx, ctx->allocate_output("unique_idxs", + TensorShape({int64(input_size)}), + &unique_idxs)); + + Init(device, data_p_with_type(keys_storage), + data_p_with_type(vals_storage), + data_p_with_type(counts_storage), + data_p_with_type(counter), capacity, empty_key, empty_val, + empty_counts, init_counter_val); + + GetInsert(device, data_p_with_type(input), + data_p_with_type(unique_idxs), input_size, + data_p_with_type(keys_storage), + data_p_with_type(vals_storage), + data_p_with_type(counts_storage), capacity, + data_p_with_type(counter), empty_key, empty_val); + + CounterType uniq_size; + CK_CUDA_THROW_(cudaMemcpyAsync( + &uniq_size, data_p_with_type(counter), sizeof(CounterType), + cudaMemcpyDeviceToHost, device.stream())); + CK_CUDA_THROW_(cudaEventRecord(memcpy_event_, device.stream())); + CK_CUDA_THROW_(cudaEventSynchronize(memcpy_event_)); + + Tensor* unique_keys; + OP_REQUIRES_OK( + ctx, ctx->allocate_output( + "unique_keys", TensorShape({int64(uniq_size)}), &unique_keys)); + + Tensor* unique_counts; + OP_REQUIRES_OK(ctx, ctx->allocate_output("unique_counts", + TensorShape({int64(uniq_size)}), + &unique_counts)); + + Dump(device, data_p_with_type(unique_keys), + data_p_with_type(unique_counts), + data_p_with_type(keys_storage), + data_p_with_type(vals_storage), + data_p_with_type(counts_storage), 0, capacity, + data_p_with_type(dump_counter), empty_key); + } + + private: + cudaEvent_t memcpy_event_; +}; + +REGISTER_KERNEL_BUILDER(Name("UniqueWithCountsV3") + .Device(DEVICE_GPU) + .TypeConstraint("KeyType") + .TypeConstraint("CounterType"), + UniqueWithCountsV3); + +REGISTER_KERNEL_BUILDER(Name("UniqueWithCountsV3") + .Device(DEVICE_GPU) + .TypeConstraint("KeyType") + .TypeConstraint("CounterType"), + UniqueWithCountsV3); + +REGISTER_KERNEL_BUILDER(Name("UniqueWithCountsV3") + .Device(DEVICE_GPU) + .TypeConstraint("KeyType") + .TypeConstraint("CounterType"), + UniqueWithCountsV3); + +REGISTER_KERNEL_BUILDER(Name("UniqueWithCountsV3") + .Device(DEVICE_GPU) + .TypeConstraint("KeyType") + .TypeConstraint("CounterType"), + UniqueWithCountsV3); + +} // namespace tensorflow + +#endif // GOOGLE_CUDA \ No newline at end of file diff --git a/tensorflow/core/ops/fused_embedding_ops.cc b/tensorflow/core/ops/fused_embedding_ops.cc index 38cc438372d..f7cfa04bdc6 100644 --- a/tensorflow/core/ops/fused_embedding_ops.cc +++ b/tensorflow/core/ops/fused_embedding_ops.cc @@ -1,5 +1,3 @@ -#include - #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" @@ -9,59 +7,6 @@ using shape_inference::DimensionHandle; using shape_inference::InferenceContext; using shape_inference::ShapeHandle; -REGISTER_OP("FusedEmbeddingLocalSparseLookUp") - .Attr("T: {float32}") - .Attr("combiner: {'sqrtn', 'mean', 'sum'}") - .Attr("max_norm: float = -1.0") - .Input("sp_values: int64") - .Input("sp_indices: int64") - .Input("sp_dense_shape: int64") - .Input("emb_variable: T") - .Output("emb_vectors: T") - .Output("sp_values_offset: int32") - .SetShapeFn([](InferenceContext* ctx) { - ShapeHandle temp; - TF_RETURN_IF_ERROR(ctx->WithRank(ctx->input(0), 1, &temp)); - TF_RETURN_IF_ERROR(ctx->WithRank(ctx->input(1), 2, &temp)); - TF_RETURN_IF_ERROR(ctx->WithRank(ctx->input(2), 1, &temp)); - ShapeHandle emb_var_shape; - TF_RETURN_IF_ERROR(ctx->WithRank(ctx->input(3), 2, &emb_var_shape)); - - DimensionHandle emb_vec_size_dim = ctx->Dim(emb_var_shape, 1); - DimensionHandle batch_dim = ctx->UnknownDim(); - - ShapeHandle output_shape = ctx->MakeShape({batch_dim, emb_vec_size_dim}); - ctx->set_output(0, output_shape); - - return Status::OK(); - }); -// .Doc(R"doc( -// FusedEmbedding ops that performs a local embedding lookup. The process will perform embedding vector copying from emb_variable. -// The input is usually a SparseTensor. The output sp_values_offset is reserved for gradient calculation. -// )doc"); - -REGISTER_OP("FusedEmbeddingLocalSparseLookUpGrad") - .Attr("T: {float32}") - .Attr("combiner: {'sqrtn', 'mean', 'sum'}") - .Attr("max_norm: float = -1.0") - .Input("top_grad: T") - .Input("emb_variable: T") - .Input("sp_values: int64") - .Input("sp_values_offset: int32") - .Output("grad_emb_weight_sp_values: T") - .SetShapeFn([](InferenceContext* ctx) { - ShapeHandle top_grad_shape; - TF_RETURN_IF_ERROR(ctx->WithRank(ctx->input(0), 2, &top_grad_shape)); - DimensionHandle emb_vec_size_dim = ctx->Dim(top_grad_shape, 1); - ctx->set_output(0, ctx->MakeShape({ctx->UnknownDim(), emb_vec_size_dim})); - return Status::OK(); - }); - -// .Doc(R"doc( -// The gradient ops for FusedEmbeddingLocalSparseLookUp. sp_values_offset from the forward op -// need to be passed to this grad op as input. -// )doc"); - REGISTER_OP("FusedEmbeddingSparsePreLookUp") .Attr("num_partitions: int >= 1 = 1") .Attr("partition_axis: int >= 0 = 0") // for now only support = 0, @@ -118,12 +63,13 @@ REGISTER_OP("FusedEmbeddingSparsePreLookUp") return Status::OK(); }); // .Doc(R"doc( -// A fused embedding op, usually using for partitioned and distriuted embedding variables. -// FusedEmbeddingSparsePreLookUp, FusedEmbeddingSparsePostLookUp should be used together. -// This op will first read the partition pattern of embedding variables through partition_shapes, -// then sort, re-calculate and assign the embedding indices to the corresponding partition. Several Gather ops -// usually should be appended after this op to gather embedding shards from multiple partitioned embedding -// variables. This op has no gradient function. +// A fused embedding op, usually using for partitioned and distriuted embedding +// variables. FusedEmbeddingSparsePreLookUp, FusedEmbeddingSparsePostLookUp +// should be used together. This op will first read the partition pattern of +// embedding variables through partition_shapes, then sort, re-calculate and +// assign the embedding indices to the corresponding partition. Several Gather +// ops usually should be appended after this op to gather embedding shards from +// multiple partitioned embedding variables. This op has no gradient function. // )doc"); REGISTER_OP("FusedEmbeddingSparsePostLookUp") @@ -178,11 +124,12 @@ REGISTER_OP("FusedEmbeddingSparsePostLookUp") }); // .Doc(R"doc( -// A fused embedding op, usually using for partitioned and distriuted embedding variables. -// FusedEmbeddingSparsePreLookUp, FusedEmbeddingSparsePostLookUp should be used together. -// There should be several Gather ops before this op. The Gather ops gather embedding shards from -// embedding variable and this op glue them together, then apply combiner and max_morm according to -// embedding indices. +// A fused embedding op, usually using for partitioned and distriuted embedding +// variables. FusedEmbeddingSparsePreLookUp, FusedEmbeddingSparsePostLookUp +// should be used together. There should be several Gather ops before this op. +// The Gather ops gather embedding shards from embedding variable and this op +// glue them together, then apply combiner and max_morm according to embedding +// indices. // )doc"); REGISTER_OP("FusedEmbeddingSparsePostLookUpGrad") @@ -301,4 +248,285 @@ REGISTER_OP("FusedSafeEmbeddingLookupSparseLocalGrad") return Status::OK(); }); +REGISTER_OP("PruneInvalidAndFillEmptyRows") + .Attr("fill_empty_row: bool = false") + .Attr("prune: bool = false") + .Attr("default_id: int = -1") + .Attr("use_sparse_weights: bool = false") + .Attr("prune_sparse_weights: bool = false") + .Attr("default_weight: float = 1.0") + .Input("sp_values: int64") + .Input("sp_indices: int64") + .Input("sp_dense_shape: int64") + .Input("sp_weights_values: float") + .Output("sp_values_out: int64") + .Output("sp_indices_out: int64") + .Output("sp_weights_values_out: float") + .Output("is_row_empty: bool") + .SetShapeFn([](InferenceContext* ctx) { + ShapeHandle unused; + std::vector unused_list; + + ctx->input("sp_values", &unused_list); + TF_RETURN_IF_ERROR(ctx->WithRank(unused_list[0], 1, &unused)); + + ctx->input("sp_indices", &unused_list); + TF_RETURN_IF_ERROR(ctx->WithRank(unused_list[0], 2, &unused)); + DimensionHandle unused_dim; + TF_RETURN_IF_ERROR(ctx->WithValue(ctx->Dim(unused, 1), 2, &unused_dim)); + + ctx->input("sp_dense_shape", &unused_list); + TF_RETURN_IF_ERROR(ctx->WithRank(unused_list[0], 1, &unused)); + + ctx->input("sp_weights_values", &unused_list); + TF_RETURN_IF_ERROR(ctx->WithRank(unused_list[0], 1, &unused)); + + unused_list.clear(); + unused_list.resize(1); + unused_list[0] = ctx->MakeShape({ctx->UnknownDim()}); + ctx->set_output("sp_values_out", unused_list); + ctx->set_output("is_row_empty", unused_list); + unused_list[0] = ctx->MakeShape({ctx->UnknownDim(), 2}); + ctx->set_output("sp_indices_out", unused_list); + + return Status::OK(); + }); + +REGISTER_OP("UniqueWithCountsV3") + .Attr("KeyType: {int32, int64} = DT_INT64") + .Attr("CounterType: {int32, int64} = DT_INT32") + .Input("input: KeyType") + .Output("unique_keys: KeyType") + .Output("unique_idxs: CounterType") + .Output("unique_counts: CounterType") + + .SetShapeFn([](InferenceContext* ctx) { + ShapeHandle unused; + std::vector unused_list; + + ctx->input("input", &unused_list); + TF_RETURN_IF_ERROR(ctx->WithRank(unused_list[0], 1, &unused)); + + unused_list.clear(); + unused_list.resize(1); + unused_list[0] = ctx->MakeShape({ctx->UnknownDim()}); + + ctx->set_output("unique_keys", unused_list); + ctx->set_output("unique_idxs", unused_list); + ctx->set_output("unique_counts", unused_list); + + return Status::OK(); + }); + +REGISTER_OP("PartitionWithPermutation") + .Attr("num_partitions: int >= 2 = 2") + .Attr("partition_axis: int >= 0 = 0") + .Attr("partition_strategy : {'div', 'mod', 'mod_ev'}") + .Input("input: int64") + .Input("partition_shapes: num_partitions * int64") + .Output("partitioned_values: num_partitions * int64") + .Output("partition_permutation: int32") + .SetShapeFn([](InferenceContext* ctx) { + ShapeHandle unused; + std::vector unused_list; + DimensionHandle unused_dim; + + int num_partitions; + TF_RETURN_IF_ERROR(ctx->GetAttr("num_partitions", &num_partitions)); + int partition_axis; + TF_RETURN_IF_ERROR(ctx->GetAttr("partition_axis", &partition_axis)); + + ctx->input("input", &unused_list); + TF_RETURN_IF_ERROR(ctx->WithRank(unused_list[0], 1, &unused)); + + ctx->input("partition_shapes", &unused_list); + for (int i = 0; i < num_partitions; i++) { + TF_RETURN_IF_ERROR(ctx->WithRank(unused_list[i], 1, &unused)); + TF_RETURN_IF_ERROR( + ctx->WithValue(ctx->NumElements(unused), 2, &unused_dim)); + } + + unused_list.clear(); + unused_list.resize(num_partitions); + for (int i = 0; i < num_partitions; i++) { + unused_list[i] = ctx->MakeShape({ctx->UnknownDim()}); + } + ctx->set_output("partitioned_values", unused_list); + + unused_list.clear(); + unused_list.resize(1); + unused_list[0] = ctx->MakeShape({ctx->UnknownDim(), 2}); + ctx->set_output("partition_permutation", unused_list); + + return Status::OK(); + }); + +REGISTER_OP("FusedEmbeddingSparsePostLookUpV2") + .Attr("T : {float32}") + .Attr("num_partitions: int >= 1 = 1") + .Attr("fill_empty_row: bool = false") + .Attr("default_id: int = -1") + .Attr("partition_axis: int >= 0 = 0") + .Attr("combiner: {'sqrtn', 'mean', 'sum'}") + .Attr("max_norm: float = -1.0") + .Attr("use_sparse_weights: bool = false") + .Input("emb_shards: num_partitions * T") + .Input("partition_permutation: int32") + .Input("sp_dense_shape: int64") + .Input("indices_before_unique: int64") + .Input("is_row_empty: bool") + .Input("unique_idxs: int32") + .Input("sp_weights_values: float") + .Output("emb_vectors: T") + .Output("feature_nums: int32") + .Output("emb_shard_ptrs: uint64") + .SetShapeFn([](InferenceContext* ctx) { + int num_partitions; + TF_RETURN_IF_ERROR(ctx->GetAttr("num_partitions", &num_partitions)); + + std::vector unused_list; + ShapeHandle unused; + DimensionHandle unused_dim; + + // emb_shards + ctx->input("emb_shards", &unused_list); + ShapeHandle first_emb_shard_shape; + TF_RETURN_IF_ERROR( + ctx->WithRank(unused_list[0], 2, &first_emb_shard_shape)); + DimensionHandle emb_vec_size_dim = ctx->Dim(first_emb_shard_shape, 1); + int64 emb_vec_size = ctx->Value(emb_vec_size_dim); + + for (int i = 0; i < num_partitions; i++) { + TF_RETURN_IF_ERROR(ctx->WithRank(unused_list[i], 2, &unused)); + TF_RETURN_IF_ERROR( + ctx->WithValue(ctx->Dim(unused, 1), emb_vec_size, &unused_dim)); + } + + // partition_permutation + ctx->input("partition_permutation", &unused_list); + TF_RETURN_IF_ERROR(ctx->WithRank(unused_list[0], 2, &unused)); + + // sp_dense_shape + ctx->input("sp_dense_shape", &unused_list); + TF_RETURN_IF_ERROR(ctx->WithRank(unused_list[0], 1, &unused)); + + // indices_before_unique + ctx->input("indices_before_unique", &unused_list); + TF_RETURN_IF_ERROR(ctx->WithRank(unused_list[0], 2, &unused)); + + // is_row_empty + ctx->input("is_row_empty", &unused_list); + TF_RETURN_IF_ERROR(ctx->WithRank(unused_list[0], 1, &unused)); + + // unique_counts + ctx->input("unique_idxs", &unused_list); + TF_RETURN_IF_ERROR(ctx->WithRank(unused_list[0], 1, &unused)); + + ctx->input("sp_weights_values", &unused_list); + TF_RETURN_IF_ERROR(ctx->WithRank(unused_list[0], 1, &unused)); + + // emb_vectors + unused_list.clear(); + unused_list.resize(1); + unused_list[0] = ctx->MakeShape({ctx->UnknownDim(), emb_vec_size_dim}); + ctx->set_output("emb_vectors", unused_list); + + // feature_nums + unused_list[0] = ctx->MakeShape({ctx->UnknownDim()}); + ctx->set_output("feature_nums", unused_list); + + // emb_shard_ptrs + unused_list[0] = ctx->MakeShape({num_partitions}); + ctx->set_output("emb_shard_ptrs", unused_list); + return Status::OK(); + }); + +// .Doc(R"doc( +// A fused embedding op, usually using for partitioned and distriuted embedding +// variables. FusedEmbeddingSparse`LookUp, FusedEmbeddingSparsePostLookUp +// should be used together. There should be several Gather ops before this op. +// The Gather ops gather embedding shards from embedding variable and this op +// glue them together, then apply combiner and max_morm according to embedding +// indices. +// )doc"); + +REGISTER_OP("FusedEmbeddingSparsePostLookUpV2Grad") + .Attr("T : {float32}") + .Attr("num_partitions: int >= 1 = 1") + .Attr("fill_empty_row: bool = false") + .Attr("partition_axis: int >= 0 = 0") // for now only support = 0, + // will consider support = 1 + // if necessary + .Attr("default_id: int = -1") + .Attr("combiner: {'sqrtn', 'mean', 'sum'}") + .Attr("max_norm: float = -1.0") + .Attr("use_sparse_weights: bool = false") + .Input("top_grad: T") + .Input("emb_shards: num_partitions * T") + .Input("emb_shard_ptrs: uint64") + .Input("partition_permutation: int32") + .Input("feature_nums: int32") + .Input("indices_before_unique: int64") + .Input("unique_idxs: int32") + .Input("is_row_empty: bool") + .Input("sp_weights_values: float") + .Output("grad_shards: num_partitions * T") + .SetShapeFn([](InferenceContext* ctx) { + int num_partitions; + TF_RETURN_IF_ERROR(ctx->GetAttr("num_partitions", &num_partitions)); + + std::vector unused_list; + ShapeHandle unused; + DimensionHandle unused_dim; + + ctx->input("top_grad", &unused_list); + // top_grad + TF_RETURN_IF_ERROR(ctx->WithRank(unused_list[0], 2, &unused)); + DimensionHandle emb_vec_size_dim = ctx->Dim(unused, 1); + + // emb_shards + ctx->input("emb_shards", &unused_list); + for (int i = 0; i < num_partitions; i++) { + TF_RETURN_IF_ERROR(ctx->WithRank(unused_list[i], 2, &unused)); + } + + // emb_shard_ptrs + ctx->input("emb_shard_ptrs", &unused_list); + TF_RETURN_IF_ERROR(ctx->WithRank(unused_list[0], 1, &unused)); + TF_RETURN_IF_ERROR( + ctx->WithValue(ctx->Dim(unused, 0), num_partitions, &unused_dim)); + + // partition_permutation + ctx->input("partition_permutation", &unused_list); + TF_RETURN_IF_ERROR(ctx->WithRank(unused_list[0], 2, &unused)); + + // feature_nums + ctx->input("feature_nums", &unused_list); + TF_RETURN_IF_ERROR(ctx->WithRank(unused_list[0], 1, &unused)); + + // indices_before_unique + ctx->input("indices_before_unique", &unused_list); + TF_RETURN_IF_ERROR(ctx->WithRank(unused_list[0], 2, &unused)); + + // unique_idxs + ctx->input("unique_idxs", &unused_list); + TF_RETURN_IF_ERROR(ctx->WithRank(unused_list[0], 1, &unused)); + + // is_row_empty + ctx->input("is_row_empty", &unused_list); + TF_RETURN_IF_ERROR(ctx->WithRank(unused_list[0], 1, &unused)); + + // grad_shards + unused_list.clear(); + unused_list.resize(1); + unused_list[0] = ctx->MakeShape({ctx->UnknownDim(), emb_vec_size_dim}); + ctx->set_output("grad_shards", unused_list); + + return Status::OK(); + }); + +// .Doc(R"doc( +// Calculate gradient of FusedEmbeddingSparsePostLookUp +// )doc"); + } // namespace tensorflow diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index c9b25a46e75..30cebd65c84 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -2087,7 +2087,8 @@ tf_gen_op_wrapper_private_py( py_library( name = "fused_embedding_ops", - srcs = ["ops/fused_embedding_ops.py"], + srcs = ["ops/fused_embedding_ops.py", + "ops/fused_embedding_ops_v2.py"], srcs_version = "PY2AND3", deps = [ ":fused_embedding_ops_gen", diff --git a/tensorflow/python/feature_column/feature_column_v2.py b/tensorflow/python/feature_column/feature_column_v2.py index 1553d99400a..1d5687e8e64 100644 --- a/tensorflow/python/feature_column/feature_column_v2.py +++ b/tensorflow/python/feature_column/feature_column_v2.py @@ -857,7 +857,7 @@ def embedding_column(categorical_column, max_norm=None, trainable=True, coalesced_scope=None, - do_fusion=False): + do_fusion=None): """`DenseColumn` that converts from sparse, categorical input. Use this when your inputs are sparse, but you want to convert them to a dense @@ -4182,7 +4182,7 @@ def __new__( max_norm, trainable, coalesced_scope=None, - do_fusion=False): + do_fusion=None): """Create feature column in compatible way.""" return super(EmbeddingColumn, cls).__new__( cls, categorical_column, dimension, combiner, initializer, @@ -4277,7 +4277,8 @@ def _get_dense_tensor_internal_helper(self, sparse_tensors, sparse_weights=sparse_weights, combiner=self.combiner, name='%s_weights' % self.name, - max_norm=self.max_norm) + max_norm=self.max_norm, + fusion_version=self.do_fusion) else: return embedding_ops.safe_embedding_lookup_sparse( embedding_weights=embedding_weights, diff --git a/tensorflow/python/feature_column/feature_column_v2_test.py b/tensorflow/python/feature_column/feature_column_v2_test.py index 0fc33f70957..f92c7c09a7a 100644 --- a/tensorflow/python/feature_column/feature_column_v2_test.py +++ b/tensorflow/python/feature_column/feature_column_v2_test.py @@ -8514,7 +8514,7 @@ def test_serialization_with_default_initializer(self): 'tensor_name_in_ckpt': None, 'trainable': True, 'coalesced_scope': None, - 'do_fusion': False, + 'do_fusion': None, }, config) custom_objects = {'TruncatedNormal': init_ops.TruncatedNormal} @@ -8569,7 +8569,7 @@ def _initializer(shape, dtype, partition_info=None): 'tensor_name_in_ckpt': None, 'trainable': True, 'coalesced_scope': None, - 'do_fusion': False, + 'do_fusion': None, }, config) custom_objects = { diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py index 1199b9c86da..0659195738e 100644 --- a/tensorflow/python/ops/embedding_ops.py +++ b/tensorflow/python/ops/embedding_ops.py @@ -36,6 +36,7 @@ from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import variables from tensorflow.python.ops import fused_embedding_ops +from tensorflow.python.ops import fused_embedding_ops_v2 from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util.tf_export import tf_export @@ -652,6 +653,7 @@ def embedding_lookup_sparse(params, else: assert False, "Unrecognized combiner" else: + assert idx is not None if combiner == "sum": embeddings = math_ops.sparse_segment_sum( @@ -1351,11 +1353,12 @@ def fused_safe_embedding_lookup_sparse(embedding_weights, name=None, partition_strategy="div", max_norm=None, - prune=True): + prune=True, + fusion_version='v2'): """Functionally the same as safe_embedding_lookup_sparse but using fused embedding lookup ops in this method. """ - logging.info("Is using fused embedding lookup for this scope {}".format(name)) + logging.info("Is using fused embedding lookup {} for this scope {}".format(fusion_version, name)) if embedding_weights is None: raise ValueError("Missing embedding_weights %s." % embedding_weights) @@ -1395,17 +1398,31 @@ def fused_safe_embedding_lookup_sparse(embedding_weights, sparse_weights.values, sparse_ids.dense_shape) - result = fused_embedding_ops.fused_embedding_lookup_sparse( - embedding_weights, - sparse_ids, - sparse_weights=sparse_weights, - combiner=combiner, - partition_strategy=partition_strategy, - name=None if default_id is None else scope, - max_norm=max_norm, - default_id=default_id, - prune_invalid_ids=True - ) + assert(fusion_version in ['v1', 'v2']) + if fusion_version == 'v1': + result = fused_embedding_ops.fused_embedding_lookup_sparse( + embedding_weights, + sparse_ids, + sparse_weights=sparse_weights, + partition_strategy=partition_strategy, + name=None if default_id is None else scope, + combiner=combiner, + max_norm=max_norm, + default_id=default_id, + prune_invalid_ids=True, + ) + else: + result = fused_embedding_ops_v2.fused_embedding_lookup_sparse_v2( + embedding_weights, + sparse_ids, + sparse_weights=sparse_weights, + partition_strategy=partition_strategy, + name=None if default_id is None else scope, + combiner=combiner, + max_norm=max_norm, + default_id=default_id, + prune=True, + ) # Reshape back from linear ids back into higher-dimensional dense result. final_result = array_ops.reshape( diff --git a/tensorflow/python/ops/fused_embedding_ops_v2.py b/tensorflow/python/ops/fused_embedding_ops_v2.py new file mode 100644 index 00000000000..de860fcc868 --- /dev/null +++ b/tensorflow/python/ops/fused_embedding_ops_v2.py @@ -0,0 +1,158 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework.constant_op import constant +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops.kv_variable_ops import EmbeddingVariable +from tensorflow.python.ops.gen_fused_embedding_ops import prune_invalid_and_fill_empty_rows +from tensorflow.python.ops.gen_fused_embedding_ops import unique_with_counts_v3 +from tensorflow.python.ops.gen_fused_embedding_ops import partition_with_permutation +from tensorflow.python.ops.gen_fused_embedding_ops import fused_embedding_sparse_post_look_up_v2 +from tensorflow.python.ops.gen_fused_embedding_ops import fused_embedding_sparse_post_look_up_v2_grad +from tensorflow.python.util.tf_export import tf_export + + +@tf_export(v1=["nn.fused_embedding_lookup_sparse_v2"]) +def fused_embedding_lookup_sparse_v2(params, + sp_ids, + sparse_weights=None, + partition_strategy=None, + name=None, + combiner=None, + max_norm=None, + default_id=None, + prune=False, + fill_empty_row=True, + blocknums=None): + + if sparse_weights is not None: + if type(sparse_weights) not in [ops.Tensor, sparse_tensor.SparseTensor]: + raise ValueError("sparse_weights must be Tensor or SparseTensor") + if type(sparse_weights) is sparse_tensor.SparseTensor: + sp_weights_values = sparse_weights.values + else: + sp_weights_values = sparse_weights + use_sparse_weights = True + if combiner != "sum": + prune_sparse_weights = True + else: + use_sparse_weights = False + prune_sparse_weights = False + # dummy + sp_weights_values = constant([1], dtype=dtypes.float32) + + valid_partition_strategy = ['div', 'mod', 'mod_ev'] + if partition_strategy not in valid_partition_strategy: + raise ValueError("{} is not supported yet. Currently only support {}".format( + partition_strategy, valid_partition_strategy)) + + if blocknums is not None: + raise ValueError("Using blocknums for DynamicEmbeddingVariable is not supported yet") + + if default_id is not None and type(default_id) is not int: + raise ValueError("default_id must be a integer!") + + params_white_list = [EmbeddingVariable, ops.Tensor] + if any([type(param) not in params_white_list for param in params]): + raise ValueError("Currently fused embedding only support: {}".format(params_white_list)) + + partition_nums = len(params) + + if type(params[0]) is EmbeddingVariable: + partition_strategy = 'mod_ev' + partition_shapes = [constant([1, 1], dtype=dtypes.int64) for _ in params] # dummy + else: + partition_shapes = [w.shape for w in params] + + with ops.name_scope(name, "fused_embedding_lookup_sparse", + params + [sp_ids]) as name: + + sp_values = sp_ids.values + sp_indices = sp_ids.indices + sp_dense_shape = sp_ids.dense_shape + + if prune or fill_empty_row: + sp_values, sp_indices, sp_weights_values, is_row_empty = prune_invalid_and_fill_empty_rows( + fill_empty_row=fill_empty_row, + prune=prune, + default_id=default_id, + use_sparse_weights=use_sparse_weights, + prune_sparse_weights=prune_sparse_weights, + sp_values=sp_values, + sp_indices=sp_indices, + sp_dense_shape=sp_dense_shape, + sp_weights_values=sp_weights_values) + else: + is_row_empty = constant(False, shape=(1, ), dtype=dtypes.bool) # dummy + + unique_keys, unique_idxs, unique_counts = unique_with_counts_v3( + input=sp_values, + CounterType=dtypes.int32, + ) + + if partition_nums > 1: + partitioned_values, partition_permutation = partition_with_permutation( + partition_strategy=partition_strategy, + input=unique_keys, + partition_shapes=partition_shapes + ) + else: + partitioned_values = [unique_keys] + partition_permutation = constant(0, shape=(1, 1), dtype=dtypes.int32) # dummy + + emb_shards = [] + for i in range(partition_nums): + with ops.colocate_with(params[i]): + shard = array_ops.gather(params[i], partitioned_values[i], counts=unique_counts) + emb_shards.append(shard) + + emb_vectors, _, _ = fused_embedding_sparse_post_look_up_v2( + fill_empty_row=fill_empty_row, default_id=default_id, + combiner=combiner, max_norm=max_norm, + use_sparse_weights=use_sparse_weights, + emb_shards=emb_shards, partition_permutation=partition_permutation, + sp_dense_shape=sp_dense_shape, + indices_before_unique=sp_indices, + is_row_empty=is_row_empty, + unique_idxs=unique_idxs, + sp_weights_values=sp_weights_values + ) + + return emb_vectors + + +@ops.RegisterGradient("FusedEmbeddingSparsePostLookUpV2") +def fused_embedding_sparse_post_look_up_v2_gradient(op, top_grad, + uesless_grad_1, uesless_grad_2): + num_partitions = op.get_attr("num_partitions") + combiner = op.get_attr("combiner") + max_norm = op.get_attr("max_norm") + fill_empty_row = op.get_attr("fill_empty_row") + default_id = op.get_attr("default_id") + use_sparse_weights = op.get_attr("use_sparse_weights") + + emb_shards = [op.inputs[i] for i in range(0, num_partitions)] + partition_permutation = op.inputs[num_partitions] + # sp_dense_shape = op.inputs[num_partitions + 1] + indices_before_unique = op.inputs[num_partitions + 2] + is_row_empty = op.inputs[num_partitions + 3] + unique_idxs = op.inputs[num_partitions + 4] + sp_weights_values = op.inputs[num_partitions + 5] + + feature_nums = op.outputs[1] + emb_shard_ptrs = op.outputs[2] + + grad_shards = fused_embedding_sparse_post_look_up_v2_grad( + fill_empty_row=fill_empty_row, default_id=default_id, + combiner=combiner, max_norm=max_norm, use_sparse_weights=use_sparse_weights, + top_grad=top_grad, emb_shards=emb_shards, + emb_shard_ptrs=emb_shard_ptrs, + partition_permutation=partition_permutation, + feature_nums=feature_nums, indices_before_unique=indices_before_unique, + unique_idxs=unique_idxs, is_row_empty=is_row_empty, sp_weights_values=sp_weights_values) + + return grad_shards + [None for _ in range(len(op.inputs) - num_partitions)] diff --git a/tensorflow/python/tpu/feature_column_v2.py b/tensorflow/python/tpu/feature_column_v2.py index 0e52ff84c06..1ec2386c233 100644 --- a/tensorflow/python/tpu/feature_column_v2.py +++ b/tensorflow/python/tpu/feature_column_v2.py @@ -290,7 +290,7 @@ def __new__(cls, max_norm=None, trainable=True, coalesced_scope=None, - do_fusion=False) + do_fusion=None) def __init__(self, categorical_column, diff --git a/tensorflow/tools/api/golden/v1/tensorflow.feature_column.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.feature_column.pbtxt index 72b386f42c3..af881226ad5 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.feature_column.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.feature_column.pbtxt @@ -62,7 +62,7 @@ tf_module { } member_method { name: "embedding_column" - argspec: "args=[\'categorical_column\', \'dimension\', \'combiner\', \'initializer\', \'ckpt_to_load_from\', \'tensor_name_in_ckpt\', \'max_norm\', \'trainable\', \'coalesced_scope\', \'do_fusion\'], varargs=None, keywords=None, defaults=[\'mean\', \'None\', \'None\', \'None\', \'None\', \'True\', \'None\', \'False\'], " + argspec: "args=[\'categorical_column\', \'dimension\', \'combiner\', \'initializer\', \'ckpt_to_load_from\', \'tensor_name_in_ckpt\', \'max_norm\', \'trainable\', \'coalesced_scope\', \'do_fusion\'], varargs=None, keywords=None, defaults=[\'mean\', \'None\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], " } member_method { name: "hash_table_column" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.pbtxt index 0c5946d9d49..928ffc915f3 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.pbtxt @@ -224,6 +224,10 @@ tf_module { name: "fused_embedding_lookup_sparse" argspec: "args=[\'params\', \'sp_ids\', \'sparse_weights\', \'partition_strategy\', \'name\', \'combiner\', \'max_norm\', \'default_id\', \'prune_invalid_ids\', \'blocknums\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'None\'], " } + member_method { + name: "fused_embedding_lookup_sparse_v2" + argspec: "args=[\'params\', \'sp_ids\', \'sparse_weights\', \'partition_strategy\', \'name\', \'combiner\', \'max_norm\', \'default_id\', \'prune\', \'fill_empty_row\', \'blocknums\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'True\', \'None\'], " + } member_method { name: "gelu" argspec: "args=[\'features\', \'approximate\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.feature_column.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.feature_column.pbtxt index 806b88828bc..33b8ac736a8 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.feature_column.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.feature_column.pbtxt @@ -62,7 +62,7 @@ tf_module { } member_method { name: "embedding_column" - argspec: "args=[\'categorical_column\', \'dimension\', \'combiner\', \'initializer\', \'ckpt_to_load_from\', \'tensor_name_in_ckpt\', \'max_norm\', \'trainable\', \'coalesced_scope\', \'do_fusion\'], varargs=None, keywords=None, defaults=[\'mean\', \'None\', \'None\', \'None\', \'None\', \'True\', \'None\', \'False\'], " + argspec: "args=[\'categorical_column\', \'dimension\', \'combiner\', \'initializer\', \'ckpt_to_load_from\', \'tensor_name_in_ckpt\', \'max_norm\', \'trainable\', \'coalesced_scope\', \'do_fusion\'], varargs=None, keywords=None, defaults=[\'mean\', \'None\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], " } member_method { name: "hash_table_column"