OrinX编译resnet50
Published:
本文介绍torch model的定义,onnx导出,tensorrt 的编译运行等。
torch model 定义
# /**************************************************************
# * @Copyright: 2021-2022 Copyright
# * @Author: lix
# * @Date: 2023-03-03 11:09:48
# * @Last Modified by: lix
# * @Last Modified time: 2023-03-03 11:09:48
# **************************************************************/
# ref https://blog.csdn.net/hlld__/article/details/113755368
import torch
import torch.nn as nn
import torch.nn.functional as F
class Conv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
padding=None, groups=1, activation=True):
super(Conv, self).__init__()
padding = kernel_size // 2 if padding is None else padding
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride,
padding, groups=groups, bias=False)
self.bn = nn.BatchNorm2d(out_channels)
self.act = nn.ReLU(inplace=True) if activation else nn.Identity()
def forward(self, x):
return self.act(self.bn(self.conv(x)))
class Bottleneck(nn.Module):
def __init__(self, in_channels, out_channels, down_sample=False, groups=1):
super(Bottleneck, self).__init__()
stride = 2 if down_sample else 1
mid_channels = out_channels // 4
self.shortcut = Conv(in_channels, out_channels, kernel_size=1, stride=stride, activation=False) \
if in_channels != out_channels else nn.Identity()
self.conv = nn.Sequential(*[
Conv(in_channels, mid_channels, kernel_size=1, stride=1),
Conv(mid_channels, mid_channels, kernel_size=3, stride=stride, groups=groups),
Conv(mid_channels, out_channels, kernel_size=1, stride=1, activation=False)
])
def forward(self, x):
y = self.conv(x) + self.shortcut(x)
return F.relu(y, inplace=True)
class ResNet50(nn.Module):
def __init__(self, num_classes, sz = 224):
super(ResNet50, self).__init__()
self.stem = nn.Sequential(*[
Conv(3, 64, kernel_size=7, stride=2), # /2
nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # /2
])
self.stages = nn.Sequential(*[
self._make_stage(64, 256, down_sample=False, num_blocks=3),
self._make_stage(256, 512, down_sample=True, num_blocks=4), # /2
self._make_stage(512, 1024, down_sample=True, num_blocks=6), # /2
self._make_stage(1024, 2048, down_sample=True, num_blocks=3),# /2
])
self.head = nn.Sequential(*[
nn.AvgPool2d(kernel_size=sz//32, stride=1, padding=0), ### 224/32 = 7
nn.Flatten(start_dim=1, end_dim=-1),
nn.Linear(2048, num_classes)
])
@staticmethod
def _make_stage(in_channels, out_channels, down_sample, num_blocks):
layers = [Bottleneck(in_channels, out_channels, down_sample=down_sample)]
for _ in range(1, num_blocks):
layers.append(Bottleneck(out_channels, out_channels, down_sample=False))
return nn.Sequential(*layers)
def forward(self, x):
return self.head(self.stages(self.stem(x)))
if __name__ == "__main__":
inputs = torch.rand((1, 3, 224, 224)).cuda()
model = ResNet50(num_classes=1000, sz = 224).cuda().eval()
outputs = model(inputs)
onnx导出
# /**************************************************************
# * @Copyright: 2021-2022 Copyright
# * @Author: lix
# * @Date: 2023-03-03 11:09:48
# * @Last Modified by: lix
# * @Last Modified time: 2023-03-03 11:09:48
# **************************************************************/
import torch
from onnxsim import simplify
import onnx, os
from loguru import logger
def export_onnx(model, inputs, onnx_name):
model.eval()
with torch.no_grad():
torch.onnx.export(model, inputs, onnx_name, opset_version=13)
model = onnx.load(onnx_name)
os.remove(onnx_name)
onnx.checker.check_model(model)
model_simp, check = simplify(model)
assert check, "Simplified ONNX model could not be validated"
onnx.checker.check_model(model_simp)
onnx.save(model_simp, onnx_name)
logger.info("simplify onnx done !")
from resnet50 import ResNet50
inputs = torch.rand((1, 3, 224, 224)).cuda()
model = ResNet50(num_classes=1000, sz = 224).cuda().eval()
export_onnx(model, inputs, "ResNet50.onnx")
tensorrt 的编译运行
/usr/src/tensorrt/bin/trtexec --onnx=ResNet50.onnx --fp16 --verbose \
--saveEngine=ResNet50.plan \
--dumpProfile --dumpLayerInfo --separateProfileRun --useCudaGraph \
--inputIOFormats=fp16:chw --outputIOFormats=fp16:chw --noDataTransfers --useSpinWait
