Skip to content

Commit

Permalink
TinyMS v0.3.0 adapts for MindSpore 1.5.0
Browse files Browse the repository at this point in the history
  • Loading branch information
hellowaywewe committed Dec 30, 2021
1 parent 206ed0e commit 13c9cf8
Show file tree
Hide file tree
Showing 8 changed files with 62 additions and 53 deletions.
4 changes: 2 additions & 2 deletions docs/en/source/design/concepts.md
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ input = image_np.copy()
# 4.Detect the input image
detection_bbox_data = object_detection_predict(input, detector, is_training=False)

# 5.Draw the box for the input image and visualize in the opencv window using OpenCV.
# 5.Draw the box for the input image and and view it using OpenCV.
detection_image_np = visualize_boxes_on_image(image_np, detection_bbox_data, box_color=(0, 255, 0),
box_thickness=3, text_font=cv2.FONT_HERSHEY_PLAIN,
font_scale=3, text_color=(0, 0, 255), font_size=3, show_scores=True)
Expand Down Expand Up @@ -298,7 +298,7 @@ while True:
# 4.Detect the input frame image
detection_bbox_data = object_detection_predict(input, detector, is_training=False)

# 5.Draw the box for the input frame image and visualize in the opencv window using OpenCV.
# 5.Draw the box for the input frame image and view it using OpenCV.
detection_image_np = visualize_boxes_on_image(image_np, detection_bbox_data, box_color=(0, 255, 0),
box_thickness=3, text_font=cv2.FONT_HERSHEY_PLAIN,
font_scale=3, text_color=(0, 0, 255), font_size=3, show_scores=True)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _write_version(file):
'scipy>=1.5.2,<1.8.0',
'matplotlib>=3.1.3',
'Pillow>=6.2.0',
'mindspore==1.3.0',
'mindspore==1.5.0',
'requests>=2.22.0',
'flask>=1.1.1',
'python-Levenshtein>=0.10.2',
Expand Down
4 changes: 2 additions & 2 deletions tests/st/app/object_detection/opencv_camera_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@

cap = cv2.VideoCapture(0)
while True:
# 3.Read the frame image from the camera
# 3.Read the frame image from the camera using OpenCV
ret, image_np = cap.read()
input = image_np.copy()

# 4.Detect the input frame image
detection_bbox_data = object_detection_predict(input, detector, is_training=False)

# 5.Draw the box for the input frame image and visualize in the opencv window.
# 5.Draw the box for the input frame image and view it using OpenCV.
detection_image_np = visualize_boxes_on_image(image_np, detection_bbox_data, box_color=(0, 255, 0),
box_thickness=3, text_font=cv2.FONT_HERSHEY_PLAIN,
font_scale=3, text_color=(0, 0, 255), font_size=3, show_scores=True)
Expand Down
4 changes: 2 additions & 2 deletions tests/st/app/object_detection/opencv_image_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,14 @@ def parse_args():
# 2.Generate the instance of ObjectDetector
detector = ObjectDetector(config=config)

# 3.Read the input image
# 3.Read the input image using OpenCV
image_np = cv2.imread(args_opt.img_path)
input = image_np.copy()

# 4.Detect the input image
detection_bbox_data = object_detection_predict(input, detector, is_training=False)

