Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add vitpose_wholebody #9283

Open
wants to merge 6 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README_cn.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ PaddleDetection是一个基于PaddlePaddle的目标检测端到端开发套件
* 🎨 [**模型丰富一键调用**](docs/paddlex/quick_start.md):将通用目标检测、小目标检测和实例分割涉及的**55个模型**整合为3条模型产线,通过极简的**Python API一键调用**,快速体验模型效果。此外,同一套API,也支持图像分类、图像分割、文本图像智能分析、通用OCR、时序预测等共计**200+模型**,形成20+单功能模块,方便开发者进行**模型组合使用**。
* 🚀 [**提高效率降低门槛**](docs/paddlex/overview.md):提供基于**统一命令**和**图形界面**两种方式,实现模型简洁高效的使用、组合与定制。支持**高性能部署、服务化部署和端侧部署**等多种部署方式。此外,对于各种主流硬件如**英伟达GPU、昆仑芯、昇腾、寒武纪和海光**等,进行模型开发时,都可以**无缝切换**。

* 添加实例分割SOTA模型[**Mask-RT-DETR**](https://github.com/PaddlePaddle/PaddleX/blob/release/3.0-beta1/docs/module_usage/tutorials/cv_modules/instance_segmentation.md)
* 添加实例分割SOTA模型[**Mask-RT-DETR**](https://paddlepaddle.github.io/PaddleX/latest/module_usage/tutorials/cv_modules/instance_segmentation.html)

**🔥超越YOLOv8,飞桨推出精度最高的实时检测器RT-DETR!**

Expand Down
10 changes: 5 additions & 5 deletions deploy/python/det_keypoint_unite_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,15 @@
from preprocess import decode_image
from infer import Detector, DetectorPicoDet, PredictConfig, print_arguments, get_test_images, bench_log
from keypoint_infer import KeyPointDetector, PredictConfig_KeyPoint
from visualize import visualize_pose
from visualize import visualize_pose, visualize_pose_point131
from benchmark_utils import PaddleInferBenchmark
from utils import get_current_memory_mb
from keypoint_postprocess import translate_to_ori_images

KEYPOINT_SUPPORT_MODELS = {
'HigherHRNet': 'keypoint_bottomup',
'HRNet': 'keypoint_topdown'
'HRNet': 'keypoint_topdown',
'VitPose_TopDown_WholeBody': 'keypoint_topdown_wholebody'
}


Expand Down Expand Up @@ -178,7 +179,7 @@ def topdown_unite_predict_video(detector,

keypoint_res['keypoint'][0][0] = smooth_keypoints.tolist()

im = visualize_pose(
im = visualize_pose_point131(
frame,
keypoint_res,
visual_thresh=FLAGS.keypoint_threshold,
Expand Down Expand Up @@ -329,8 +330,7 @@ def main():
enable_mkldnn=FLAGS.enable_mkldnn,
use_dark=FLAGS.use_dark)
keypoint_arch = topdown_keypoint_detector.pred_config.arch
assert KEYPOINT_SUPPORT_MODELS[
keypoint_arch] == 'keypoint_topdown', 'Detection-Keypoint unite inference only supports topdown models.'
assert KEYPOINT_SUPPORT_MODELS[keypoint_arch] == 'keypoint_topdown' or KEYPOINT_SUPPORT_MODELS[keypoint_arch] == 'keypoint_topdown_wholebody', 'Detection-Keypoint unite inference only supports topdown models.'

# predict from video file or camera video stream
if FLAGS.video_file is not None or FLAGS.camera_id != -1:
Expand Down
2 changes: 1 addition & 1 deletion deploy/python/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from benchmark_utils import PaddleInferBenchmark
from picodet_postprocess import PicoDetPostProcess
from preprocess import preprocess, Resize, NormalizeImage, Permute, PadStride, LetterBoxResize, WarpAffine, Pad, decode_image, CULaneResize
from keypoint_preprocess import EvalAffine, TopDownEvalAffine, expand_crop
from keypoint_preprocess import EvalAffine, TopDownEvalAffine, TopDownAffineImage, expand_crop
from clrnet_postprocess import CLRNetPostProcess
from visualize import visualize_box_mask, imshow_lanes
from utils import argsparser, Timer, get_current_memory_mb, multiclass_nms, coco_clsid2catid
Expand Down
49 changes: 48 additions & 1 deletion deploy/python/keypoint_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,40 @@
# Global dictionary
KEYPOINT_SUPPORT_MODELS = {
'HigherHRNet': 'keypoint_bottomup',
'HRNet': 'keypoint_topdown'
'HRNet': 'keypoint_topdown',
'VitPose_TopDown_WholeBody': 'keypoint_topdown_wholebody'
}


def _box2cs(image_size, box):
"""This encodes bbox(x,y,w,h) into (center, scale)

Args:
x, y, w, h

Returns:
tuple: A tuple containing center and scale.

- np.ndarray[float32](2,): Center of the bbox (x, y).
- np.ndarray[float32](2,): Scale of the bbox w & h.
"""

x, y, w, h = box[:4]
input_size = image_size
aspect_ratio = input_size[0] / input_size[1]
center = np.array([x + w * 0.5, y + h * 0.5], dtype=np.float32)

if w > aspect_ratio * h:
h = w * 1.0 / aspect_ratio
elif w < aspect_ratio * h:
w = h * aspect_ratio

# pixel std is 200.0
scale = np.array([w / 200.0, h / 200.0], dtype=np.float32)
scale = scale * 1.25

return center, scale

class KeyPointDetector(Detector):
"""
Args:
Expand Down Expand Up @@ -137,6 +167,23 @@ def postprocess(self, inputs, result):
imshape = inputs['im_shape'][:, ::-1]
center = np.round(imshape / 2.)
scale = imshape / 200.
keypoint_postprocess = HRNetPostProcess(use_dark=self.use_dark)
kpts, scores = keypoint_postprocess(np_heatmap, center, scale)
results['keypoint'] = kpts
results['score'] = scores
return results
elif KEYPOINT_SUPPORT_MODELS[
self.pred_config.arch] == 'keypoint_topdown_wholebody':
results = {}
imshape = inputs['im_shape'][:, ::-1]
center = []
scale = []
for i in range(len(inputs['im_shape'])):
transize = np.shape(inputs["image"])
tmp_center, tmp_scale = _box2cs([np.shape(inputs["image"])[-1],np.shape(inputs["image"])[-2]], [0,0,inputs['im_shape'][i][1],inputs['im_shape'][i][0]] )
center.append(tmp_center)
scale.append(tmp_scale)

keypoint_postprocess = HRNetPostProcess(use_dark=self.use_dark)
kpts, scores = keypoint_postprocess(np_heatmap, center, scale)
results['keypoint'] = kpts
Expand Down
77 changes: 77 additions & 0 deletions deploy/python/keypoint_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,83 @@
import numpy as np


def _box2cs(image_size, box):
"""This encodes bbox(x,y,w,h) into (center, scale)

Args:
x, y, w, h

Returns:
tuple: A tuple containing center and scale.

- np.ndarray[float32](2,): Center of the bbox (x, y).
- np.ndarray[float32](2,): Scale of the bbox w & h.
"""

x, y, w, h = box[:4]
input_size = image_size
aspect_ratio = input_size[0] / input_size[1]
center = np.array([x + w * 0.5, y + h * 0.5], dtype=np.float32)

if w > aspect_ratio * h:
h = w * 1.0 / aspect_ratio
elif w < aspect_ratio * h:
w = h * aspect_ratio

# pixel std is 200.0
scale = np.array([w / 200.0, h / 200.0], dtype=np.float32)
scale = scale * 1.25

return center, scale

class TopDownAffineImage(object):
"""apply affine transform to image and coords

Args:
trainsize (list): [w, h], the standard size used to train
use_udp (bool): whether to use Unbiased Data Processing.
records(dict): the dict contained the image and coords

Returns:
records (dict): contain the image and coords after tranformed

"""

def __init__(self, trainsize, use_udp=False, use_box2cs=True):
self.trainsize = trainsize
self.use_udp = use_udp
self.use_box2cs = use_box2cs

def __call__(self, records, im_info):
if self.use_box2cs:
center, scale = _box2cs(self.trainsize, [0,0,im_info['im_shape'][1],im_info['im_shape'][0]])
else:
imshape = im_info['im_shape'][::-1]
center = im_info['center'] if 'center' in im_info else imshape / 2.
scale = im_info['scale'] if 'scale' in im_info else imshape

image = records
rot = records['rotate'] if "rotate" in records else 0
if self.use_udp:
trans = get_warp_matrix(
rot, center * 2.0,
[self.trainsize[0] - 1.0, self.trainsize[1] - 1.0],
scale * 200.0)
image = cv2.warpAffine(
image,
trans, (int(self.trainsize[0]), int(self.trainsize[1])),
flags=cv2.INTER_LINEAR)
joints[:, 0:2] = warp_affine_joints(joints[:, 0:2].copy(), trans)
else:
trans = get_affine_transform(center, scale *
200, rot, self.trainsize)
image = cv2.warpAffine(
image,
trans, (int(self.trainsize[0]), int(self.trainsize[1])),
flags=cv2.INTER_LINEAR)
return image, im_info


class EvalAffine(object):
def __init__(self, size, stride=64):
super(EvalAffine, self).__init__()
Expand Down
51 changes: 51 additions & 0 deletions deploy/python/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@
import numpy as np
import PIL
from PIL import Image, ImageDraw, ImageFile
import json


from mmengine.structures import InstanceData
from mmpose.structures import PoseDataSample
from mmpose.visualization import PoseLocalVisualizer

from mmpose.structures import merge_data_samples, split_instances

ImageFile.LOAD_TRUNCATED_IMAGES = True

def imagedraw_textsize_c(draw, text):
Expand Down Expand Up @@ -235,6 +244,48 @@ def get_color(idx):
return color


def visualize_pose_point131(imgfile,
results,
visual_thresh=0.3,
save_name='pose.jpg',
save_dir='output',
returnimg=False,
ids=None):
pose_local_visualizer = PoseLocalVisualizer(vis_backends= [{'type': 'LocalVisBackend'}], name= 'visualizer', radius= 3, alpha= 0.8, line_width= 1)
# with open("/paddle/mmpose-dev-1.x/dataset_meta.json", 'r') as f:
with open("deploy/python/dataset_meta.json", 'r') as f:
meta_data = json.load(f)

pred_instances = InstanceData()
pose_local_visualizer.set_dataset_meta(meta_data, skeleton_style="mmpose")
image = cv2.imread(imgfile) if type(imgfile) == str else imgfile
skeletons, score = results['keypoint']
keypoints = []
scores = []
for i in range(len(skeletons[0])):
keypoints.append([skeletons[0][i][0], skeletons[0][i][1]])
scores.append(skeletons[0][i][2])
keypoints = [keypoints]
skeletons = np.array(skeletons)
scores = np.array(scores)
pred_instances.keypoints = skeletons

pred_pose_data_sample = PoseDataSample()
pred_pose_data_sample.pred_instances = pred_instances

blank_image = np.zeros(image.shape, dtype=np.uint8)
pose_local_visualizer.add_datasample('image', blank_image, data_sample=pred_pose_data_sample,
draw_gt=False,
draw_heatmap=False,
draw_bbox=True,
show_kpt_idx=False,
skeleton_style='mmpose',
show=False,
wait_time=0,
kpt_thr=visual_thresh)

return pose_local_visualizer.get_image()

def visualize_pose(imgfile,
results,
visual_thresh=0.6,
Expand Down
18 changes: 9 additions & 9 deletions docs/paddlex/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

## 1. 低代码全流程开发简介

飞桨低代码开发工具[PaddleX](https://github.com/PaddlePaddle/PaddleX/tree/release/3.0-beta1),依托于PaddleDetection的先进技术,支持了目标检测领域的**低代码全流程**开发能力。通过低代码全流程开发,可实现简单且高效的模型使用、组合与定制。这将显著**减少模型开发的时间消耗**,**降低其开发难度**,大大加快模型在行业中的应用和推广速度。特色如下:
飞桨低代码开发工具[PaddleX](https://github.com/PaddlePaddle/PaddleX),依托于PaddleDetection的先进技术,支持了目标检测领域的**低代码全流程**开发能力。通过低代码全流程开发,可实现简单且高效的模型使用、组合与定制。这将显著**减少模型开发的时间消耗**,**降低其开发难度**,大大加快模型在行业中的应用和推广速度。特色如下:

* 🎨 **模型丰富一键调用**:将通用目标检测、小目标检测和实例分割涉及的**55个模型**整合为3条模型产线,通过极简的**Python API一键调用**,快速体验模型效果。此外,同一套API,也支持图像分类、图像分割、文本图像智能分析、通用OCR、时序预测等共计**200+模型**,形成20+单功能模块,方便开发者进行**模型组合使用**。

Expand All @@ -21,7 +21,7 @@

## 2. 目标检测相关能力支持

PaddleX中目标检测领域相关的3条产线均支持本地**快速推理**,部分产线支持**在线体验**,您可以快速体验各个产线的预训练模型效果,如果您对产线的预训练模型效果满意,可以直接对产线进行[高性能部署](https://github.com/PaddlePaddle/PaddleX/blob/release/3.0-beta1/docs/pipeline_deploy/high_performance_deploy.md)/[服务化部署](https://github.com/PaddlePaddle/PaddleX/blob/release/3.0-beta1/docs/pipeline_deploy/service_deploy.md)/[端侧部署](https://github.com/PaddlePaddle/PaddleX/blob/release/3.0-beta1/docs/pipeline_deploy/lite_deploy.md),如果不满意,您也可以使用产线的**二次开发**能力,提升效果。完整的产线开发流程请参考[PaddleX产线使用概览](https://github.com/PaddlePaddle/PaddleX/blob/release/3.0-beta1/docs/pipeline_usage/pipeline_develop_guide.md)或各产线使用教程。
PaddleX中目标检测领域相关的3条产线均支持本地**快速推理**,部分产线支持**在线体验**,您可以快速体验各个产线的预训练模型效果,如果您对产线的预训练模型效果满意,可以直接对产线进行[高性能推理](https://paddlepaddle.github.io/PaddleX/latest/pipeline_deploy/high_performance_inference.html)/[服务化部署](https://paddlepaddle.github.io/PaddleX/latest/pipeline_deploy/service_deploy.html)/[端侧部署](https://paddlepaddle.github.io/PaddleX/latest/pipeline_deploy/edge_deploy.html),如果不满意,您也可以使用产线的**二次开发**能力,提升效果。完整的产线开发流程请参考[PaddleX产线使用概览](https://paddlepaddle.github.io/PaddleX/latest/pipeline_usage/pipeline_develop_guide.html)或各产线使用教程。

此外,PaddleX为开发者提供了基于[云端图形化开发界面](https://aistudio.baidu.com/pipeline/mine)的全流程开发工具, 详细请参考[教程《零门槛开发产业级AI模型》](https://aistudio.baidu.com/practical/introduce/546656605663301)

Expand Down Expand Up @@ -69,7 +69,7 @@ PaddleX中目标检测领域相关的3条产线均支持本地**快速推理**
</tr>
</table>

> ❗注:以上功能均基于GPU/CPU实现。PaddleX还可在昆仑、昇腾、寒武纪和海光等主流硬件上进行快速推理和二次开发。下表详细列出了模型产线的支持情况,具体支持的模型列表请参阅 [模型列表(NPU)](https://github.com/PaddlePaddle/PaddleX/blob/release/3.0-beta1/docs/support_list/model_list_npu.md) // [模型列表(XPU)](https://github.com/PaddlePaddle/PaddleX/blob/release/3.0-beta1/docs/support_list/model_list_xpu.md) // [模型列表(MLU)](https://github.com/PaddlePaddle/PaddleX/blob/release/3.0-beta1/docs/support_list/model_list_mlu.md) // [模型列表DCU](https://github.com/PaddlePaddle/PaddleX/blob/release/3.0-beta1/docs/support_list/model_list_dcu.md)。同时我们也在适配更多的模型,并在主流硬件上推动高性能和服务化部署的实施。
> ❗注:以上功能均基于GPU/CPU实现。PaddleX还可在昆仑、昇腾、寒武纪和海光等主流硬件上进行快速推理和二次开发。下表详细列出了模型产线的支持情况,具体支持的模型列表请参阅 [模型列表(NPU)](https://paddlepaddle.github.io/PaddleX/latest/support_list/model_list_npu.html) // [模型列表(XPU)](https://paddlepaddle.github.io/PaddleX/latest/support_list/model_list_xpu.html) // [模型列表(MLU)](https://paddlepaddle.github.io/PaddleX/latest/support_list/model_list_mlu.html) // [模型列表DCU](https://paddlepaddle.github.io/PaddleX/latest/support_list/model_list_dcu.html)。同时我们也在适配更多的模型,并在主流硬件上推动高性能和服务化部署的实施。


**🚀 国产化硬件能力支持**
Expand Down Expand Up @@ -102,16 +102,16 @@ PaddleX中目标检测领域相关的3条产线均支持本地**快速推理**

## 3. 目标检测相关模型产线列表和教程

- **通用目标检测产线**: [使用教程](https://github.com/PaddlePaddle/PaddleX/blob/release/3.0-beta1/docs/pipeline_usage/tutorials/cv_pipelines/object_detection.md)
- **通用实例分割产线**: [使用教程](https://github.com/PaddlePaddle/PaddleX/blob/release/3.0-beta1/docs/pipeline_usage/tutorials/cv_pipelines/instance_segmentation.md)
- **小目标检测产线**: [使用教程](https://github.com/PaddlePaddle/PaddleX/blob/release/3.0-beta1/docs/pipeline_usage/tutorials/cv_pipelines/small_object_detection.md)
- **通用目标检测产线**: [使用教程](https://paddlepaddle.github.io/PaddleX/latest/pipeline_usage/tutorials/cv_pipelines/object_detection.html)
- **通用实例分割产线**: [使用教程](https://paddlepaddle.github.io/PaddleX/latest/pipeline_usage/tutorials/cv_pipelines/instance_segmentation.html)
- **小目标检测产线**: [使用教程](https://paddlepaddle.github.io/PaddleX/latest/pipeline_usage/tutorials/cv_pipelines/small_object_detection.html)


<a name="4"></a>

## 4. 目标检测相关单功能模块列表和教程

- **目标检测模块**: [使用教程](https://github.com/PaddlePaddle/PaddleX/blob/release/3.0-beta1/docs/module_usage/tutorials/cv_modules/object_detection.md)
- **实例分割模块**: [使用教程](https://github.com/PaddlePaddle/PaddleX/blob/release/3.0-beta1/docs/module_usage/tutorials/cv_modules/instance_segmentation.md)
- **小目标检测模块**: [使用教程](https://github.com/PaddlePaddle/PaddleX/blob/release/3.0-beta1/docs/module_usage/tutorials/cv_modules/small_object_detection.md)
- **目标检测模块**: [使用教程](https://paddlepaddle.github.io/PaddleX/latest/pipeline_usage/tutorials/cv_pipelines/object_detection.html)
- **实例分割模块**: [使用教程](https://paddlepaddle.github.io/PaddleX/latest/pipeline_usage/tutorials/cv_pipelines/instance_segmentation.html)
- **小目标检测模块**: [使用教程](https://paddlepaddle.github.io/PaddleX/latest/pipeline_usage/tutorials/cv_pipelines/small_object_detection.html)

Loading