ResNet模型导出是在Pytorch模型的生产环境下进行的,需提前安装好Pytorch环境。
Pytorch模型在编译前要经过torch.jit.trace
,trace后的模型才能使用tpu-nntc编译BModel。trace的方法和原理可参考torch.jit.trace参考文档。
如果使用tpu-mlir编译模型,则必须先将Pytorch模型导出为onnx模型。下面以导出1 batch的onnx模型为例进行演示:
import torch
import torch.onnx
from torchvision.models import resnet50
if __name__ == '__main__':
input = torch.randn(1, 3, 224, 224) # [1,3,224,224]分别对应[B,C,H,W]
model = resnet50() # 载入模型框架
model.load_state_dict(torch.load("xxx.pth")) # xxx.pth表示.pth文件, 这一步载入模型权重
model.eval() # 设置模型为推理模式
torch.onnx.export(model, input, "xxx.onnx", training=torch.onnx.TrainingMode.TRAINING) # xxx.onnx表示.onnx文件, 这一步导出为onnx模型, 并不做任何算子融合操作。