Skip to content

Commit

Permalink
605 modifying object attributes (#607)
Browse files Browse the repository at this point in the history
* Reformat.

* Support object attribute replace/delete.
  • Loading branch information
denisvmedyantsev authored Jan 12, 2024
1 parent b3a12f9 commit b13a4f7
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 19 deletions.
14 changes: 6 additions & 8 deletions samples/panoptic_driving_perception/pytorch_infer.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
"""Custom PyFunc implementation inference PyTorch model."""
from savant_rs.primitives.geometry import BBox

from savant.meta.object import ObjectMeta
from savant.utils.memory_repr_pytorch import (
pytorch_tensor_as_opencv_gpu_mat,
opencv_gpu_mat_as_pytorch_tensor,
)

import cv2
import torch
import torchvision
import torchvision.transforms as transforms
from savant_rs.primitives.geometry import BBox

from savant.deepstream.meta.frame import NvDsFrameMeta
from savant.deepstream.opencv_utils import alpha_comp, nvds_to_gpu_mat
from savant.deepstream.pyfunc import NvDsPyFuncPlugin
from savant.gstreamer import Gst
from savant.meta.object import ObjectMeta
from savant.utils.memory_repr_pytorch import (
opencv_gpu_mat_as_pytorch_tensor,
pytorch_tensor_as_opencv_gpu_mat,
)


class PyTorchInfer(NvDsPyFuncPlugin):
Expand Down
6 changes: 4 additions & 2 deletions savant/config/module_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,10 @@ def source_element_configurator(
"""Additional configuration steps for SourceElements."""
# if dev mode is enabled in the module parameters
# set dev mode for the ingress filter function
if (element_config.ingress_frame_filter is not None
and module_config.parameters.dev_mode):
if (
element_config.ingress_frame_filter is not None
and module_config.parameters.dev_mode
):
logger.debug(
'Setting dev mode for ingress filter of SourceElement named "%s" to True.',
element_config.name,
Expand Down
39 changes: 37 additions & 2 deletions savant/deepstream/meta/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
nvds_get_obj_draw_label,
nvds_get_obj_uid,
nvds_init_obj_draw_label,
nvds_remove_obj_attr_meta_list,
nvds_replace_obj_attr_meta_list,
nvds_set_obj_bbox,
nvds_set_obj_draw_label,
nvds_set_obj_uid,
Expand Down Expand Up @@ -127,7 +129,7 @@ def get_attr_meta(
return nvds_get_obj_attr_meta(
frame_meta=self._frame_meta,
obj_meta=self.ds_object_meta,
model_name=element_name,
element_name=element_name,
attr_name=attr_name,
)

Expand All @@ -143,7 +145,37 @@ def get_attr_meta_list(
return nvds_get_obj_attr_meta_list(
frame_meta=self._frame_meta,
obj_meta=self.ds_object_meta,
model_name=element_name,
element_name=element_name,
attr_name=attr_name,
)

def replace_attr_meta_list(
self, element_name: str, attr_name: str, value: List[AttributeMeta]
):
"""Replaces the object's specified attributes with a new list.
:param element_name: Attribute model name.
:param attr_name: Attribute name.
:param value: List of AttributeMeta.
"""
nvds_replace_obj_attr_meta_list(
frame_meta=self._frame_meta,
obj_meta=self.ds_object_meta,
element_name=element_name,
attr_name=attr_name,
value=value,
)

def remove_attr_meta_list(self, element_name: str, attr_name: str):
"""Removes the object's specified attributes.
:param element_name: Attribute model name.
:param attr_name: Attribute name.
"""
nvds_remove_obj_attr_meta_list(
frame_meta=self._frame_meta,
obj_meta=self.ds_object_meta,
element_name=element_name,
attr_name=attr_name,
)

Expand All @@ -153,13 +185,15 @@ def add_attr_meta(
name: str,
value: Any,
confidence: float = 1.0,
replace: bool = False,
):
"""Adds specified object attribute to object meta.
:param element_name: attribute model name.
:param name: attribute name.
:param value: attribute value.
:param confidence: attribute confidence.
:param replace: replace attribute if it already exists.
"""
nvds_add_attr_meta_to_obj(
frame_meta=self._frame_meta,
Expand All @@ -168,6 +202,7 @@ def add_attr_meta(
name=name,
value=value,
confidence=confidence,
replace=replace,
)

@property
Expand Down
2 changes: 2 additions & 0 deletions savant/deepstream/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
nvds_attr_meta_iterator,
nvds_get_obj_attr_meta,
nvds_get_obj_attr_meta_list,
nvds_remove_obj_attr_meta_list,
nvds_remove_obj_attrs,
nvds_replace_obj_attr_meta_list,
)
from .event import (
GST_NVEVENT_INFER_INTERVAL_UPDATE,
Expand Down
59 changes: 52 additions & 7 deletions savant/deepstream/utils/attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def nvds_add_attr_meta_to_obj( # pylint: disable=too-many-arguments
name: str,
value: Any,
confidence: float = 1.0,
replace: bool = False,
):
"""Adds attribute to the object.
Expand All @@ -27,11 +28,12 @@ def nvds_add_attr_meta_to_obj( # pylint: disable=too-many-arguments
:param name: attribute name.
:param value: attribute value.
:param confidence: object confidence.
:param replace: replace existing attribute.
"""
skey = nvds_get_obj_uid(frame_meta, obj_meta)
if skey not in NVDS_OBJ_ATTR_STORAGE:
NVDS_OBJ_ATTR_STORAGE[skey] = {}
if (element_name, name) not in NVDS_OBJ_ATTR_STORAGE[skey]:
if (element_name, name) not in NVDS_OBJ_ATTR_STORAGE[skey] or replace:
NVDS_OBJ_ATTR_STORAGE[skey][(element_name, name)] = []
NVDS_OBJ_ATTR_STORAGE[skey][(element_name, name)].append(
AttributeMeta(
Expand Down Expand Up @@ -61,42 +63,85 @@ def nvds_attr_meta_iterator(
def nvds_get_obj_attr_meta_list(
frame_meta: pyds.NvDsFrameMeta,
obj_meta: pyds.NvDsObjectMeta,
model_name: str,
element_name: str,
attr_name: str,
) -> Optional[List[AttributeMeta]]:
"""Returns specified object attribute values (multi-label case).
:param frame_meta: object parent frame.
:param obj_meta: object metadata.
:param model_name: element name that created this attribute.
:param element_name: element name that created this attribute.
:param attr_name: attribute name.
:return: List of AttributeMeta/None
"""
skey = nvds_get_obj_uid(frame_meta, obj_meta)
if skey not in NVDS_OBJ_ATTR_STORAGE:
return None
return NVDS_OBJ_ATTR_STORAGE[skey].get((model_name, attr_name), None)
return NVDS_OBJ_ATTR_STORAGE[skey].get((element_name, attr_name), None)


def nvds_get_obj_attr_meta(
frame_meta: pyds.NvDsFrameMeta,
obj_meta: pyds.NvDsObjectMeta,
model_name: str,
element_name: str,
attr_name: str,
) -> Optional[AttributeMeta]:
"""Returns the first value (the first and only except in the case of a
multi-label) for specified object attribute.
:param frame_meta: object parent frame.
:param obj_meta: object metadata.
:param model_name: element name that created this attribute.
:param element_name: element name that created this attribute.
:param attr_name: attribute name.
:return: AttributeMeta/None
"""
attrs = nvds_get_obj_attr_meta_list(frame_meta, obj_meta, model_name, attr_name)
attrs = nvds_get_obj_attr_meta_list(frame_meta, obj_meta, element_name, attr_name)
return attrs[0] if attrs else None


def nvds_replace_obj_attr_meta_list(
frame_meta: pyds.NvDsFrameMeta,
obj_meta: pyds.NvDsObjectMeta,
element_name: str,
attr_name: str,
value: List[AttributeMeta],
):
"""Replaces specified object attribute values.
:param frame_meta: object parent frame.
:param obj_meta: object metadata.
:param element_name: element name that created this attribute.
:param attr_name: attribute name.
:param value: new attribute value, list.
"""
skey = nvds_get_obj_uid(frame_meta, obj_meta)
if skey not in NVDS_OBJ_ATTR_STORAGE:
NVDS_OBJ_ATTR_STORAGE[skey] = {}
for attr in value:
assert attr.element_name == element_name
assert attr.name == attr_name
NVDS_OBJ_ATTR_STORAGE[skey][(element_name, attr_name)] = value


def nvds_remove_obj_attr_meta_list(
frame_meta: pyds.NvDsFrameMeta,
obj_meta: pyds.NvDsObjectMeta,
element_name: str,
attr_name: str,
):
"""Removes specified object attribute values.
:param frame_meta: object parent frame.
:param obj_meta: object metadata.
:param element_name: element name that created this attribute.
:param attr_name: attribute name.
"""
skey = nvds_get_obj_uid(frame_meta, obj_meta)
if skey in NVDS_OBJ_ATTR_STORAGE:
if (element_name, attr_name) in NVDS_OBJ_ATTR_STORAGE[skey]:
del NVDS_OBJ_ATTR_STORAGE[skey][(element_name, attr_name)]


def nvds_remove_obj_attrs(
frame_meta: pyds.NvDsFrameMeta,
obj_meta: pyds.NvDsObjectMeta,
Expand Down

0 comments on commit b13a4f7

Please sign in to comment.