Skip to content

Commit

Permalink
fix v0.2.0 bug and update released note
Browse files Browse the repository at this point in the history
  • Loading branch information
hellowaywewe committed Jul 15, 2021
1 parent a1a14f6 commit cada54c
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 6 deletions.
8 changes: 8 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
# TinyMS Release Notes

## 0.2.1

Released 2021-07-15.

### Major Features and Improvements

* Fix `load_checkpoint` interface bug in TinyMS 0.2.0 hub module. [#96](https://github.com/tinyms-ai/tinyms/pull/96)

## 0.2.0

Released 2021-06-07.
Expand Down
11 changes: 11 additions & 0 deletions docker/tinyms/0.2.1-jupyter/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
ARG BASE_CONTAINER=jupyter/scipy-notebook:ubuntu-18.04
FROM $BASE_CONTAINER

LABEL MAINTAINER="TinyMS Authors"

# Set the default jupyter token with "tinyms"
RUN sh -c '/bin/echo -e "tinyms\ntinyms\n" | jupyter notebook password'

# Install TinyMS cpu whl package
RUN pip install --no-cache-dir numpy==1.17.5 tinyms==0.2.1 && \
fix-permissions "${CONDA_DIR}"
10 changes: 10 additions & 0 deletions docker/tinyms/0.2.1/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
ARG BASE_CONTAINER=swr.cn-south-1.myhuaweicloud.com/mindspore/mindspore-cpu:1.2.0
FROM $BASE_CONTAINER

LABEL MAINTAINER="TinyMS Authors"

# Install base tools
RUN apt-get update

# Install TinyMS cpu whl package
RUN pip install --no-cache-dir tinyms==0.2.1
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import setuptools

package_name = 'tinyms'
version_tag = '0.2.0'
version_tag = '0.2.1'
pwd = os.path.dirname(os.path.realpath(__file__))


Expand Down
20 changes: 15 additions & 5 deletions tinyms/hub/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def load(uid, pretrained=True, **kwargs):
return net


def load_checkpoint(uid, dst):
def load_checkpoint(uid, dst, pretrained=True, **kwargs):
'''
Load model checkpoint file from remote TinyMS Hub.
Expand All @@ -106,16 +106,26 @@ def load_checkpoint(uid, dst):
the official example: `tinyms/0.2/lenet5_v1_mnist`.
dst (str): Full path of filename where the checkpoint file
will be loaded, e.g. `/tmp/lenet5.ckpt`.
pretrained (bool): Specified if to load pretrained weight ckpt file. Default: True.
kwargs (dict, optional): Keyword arguments for network initialization.
Examples:
>>> from tinyms import hub
>>>
>>> hub.load_checkpoint('tinyms/0.2/lenet5_v1_mnist', '/tmp/lenet5.ckpt')
>>> hub.load_checkpoint('tinyms/0.2/lenet5_v1_mnist', '/tmp/lenet5.ckpt', class_num=10)
'''
uid_info = UidInfo(uid)
asset_path = _get_model_asset_path(uid_info)
weights = _load_weights(asset_path)
_save_checkpoint(weights, dst)
model_info = uid_info.model_name + '_' + uid_info.model_version
net_func = MODEL_HUB.get(model_info)
if net_func is None:
raise ValueError("Currently model_name only supports " + str(list(MODEL_HUB.keys())) + "!")

net = net_func(**kwargs)
if pretrained is True:
asset_path = _get_model_asset_path(uid_info)
ckpt_params = _load_weights(asset_path)
load_param_into_net(net, ckpt_params)
_save_checkpoint(net, dst)


def load_weights(uid):
Expand Down

0 comments on commit cada54c

Please sign in to comment.