# 5.Draw the box for the input image and visualize in the opencv window.
# 5.Draw the box for the input image and view it using OpenCV.
detection_image_np = visualize_boxes_on_image(image_np, detection_bbox_data, box_color=(0, 255, 0),
box_thickness=3, text_font=cv2.FONT_HERSHEY_PLAIN,
font_scale=3, text_color=(0, 0, 255), font_size=3, show_scores=True)
Expand Down
15 changes: 13 additions & 2 deletions tinyms/app/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""
This module is to support vision visualization with opencv, which can help
developers use pre-trained models to predict and show the reasoning image fast.
Current it only supports object detection model.
"""
from . import object_detection
from .object_detection.object_detector import object_detection_predict, ObjectDetector
from .object_detection.utils.view_util import visualize_boxes_on_image, draw_boxes_on_image, save_image
from .object_detection.utils.config_util import load_and_parse_config


object_detection_utils = ['visualize_boxes_on_image', 'draw_boxes_on_image', 'save_image', 'load_and_parse_config']

__all__ = []
__all__ = ['ObjectDetector', 'object_detection_predict']
__all__.extend(object_detection_utils)
__all__.extend(object_detection.__all__)
40 changes: 20 additions & 20 deletions tinyms/app/object_detection/object_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,18 @@


class ObjectDetector():
r'''
r"""
ObjectDetector is a high-level class defined for building model,preproceing the input image,
predicting and postprocessing the prediction output data.
Args:
config (dict): model config parsed from the json file under the app/object_detection/configs dir.
'''
"""
def __init__(self, config=None):
self.config = config

def data_preprocess(self, input):
r'''
r"""
Preprocess the input image.
Args:
Expand All @@ -53,7 +53,7 @@ def data_preprocess(self, input):
Returns:
list, the preprocess image shape.
numpy.ndarray, the preprocess image result.
'''
"""
if not isinstance(input, np.ndarray):
err_msg = 'The input type should be numpy.ndarray, got {}.'.format(type(input))
raise TypeError(err_msg)
Expand All @@ -69,31 +69,31 @@ def data_preprocess(self, input):
return image_shape, transform_input

def convert2tensor(self, transform_input):
r'''
r"""
Convert the numpy data to the tensor format.
Args:
transform_input (numpy.ndarray): the preprocessed image.
transform_input (numpy.ndarray): the preprocessing image.
Returns:
Tensor, the converted image.
'''
"""
if not isinstance(transform_input, np.ndarray):
err_msg = 'The transform_input type should be numpy.ndarray, got {}.'.format(type(transform_input))
raise TypeError(err_msg)
input_tensor = ts.expand_dims(ts.array(list(transform_input)), 0)
return input_tensor

def model_build(self, is_training=False):
r'''
r"""
Build the object detection model to predict the image.
Args:
is_training (bool): default: False.
Returns:
model.Model, generated object detection model.
'''
"""
model_net = model_checker.get(self.config.get('model_net'))
if not model_net:
err_msg = 'Currently model_net only supports {}!'.format(str(list(model_checker.keys())))
Expand All @@ -109,17 +109,17 @@ def model_build(self, is_training=False):
return serve_model

def model_load_and_predict(self, serve_model, input_tensor):
r'''
r"""
Load the object detection model to predict the image.
Args:
serve_model (model.Model): object detection model.
input_tensor(Tensor): the converted input image
input_tensor (Tensor): the converted input image.
Returns:
model.Model, object detection model loaded the checkpoint file.
list, predictions output result.
'''
"""
ckpt_path = self.config.get('checkpoint_path')
if not ckpt_path:
err_msg = 'The ckpt_path {} can not be none.'.format(ckpt_path)
Expand All @@ -139,16 +139,16 @@ def model_load_and_predict(self, serve_model, input_tensor):
return serve_model, predictions_output

def data_postprocess(self, predictions_output, image_shape):
r'''
r"""
Postprocessing the predictions output data.
Args:
predictions_output (list): predictions output data.
image_shape(list): the shapr of the input image.
image_shape (list): the shape of the input image.
Returns:
dict, the postprocess result.
'''
dict, the postprocessing result.
"""
output_np = (ts.concatenate((predictions_output[0], predictions_output[1]), axis=-1).asnumpy())
transform_func = transform_checker.get(self.config.get('dataset'))
if not transform_func:
Expand All @@ -158,17 +158,17 @@ def data_postprocess(self, predictions_output, image_shape):


def object_detection_predict(input, object_detector, is_training=False):
r'''
r"""
An easy object detection model predicting method for beginning developers to use.
Args:
input (numpy.ndarray): the input image.
object_detector (ObjectDetector): the instance of the ObjectDetector class
object_detector (ObjectDetector): the instance of the ObjectDetector class.
is_training (bool): default: False.
Returns:
dict, the postprocess result.
'''
dict, the postprocessing result.
"""
if not isinstance(object_detector, ObjectDetector):
err_msg = 'The object_detector is not the instance of ObjectDetector'
raise TypeError(err_msg)
Expand Down
4 changes: 2 additions & 2 deletions tinyms/app/object_detection/utils/config_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,15 @@ def _download_ckeckpoint(checkpoint_url, sha256, checkpoint_path):


def load_and_parse_config(config_path):
r'''
r"""
Load and parse the json file the object detection model.
Args:
config_path (numpy.ndarray): the config json file path.
Returns:
dict, the model configuration.
'''
"""
# Check if config_path existed
if not os.path.exists(config_path):
raise FileNotFoundError("The config file path {} does not exist!".format(config_path))
Expand Down
42 changes: 20 additions & 22 deletions tinyms/app/object_detection/utils/view_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@


