Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add trt args and trt infer #180

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion Classification/cnns/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,29 @@ def str2bool(v):
type=str2bool,
nargs='?',
const=True,
help='Whether to use use xla'
help='Whether to use xla'
)
parser.add_argument(
'--use_tensorrt',
type=str2bool,
nargs='?',
default=False,
help='Whether to use tensorrt'
)
parser.add_argument(
'--use_int8_online',
type=str2bool,
nargs='?',
default=False,
help='Whether to use online int8 calibration'
)
parser.add_argument(
'--use_int8_offline',
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里缩进和其他地方不一样

type=str2bool,
nargs='?',
default=False, help='Whether to use online int8 calibration'
)

parser.add_argument(
'--channel_last',
type=str2bool,
Expand Down
15 changes: 10 additions & 5 deletions Classification/cnns/evaluate.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,19 @@ DATA_ROOT=/dataset/ImageNet/ofrecord
# Set up model path, e.g. : vgg16_of_best_model_val_top1_721 alexnet_of_best_model_val_top1_54762
MODEL_LOAD_DIR="resnet_v15_of_best_model_val_top1_77318"

python3 of_cnn_evaluate.py \
--num_epochs=3 \
PYTHONPATH=/home/dev/files/repos/oneflow6/build-release/python_scripts \
python3 of_cnn_evaluate.py \
--num_epochs=1 \
--num_val_examples=50000 \
--model_load_dir=$MODEL_LOAD_DIR \
--val_data_dir=$DATA_ROOT/validation \
--val_data_part_num=256 \
--num_nodes=1 \
--node_ips='127.0.0.1' \
--gpu_num_per_node=4 \
--val_batch_size_per_device=64 \
--model="resnet50"
--gpu_num_per_node=1 \
--val_batch_size_per_device=10 \
--model="resnet50" \
--use_tensorrt=True \
--use_int8_online=False \
--use_int8_offline=True \
|& tee rn50-offline-int8.log
11 changes: 11 additions & 0 deletions Classification/cnns/job_function_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,17 @@ def _default_config(args):
if args.use_xla:
config.use_xla_jit(True)
config.enable_fuse_add_to_output(True)
if args.use_tensorrt:
config.use_tensorrt(True)
if args.use_int8_online or args.use_int8_offline:
config.tensorrt.use_int8()
elif args.use_int8_online or args.use_int8_offline:
raise Exception("You can set use_int8_online or use_int8_offline only after use_tensorrt is True!")
if args.use_int8_offline:
int8_calibration_path = "./int8_calibration"
config.tensorrt.int8_calibration(int8_calibration_path)
if args.use_int8_offline and args.use_int8_online:
raise ValueError("You cannot use use_int8_offline or use_int8_online at the same time!")
return config


Expand Down
13 changes: 11 additions & 2 deletions Classification/cnns/of_cnn_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,22 @@ def main():
print("Restoring model from {}.".format(args.model_load_dir))
checkpoint = flow.train.CheckPoint()
checkpoint.load(args.model_load_dir)


if args.use_int8_online:
for j in range(10):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

校准数据集的batch数可以用一个参数来控制,--calibration_batch_num

InferenceNet().get()
flow.tensorrt.cache_int8_calibration()

warmup = 2
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warmup也需要用一个参数来控制

for j in range(warmup):
InferenceNet().get()

metric = Metric(desc='validation', calculate_batches=num_val_steps, summary=summary,
save_summary_steps=num_val_steps, batch_size=val_batch_size)

for i in range(args.num_epochs):
for j in range(num_val_steps):
InferenceNet().async_get(metric.metric_cb(0, j))


if __name__ == "__main__":
main()
12 changes: 9 additions & 3 deletions Classification/cnns/of_cnn_train_val.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
epoch_size = math.ceil(args.num_examples / train_batch_size)
num_val_steps = int(args.num_val_examples / val_batch_size)


model_dict = {
"resnet50": resnet_model.resnet50,
"vgg": vgg_model.vgg16bn,
Expand Down Expand Up @@ -126,12 +125,19 @@ def main():
batch_size=train_batch_size, loss_key='loss')
for i in range(epoch_size):
TrainNet().async_get(metric.metric_cb(epoch, i))

# flow.tensorrt.write_int8_calibration("./int8_calibration") # mkdir int8_calibration
if args.val_data_dir:
metric = Metric(desc='validation', calculate_batches=num_val_steps, summary=summary,
save_summary_steps=num_val_steps, batch_size=val_batch_size)
for i in range(num_val_steps):
for i in range(val_batch_size):
# if i<=10:
# InferenceNet().get()
# if i ==10:
# flow.tensorrt.cache_int8_calibration()
# else:
# InferenceNet().async_get(metric.metric_cb(epoch, i))
InferenceNet().async_get(metric.metric_cb(epoch, i))

snapshot.save('epoch_{}'.format(epoch))


Expand Down
10 changes: 5 additions & 5 deletions Classification/cnns/train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ echo NUM_EPOCH=$NUM_EPOCH
if [ -n "$2" ]; then
DATA_ROOT=$2
else
DATA_ROOT=/data/imagenet/ofrecord
DATA_ROOT=/dataset/ImageNet/ofrecord
fi
echo DATA_ROOT=$DATA_ROOT

LOG_FOLDER=../logs
LOG_FOLDER=./logs
mkdir -p $LOG_FOLDER
LOGFILE=$LOG_FOLDER/resnet_training.log

Expand All @@ -26,13 +26,13 @@ python3 of_cnn_train_val.py \
--val_data_dir=$DATA_ROOT/validation \
--val_data_part_num=256 \
--num_nodes=1 \
--gpu_num_per_node=8 \
--gpu_num_per_node=2 \
--optimizer="sgd" \
--momentum=0.875 \
--label_smoothing=0.1 \
--learning_rate=1.024 \
--loss_print_every_n_iter=100 \
--batch_size_per_device=128 \
--loss_print_every_n_iter=10 \
--batch_size_per_device=64 \
--val_batch_size_per_device=50 \
--num_epoch=$NUM_EPOCH \
--model="resnet50" 2>&1 | tee ${LOGFILE}
Expand Down
6 changes: 3 additions & 3 deletions Classification/cnns/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,9 @@ def callback(outputs):
self.summary.scalar(self.desc + "_top_{}".format(self.top_k),
top_k_accuracy, epoch, step)

if self.save_summary:
if (step + 1) % self.save_summary_steps == 0:
self.summary.save()
# if self.save_summary:
# if (step + 1) % self.save_summary_steps == 0:
# self.summary.save()

return callback

Expand Down