Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

YOLOV8 port to keras-hub #1899

Open
wants to merge 61 commits into
base: master
Choose a base branch
from
Open

YOLOV8 port to keras-hub #1899

wants to merge 61 commits into from

Conversation

oarriaga
Copy link

@oarriaga oarriaga commented Oct 1, 2024

This PR ports YOLOV8 from keras-cv to keras-hub (#176). All necessary YOLOV8 functions are now found inside keras-hub:

  • Add CIOU loss.
  • Add missing masking functionality in the bounding_boxes module.
  • Add multibackend non maximum supression layer.
  • Add label encoder.
  • Build basic abstract object detector task class.
  • Add YOLOV8 backbone and detector.

Missing steps include:

  • Upload previous presets to Kaggle.
  • Remove skipping tests with presets.
  • Add colab with basic functionality.
  • Add weight transfer script from keras-cv to keras-hub.
  • Add training script.

Copy link
Collaborator

@divyashreepathihalli divyashreepathihalli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @oarriaga left some initial comments

keras_hub/src/layers/modeling/non_max_suppression.py Outdated Show resolved Hide resolved
keras_hub/src/layers/modeling/non_max_suppression.py Outdated Show resolved Hide resolved
keras_hub/src/layers/modeling/non_max_suppression.py Outdated Show resolved Hide resolved
keras_hub/src/layers/modeling/non_max_suppression.py Outdated Show resolved Hide resolved
keras_hub/src/layers/modeling/non_max_suppression.py Outdated Show resolved Hide resolved
label_encoder=None,
prediction_decoder=None,
**kwargs,
):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

restructure the code to define all teh layers first, functional model next and config last with these comments
=== Layers ===
.
.
=== Functional model ===
.
.
=== Config ===
.
.
example : https://github.com/keras-team/keras-hub/blob/master/keras_hub/src/models/bert/bert_backbone.py#L92

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Divya, the current model applies multiple blocks of layers. That will imply that we would need to initialize many layers in the constructor. Moreover, the connections between those layers are not so straightforward as in Bert. What would you suggest? Shall we still initialize all layers in the layer block and connect them in the functional block?

def predict_step(self, *args):
outputs = super().predict_step(*args)
if isinstance(outputs, tuple):
return self.decode_predictions(outputs[0], args[-1]), outputs[1]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will model.fit work ?

@@ -0,0 +1,318 @@
import os

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you will need to add preprocessor flow - example follow resnet - https://github.com/keras-team/keras-hub/tree/master/keras_hub/src/models/resnet