def save_image(img, save_dir='./', img_name='no_name', img_format='jpg'):
r'''
r"""
Save the prediction image.
Args:
img (numpy.ndarray): the input image.
save_dir (str): the dir to save the prediction image.
img_name (str): the name of the prediction image.
img_format (str): the format of the prediction image.
'''
img_name (str): the name of the prediction image. Default: 'no_name'.
img_format (str): the format of the prediction image. Default: 'jpg'.
"""
if img_format.lower() not in IMG_FORMAT:
raise Exception("当前图片格式仅支持", IMG_FORMAT)
output_image = os.path.join(save_dir, '{}.{}'.format(img_name, img_format))
Expand All @@ -39,25 +39,23 @@ def save_image(img, save_dir='./', img_name='no_name', img_format='jpg'):
def draw_boxes_on_image(img, boxes, box_scores, box_classes, box_color=(0, 255, 0),
box_thickness=3, text_font=cv2.FONT_HERSHEY_PLAIN,
font_scale=3, text_color=(0, 0, 255), font_size=3, show_scores=True):
r'''
r"""
Draw the prediction box for the input image.
Args:
img (numpy.ndarray): the input image.
boxes (list): the box coordinates.
box_scores (int): the prediction score.
box_classes: the prediction category.
box_color (list): the box color.
box_thickness (int): box thickness.
text_font (Enum): text font.
font_scale (int): font scale.
text_color (list): text color.
font_size (int): font size.
box_color (list): the box color. Default: (0, 255, 0).
box_thickness (int): box thickness. Default: 3.
text_font (Enum): text font. Default: cv2.FONT_HERSHEY_PLAIN.
font_scale (int): font scale. Default: 3.
text_color (list): text color. Default: (0, 0, 255).
font_size (int): font size. Default: 3.
show_scores (bool): whether to show scores. Default: True.
Returns:
numpy.ndarray, the output image drawed the prediction box.
'''
"""
x = int(boxes[0])
y = int(boxes[1])
w = int(boxes[2])
Expand All @@ -71,23 +69,23 @@ def draw_boxes_on_image(img, boxes, box_scores, box_classes, box_color=(0, 255,
def visualize_boxes_on_image(img, bbox_data, box_color=(0, 255, 0),
box_thickness=3, text_font=cv2.FONT_HERSHEY_PLAIN,
font_scale=3, text_color=(0, 0, 255), font_size=3, show_scores=True):
r'''
r"""
Visualize the prediction image.
Args:
img (numpy.ndarray): the input image.
bbox_data (dict): the predictions box data.
box_color (list): the box color.
box_thickness (int): box thickness.
text_font (Enum): text font.
font_scale (int): font scale.
text_color (list): text color.
font_size (int): font size.
box_color (list): the box color. Default: (0, 255, 0).
box_thickness (int): box thickness. Default: 3.
text_font (Enum): text font. Default: cv2.FONT_HERSHEY_PLAIN.
font_scale (int): font scale. Default: 3.
text_color (list): text color. Default: (0, 0, 255).
font_size (int): font size. Default: 3.
show_scores (bool): whether to show scores. Default: True.
Returns:
numpy.ndarray, the output image drawed the prediction box.
'''
"""
bbox_num = len(bbox_data)
if bbox_num:
for i in range(bbox_num):
Expand Down

0 comments on commit 13c9cf8

Please sign in to comment.