Skip to content

Commit

Permalink
support export blade model for Stgcn (#299)
Browse files Browse the repository at this point in the history
* support blade for stgcn and add unittest
  • Loading branch information
Cathy0908 authored Mar 6, 2023
1 parent c062b01 commit 5c33d9e
Show file tree
Hide file tree
Showing 10 changed files with 294 additions and 137 deletions.
3 changes: 2 additions & 1 deletion configs/pose/hrnet_w48_coco_256x192_udp.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,8 @@
evaluators=[dict(type='CoCoPoseTopDownEvaluator', **evaluator_args)])
]
checkpoint_sync_export = True
export = dict(use_jit=False)
export = dict(type='raw')
# export = dict(type='jit')
# export = dict(
# type='blade',
# blade_config=dict(
Expand Down
12 changes: 12 additions & 0 deletions configs/video_recognition/stgcn/stgcn_80e_ntu60_xsub_keypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,15 @@

log_config = dict(interval=100, hooks=[dict(type='TextLoggerHook')])
checkpoint_config = dict(interval=1)

export = dict(type='raw')
# export = dict(type='jit')
# export = dict(
# type='blade',
# blade_config=dict(
# enable_fp16=True,
# fp16_fallback_op_ratio=0.0,
# customize_op_black_list=[
# 'aten::select', 'aten::index', 'aten::slice', 'aten::view',
# 'aten::upsample', 'aten::clamp', 'aten::clone'
# ]))
174 changes: 78 additions & 96 deletions easycv/apis/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from easycv.file import io
from easycv.framework.errors import NotImplementedError, ValueError
from easycv.models import (DINO, MOCO, SWAV, YOLOX, BEVFormer, Classification,
MoBY, TopDown, build_model)
MoBY, SkeletonGCN, TopDown, build_model)
from easycv.utils.checkpoint import load_checkpoint
from easycv.utils.misc import encode_str_to_tensor

Expand Down Expand Up @@ -68,6 +68,8 @@ def export(cfg, ckpt_path, filename, model=None, **kwargs):
_export_bevformer(model, cfg, filename, **kwargs)
elif isinstance(model, TopDown):
_export_pose_topdown(model, cfg, filename, **kwargs)
elif isinstance(model, SkeletonGCN):
_export_stgcn(model, cfg, filename, **kwargs)
elif hasattr(cfg, 'export') and getattr(cfg.export, 'use_jit', False):
export_jit_model(model, cfg, filename, **kwargs)
return
Expand Down Expand Up @@ -98,6 +100,63 @@ def _export_common(model, cfg, filename):
torch.save(checkpoint, ofile)


def _export_jit_and_blade(model, cfg, filename, dummy_inputs, fp16=False):

def _trace_model():
with torch.no_grad():
if hasattr(model, 'forward_export'):
model.forward = model.forward_export
else:
model.forward = model.forward_test
trace_model = torch.jit.trace(
model,
copy.deepcopy(dummy_inputs),
strict=False,
check_trace=False)
return trace_model

export_type = cfg.export.get('type')
if export_type in ['jit', 'blade']:
if fp16:
with torch.cuda.amp.autocast():
trace_model = _trace_model()
else:
trace_model = _trace_model()
torch.jit.save(trace_model, filename + '.jit')
else:
raise NotImplementedError(f'Not support export type {export_type}!')

if export_type == 'jit':
return

blade_config = cfg.export.get('blade_config')

from easycv.toolkit.blade import blade_env_assert, blade_optimize
assert blade_env_assert()

def _get_blade_model():
blade_model = blade_optimize(
speed_test_model=model,
model=trace_model,
inputs=copy.deepcopy(dummy_inputs),
blade_config=blade_config,
static_opt=False,
min_num_nodes=None,
check_inputs=False,
fp16=fp16)
return blade_model

# optimize model with blade
if fp16:
with torch.cuda.amp.autocast():
blade_model = _get_blade_model()
else:
blade_model = _get_blade_model()

with io.open(filename + '.blade', 'wb') as ofile:
torch.jit.save(blade_model, ofile)