@pytest.mark.large # Saving is slow, so mark these large.
def test_saved_model(self):
model = keras_hub.models.YOLOV8Detector(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


# TODO(tirthasheshpatel): Support updating prediction decoder in Keras Core.
@pytest.mark.skip(reason="Missing presets")
@pytest.mark.tf_keras_only
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no tf_keras_only in KerasHub

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates! What's the current progress of the PR? Is it nearly ready to merge?

raise ValueError(
"`bounding_box.mask_invalid_detections()` requires inputs to be "
"Dense tensors. Please call "
"`bounding_box.to_dense(bounding_boxes)` before passing your boxes "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is to_dense actually located?

Copy link
Author

@oarriaga oarriaga Nov 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In keras-cv to_dense was located outside of the training loop when building a tf.data pipeline.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the bounding box code is now moved to keras repo


Example:
```python
images = tf.ones(shape=(1, 512, 512, 3))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please make sure code examples don't have any TF references; use keras.ops or np and the like.

classification_loss='binary_crossentropy',
box_loss='ciou',
optimizer=tf.optimizers.SGD(global_clipnorm=10.0),
jit_compile=False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No compilation support?

Copy link
Author

@oarriaga oarriaga Nov 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One can set the flag to True; however, the latest training runs with keras-cv were not converging when padded boxes were added. I am going through the loss function to find out what could be the issue.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it work with compilation now?

"""
Encodes ground truth boxes to target boxes and class labels for training a
YOLOV8 model. This is an implementation of the Task-aligned sample
assignment scheme proposed in https://arxiv.org/abs/2108.07755.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use markdown-formatted links.

"""Computes target boxes and classes for anchors.

Args:
scores: a Float Tensor of shape (batch_size, num_anchors,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use backticks around code keywords, such as shape tuples. All docstrings are rendered as markdown.

@@ -0,0 +1,257 @@
import keras
import tensorflow as tf
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Long-term, we cannot depend on TF, please remove this import

truth box. Anchors that didn't match with a ground truth
box should be excluded from both class and box losses.
"""
if isinstance(gt_bboxes, tf.RaggedTensor):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid importing TF you can use

def is_tensorflow_ragged(value):
    if hasattr(value, "__class__"):
        return (
            value.__class__.__name__ == "RaggedTensor"
            and "tensorflow.python." in str(value.__class__.__module__)
        )
    return False

@oarriaga
Copy link
Author

oarriaga commented Nov 8, 2024

Hi, thank you! I think it’s nearly ready. Right now, I’m validating the model’s expected convergence with PASCAL, which has turned out to be more challenging than anticipated. After that, the only remaining step will be to go through your comments and incorporate any additional input from Divya.

@divyashreepathihalli
Copy link
Collaborator

@oarriaga can you please add a demo notebook to the PR to verify outputs. What is the inference time of this implementation versus the original implementation?

boxes, iou_threshold, output_size, tile_arg, tile_size
)

selected_boxes, _, output_size, _ = ops.while_loop(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this be vectorized?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a potential performance bottleneck. This part of the code needs to be vectorized.

Copy link
Collaborator

@sineeli sineeli Nov 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. This works with all backends and I used the same one for retinanet but for now if we can include as this is detachable from model easily and not involved in trained and only while predictions.
  2. We have to check if its bottleneck compared to torch and then we can make some changes as they use loop based approach rather than ops.while_loop.
  3. We have to check the model convertions to onxx and then to tensorrt as well because later in the progression of model those are the important aspects we may have to look into.

Reference: https://github.com/pytorch/vision/blob/acbfd8d94d10f989f4540252e92e8855c19f7ff7/torchvision/models/detection/retinanet.py#L518

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @divyashreepathihalli from what I understand NMS seems to be programmed sequentially because of the necessity of carrying a state of unprocessed boxes. From what I have seen, NMS seems to be usually implemented using a while / for loop since boxes are greedily chosen and removed. I don't know if vectorizing this would decrease the computation time given the sequential nature of NMS. Moreover, pytorch's NMS implementation is done in c++ using a double for loop across boxes. Do let me know how you would like me to proceed.

raise ValueError(
"`bounding_box.mask_invalid_detections()` requires inputs to be "
"Dense tensors. Please call "
"`bounding_box.to_dense(bounding_boxes)` before passing your boxes "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the bounding box code is now moved to keras repo

classification_loss='binary_crossentropy',
box_loss='ciou',
optimizer=tf.optimizers.SGD(global_clipnorm=10.0),
jit_compile=False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it work with compilation now?

return xs, {"boxes": ys, "classes": y_classes}


class YOLOV8DetectorTest(TestCase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add generic task test - self.run_task_test

boxes.
Returns:
bounding boxes with proper masking of the boxes according to
`num_detections`. This allows proper interop with non-max supression.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: suppression

from keras_hub.src.tests.test_case import TestCase


class NonMaxSupressionTest(TestCase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT : Suppression


Returns:
iou_suppressed: a tensor of shape [batch_size, num_boxes_with_padding].
iou_diff: a scalar tensor representing whether any box is supressed in
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suppressed update spelling everywhere

alignment score of an anchor box. This is the beta parameter in
equation 9 of https://arxiv.org/pdf/2108.07755.pdf.
epsilon: float, a small number used for numerical stability in division
(to avoid diving by zero), and used as a threshold to eliminate very
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dividing by zero

bounding_box_format: a case-insensitive string (for example, "xyxy").
Each bounding box is defined by these 4 values. For detailed
information on the supported formats, see the [KerasCV bounding box
documentation](https://keras.io/api/keras_cv/bounding_box/formats/).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add link to keras instead

Copy link
Author

@oarriaga oarriaga Nov 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @divyashreepathihalli I was unable to find in the keras documentation the available box formats. I am linking to the keras source code. We can update this once the documentation is available.

low=0,
high=10)
loss = keras_hub.src.models.yolo_v8.ciou_loss.CIoULoss("xyxy")
loss(y_true, y_pred).numpy()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the y_true and y_pred have different shapes - would this not result in a n error?

self.bounding_box_format = bounding_box_format

def call(self, y_true, y_pred):
y_pred = ops.convert_to_tensor(y_pred)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

checking if y_pred is a tensor and the dtype before converting could improve efficiency

f"y_true={y_true.shape[-2]} and number of boxes in "
f"y_pred={y_pred.shape[-2]}."
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

raise error for unsupported bbox format

[4, 5, 5, 6],
[2, 1, 3, 3],
]
expected_loss = 1.03202
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how was this value calculated?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the original value provided by the KerasCV tests.

boxes, iou_threshold, output_size, tile_arg, tile_size
)

selected_boxes, _, output_size, _ = ops.while_loop(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a potential performance bottleneck. This part of the code needs to be vectorized.

Ported from https://github.com/tensorflow/tensorflow/blob/v2.12.0/tensorflow/python/ops/image_ops_impl.py#L5368-L5458

Args:
boxes: a tensor of rank 2 or higher with a shape of [..., num_boxes, 4].
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't the shape always be [batch_size, num_boxes, 4]?

boxes = ops.pad(ops.cast(boxes, "float32"), [[0, 0], [0, pad], [0, 0]])
scores = ops.pad(ops.cast(scores, "float32"), [[0, 0], [0, pad]])
num_boxes_after_padding = num_boxes + pad
num_iterations = num_boxes_after_padding // tile_size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a check here is verify the num_boxes_after_padding does not exceed max_output_size +pad

return x


def build_block(x, block_arg, channels, depth, block_depth, activation):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the function names here are not readable as we have one build_block and another build_blocks. maybe rename this to yolo_block or whatever you think is suitable and rename build_blocks to stackwise_blocks
Also build might lead to more confusion.

label_encoder: Optional. A `YOLOV8LabelEncoder` that is
responsible for transforming input boxes into trainable labels for
YOLOV8Detector. If not provided, a default is provided.
prediction_decoder: Optional. A `keras.layers.Layer` that is
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

document default values in docstring. here and everywhere.

# Only anchors which are inside of relevant GT boxes are considered
# for assignment.
# This is a boolean tensor of shape (B, num_gt_boxes, num_anchors)
matching_anchors_in_gt_boxes = is_anchor_center_within_box(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename all gt_ to ground_truth

@@ -0,0 +1,80 @@
import numpy as np
import pytest
import tensorflow as tf
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no tf code in Keras_hub

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants