diff --git a/keras_cv/layers/preprocessing_3d/waymo/frustum_random_point_feature_noise.py b/keras_cv/layers/preprocessing_3d/waymo/frustum_random_point_feature_noise.py index bd5b959cc5..395943d2e1 100644 --- a/keras_cv/layers/preprocessing_3d/waymo/frustum_random_point_feature_noise.py +++ b/keras_cv/layers/preprocessing_3d/waymo/frustum_random_point_feature_noise.py @@ -1,6 +1,7 @@ # Copyright 2022 Waymo LLC. # # Licensed under the terms in https://github.com/keras-team/keras-cv/blob/master/keras_cv/layers/preprocessing_3d/waymo/LICENSE # noqa: E501 +import logging import tensorflow as tf @@ -98,9 +99,17 @@ def get_random_transformation(self, point_clouds, **kwargs): # frustum. valid_points = point_clouds[0, :, POINTCLOUD_LABEL_INDEX] > 0 num_valid_points = tf.math.reduce_sum(tf.cast(valid_points, tf.int32)) - randomly_select_point_index = tf.random.uniform( - (), minval=0, maxval=num_valid_points, dtype=tf.int32 - ) + try: + randomly_select_point_index = tf.random.uniform( + (), minval=0, maxval=maxval, dtype=tf.int32 + ) + except tf.errors.InvalidArgumentError: + logging.error( + "Skipping frustum random point noise augmentation: No valid " + "point is found." + ) + return {} + randomly_select_frustum_center = tf.boolean_mask( point_clouds[0], valid_points, axis=0 )[randomly_select_point_index, :POINTCLOUD_LABEL_INDEX] @@ -144,6 +153,8 @@ def get_random_transformation(self, point_clouds, **kwargs): def augment_point_clouds_bounding_boxes( self, point_clouds, bounding_boxes, transformation, **kwargs ): + if "point_noise" not in transformation: + return (point_clouds, bounding_boxes) point_noise = transformation["point_noise"] # Do not add noise to points that are protected by setting the