def _export_cls(model, cfg, filename):
""" export cls (cls & metric learning)model and preprocess config
Expand Down Expand Up @@ -540,7 +599,7 @@ def export_jit_model(model, cfg, filename):
torch.jit.save(model_jit, ofile)


def _export_bevformer(model, cfg, filename, fp16=False):
def _export_bevformer(model, cfg, filename, fp16=False, dummy_inputs=None):
if not cfg.adapt_jit:
raise ValueError(
'"cfg.adapt_jit" must be True when export jit trace or blade model.'
Expand Down Expand Up @@ -578,60 +637,10 @@ def _dummy_inputs():
}
return img, img_metas

dummy_inputs = _dummy_inputs()

def _trace_model():
with torch.no_grad():
model.forward = model.forward_export
trace_model = torch.jit.trace(
model, copy.deepcopy(dummy_inputs), check_trace=False)
return trace_model

export_type = cfg.export.get('type')
if export_type in ['jit', 'blade']:
if fp16:
with torch.cuda.amp.autocast():
trace_model = _trace_model()
else:
trace_model = _trace_model()
torch.jit.save(trace_model, filename + '.jit')
else:
raise NotImplementedError(f'Not support export type {export_type}!')

if export_type == 'jit':
return

blade_config = cfg.export.get('blade_config')

from easycv.toolkit.blade import blade_env_assert, blade_optimize
assert blade_env_assert()

def _get_blade_model():
blade_model = blade_optimize(
speed_test_model=model,
model=trace_model,
inputs=copy.deepcopy(dummy_inputs),
blade_config=blade_config,
static_opt=False,
min_num_nodes=None, # 50
check_inputs=False,
fp16=fp16)
return blade_model

# optimize model with blade
if fp16:
with torch.cuda.amp.autocast():
blade_model = _get_blade_model()
else:
blade_model = _get_blade_model()
if dummy_inputs is None:
dummy_inputs = _dummy_inputs()

# save blade code and graph
# with io.open(filename + '.blade.code.py', 'w') as ofile:
# ofile.write(blade_model.forward.code)
# with io.open(filename + '.blade.graph.txt', 'w') as ofile:
# ofile.write(blade_model.forward.graph)
with io.open(filename + '.blade', 'wb') as ofile:
torch.jit.save(blade_model, ofile)
_export_jit_and_blade(model, cfg, filename, dummy_inputs, fp16=fp16)


def _export_pose_topdown(model, cfg, filename, fp16=False, dummy_inputs=None):
Expand Down Expand Up @@ -672,53 +681,26 @@ def _dummy_inputs(cfg):
if dummy_inputs is None:
dummy_inputs = _dummy_inputs(cfg)

def _trace_model():
with torch.no_grad():
model.forward = model.forward_export
trace_model = torch.jit.trace(
model, copy.deepcopy(dummy_inputs), strict=False)
return trace_model

export_type = cfg.export.get('type')
if export_type in ['jit', 'blade']:
if fp16:
with torch.cuda.amp.autocast():
trace_model = _trace_model()
else:
trace_model = _trace_model()
torch.jit.save(trace_model, filename + '.jit')
else:
raise NotImplementedError(f'Not support export type {export_type}!')
_export_jit_and_blade(model, cfg, filename, dummy_inputs, fp16=fp16)

if export_type == 'jit':
return

blade_config = cfg.export.get('blade_config')
def _export_stgcn(model, cfg, filename, fp16=False, dummy_inputs=None):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = copy.deepcopy(model)
model.eval()
model.to(device)

from easycv.toolkit.blade import blade_env_assert, blade_optimize
assert blade_env_assert()
if hasattr(cfg, 'export') and getattr(cfg.export, 'type', 'raw') == 'raw':
return _export_common(model, cfg, filename)

def _get_blade_model():
blade_model = blade_optimize(
speed_test_model=model,
model=trace_model,
inputs=copy.deepcopy(dummy_inputs),
blade_config=blade_config,
static_opt=False,
min_num_nodes=None,
check_inputs=False,
fp16=fp16)
return blade_model
def _dummy_inputs(device):
keypoints = torch.randn([1, 3, 300, 17, 2]).to(device)
return (keypoints, )

# optimize model with blade
if fp16:
with torch.cuda.amp.autocast():
blade_model = _get_blade_model()
else:
blade_model = _get_blade_model()
if dummy_inputs is None:
dummy_inputs = _dummy_inputs(device)

with io.open(filename + '.blade', 'wb') as ofile:
torch.jit.save(blade_model, ofile)
_export_jit_and_blade(model, cfg, filename, dummy_inputs, fp16=fp16)


def replace_syncbn(backbone_cfg):
Expand Down
19 changes: 11 additions & 8 deletions easycv/predictors/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,16 +339,19 @@ def __init__(self,
nms_thresh=None,
test_conf=None,
input_processor_threads=8,
mode='BGR'):
mode='BGR',
model_type=None):
self.max_det = max_det
self.use_trt_efficientnms = use_trt_efficientnms

if model_path.endswith('jit'):
self.model_type = 'jit'
elif model_path.endswith('blade'):
self.model_type = 'blade'
else:
self.model_type = 'raw'
self.model_type = model_type
if self.model_type is None:
if model_path.endswith('jit'):
self.model_type = 'jit'
elif model_path.endswith('blade'):
self.model_type = 'blade'
else:
self.model_type = 'raw'
assert self.model_type in ['raw', 'jit', 'blade']

if self.model_type == 'blade' or self.use_trt_efficientnms:
import torch_blade
Expand Down
25 changes: 16 additions & 9 deletions easycv/predictors/pose_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,9 @@ def __call__(self, inputs):
return output


# TODO: Fix when multi people are detected in each sample,
# all the people results will be passed to the pose model,
# resulting in a dynamic batch_size, which is not supported by jit script model.
@PREDICTORS.register_module()
class PoseTopDownPredictor(PredictorV2):
"""Pose topdown predictor.
Expand Down Expand Up @@ -336,22 +339,26 @@ def __init__(self,
save_results=False,
save_path=None,
mode='BGR',
model_type=None,
*args,
**kwargs):
assert batch_size == 1, 'Only support batch_size=1 now!'
self.cat_id = cat_id
self.bbox_thr = bbox_thr
self.detection_predictor_config = detection_predictor_config

if model_path.endswith('jit'):
assert config_file is not None
self.model_type = 'jit'
elif model_path.endswith('blade'):
import torch_blade
assert config_file is not None
self.model_type = 'blade'
else:
self.model_type = 'raw'
self.model_type = model_type
if self.model_type is None:
if model_path.endswith('jit'):
assert config_file is not None
self.model_type = 'jit'
elif model_path.endswith('blade'):
import torch_blade
assert config_file is not None
self.model_type = 'blade'
else:
self.model_type = 'raw'
assert self.model_type in ['raw', 'jit', 'blade']

super(PoseTopDownPredictor, self).__init__(
model_path,
Expand Down
44 changes: 44 additions & 0 deletions easycv/predictors/video_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from easycv.datasets.registry import PIPELINES
from easycv.file import io
from easycv.models.builder import build_model
from easycv.utils.checkpoint import load_checkpoint
from easycv.utils.mmlab_utils import (dynamic_adapt_for_mmlab,
remove_adapt_for_mmlab)
from easycv.utils.registry import build_from_cfg
Expand Down Expand Up @@ -262,8 +263,22 @@ def __init__(self,
pipelines=None,
input_processor_threads=8,
mode='RGB',
model_type=None,
*args,
**kwargs):
self.model_type = model_type
if self.model_type is None:
if model_path.endswith('jit'):
assert config_file is not None
self.model_type = 'jit'
elif model_path.endswith('blade'):
import torch_blade
assert config_file is not None
self.model_type = 'blade'
else:
self.model_type = 'raw'
assert self.model_type in ['raw', 'jit', 'blade']

super(STGCNPredictor, self).__init__(
model_path,
config_file=config_file,
Expand Down Expand Up @@ -301,6 +316,35 @@ def __init__(self,

self.label_map = [i.strip() for i in class_list]

def _build_model(self):
if self.model_type != 'raw':
with io.open(self.model_path, 'rb') as infile:
model = torch.jit.load(infile, self.device)
else:
model = super()._build_model()
return model

def prepare_model(self):
"""Build model from config file by default.
If the model is not loaded from a configuration file, e.g. torch jit model, you need to reimplement it.
"""
model = self._build_model()
model.to(self.device)
model.eval()
if self.model_type == 'raw':
load_checkpoint(model, self.model_path, map_location='cpu')
return model

def model_forward(self, inputs):
if self.model_type == 'raw':
return super().model_forward(inputs)
else:
with torch.no_grad():
keypoint = inputs['keypoint'].to(self.device)
result = self.model(keypoint)

return result

def get_input_processor(self):
return STGCNInputProcessor(
self.cfg,
Expand Down
Loading

0 comments on commit 5c33d9e

Please sign in to comment.