OrinX编译resnet50

2 minute read

Published:

本文介绍torch model的定义,onnx导出,tensorrt 的编译运行等。

torch model 定义

# /**************************************************************
#  * @Copyright: 2021-2022 Copyright
#  * @Author: lix
#  * @Date: 2023-03-03 11:09:48
#  * @Last Modified by: lix
#  * @Last Modified time: 2023-03-03 11:09:48
#  **************************************************************/

# ref https://blog.csdn.net/hlld__/article/details/113755368

import torch
import torch.nn as nn
import torch.nn.functional as F

class Conv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
                 padding=None, groups=1, activation=True):
        super(Conv, self).__init__()
        padding = kernel_size // 2 if padding is None else padding
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride,
                              padding, groups=groups, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.act = nn.ReLU(inplace=True) if activation else nn.Identity()

    def forward(self, x):
        return self.act(self.bn(self.conv(x)))

class Bottleneck(nn.Module):
    def __init__(self, in_channels, out_channels, down_sample=False, groups=1):
        super(Bottleneck, self).__init__()
        stride = 2 if down_sample else 1
        mid_channels = out_channels // 4
        self.shortcut = Conv(in_channels, out_channels, kernel_size=1, stride=stride, activation=False) \
            if in_channels != out_channels else nn.Identity()
        self.conv = nn.Sequential(*[
            Conv(in_channels, mid_channels, kernel_size=1, stride=1),
            Conv(mid_channels, mid_channels, kernel_size=3, stride=stride, groups=groups),
            Conv(mid_channels, out_channels, kernel_size=1, stride=1, activation=False)
        ])

    def forward(self, x):
        y = self.conv(x) + self.shortcut(x)
        return F.relu(y, inplace=True)

class ResNet50(nn.Module):
    def __init__(self, num_classes, sz = 224):
        super(ResNet50, self).__init__()
        self.stem = nn.Sequential(*[
            Conv(3, 64, kernel_size=7, stride=2),             # /2    
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)  # /2
        ])
        self.stages = nn.Sequential(*[
            self._make_stage(64, 256, down_sample=False, num_blocks=3),
            self._make_stage(256, 512, down_sample=True, num_blocks=4),  # /2
            self._make_stage(512, 1024, down_sample=True, num_blocks=6), # /2
            self._make_stage(1024, 2048, down_sample=True, num_blocks=3),# /2
        ])

        self.head = nn.Sequential(*[
            nn.AvgPool2d(kernel_size=sz//32, stride=1, padding=0), ### 224/32 = 7
            nn.Flatten(start_dim=1, end_dim=-1),
            nn.Linear(2048, num_classes)
        ])

    @staticmethod
    def _make_stage(in_channels, out_channels, down_sample, num_blocks):
        layers = [Bottleneck(in_channels, out_channels, down_sample=down_sample)]
        for _ in range(1, num_blocks):
            layers.append(Bottleneck(out_channels, out_channels, down_sample=False))
        return nn.Sequential(*layers)

    def forward(self, x):
        return self.head(self.stages(self.stem(x)))

if __name__ == "__main__":
    inputs = torch.rand((1, 3, 224, 224)).cuda()
    model = ResNet50(num_classes=1000, sz = 224).cuda().eval()
    outputs = model(inputs)


onnx导出

# /**************************************************************
#  * @Copyright: 2021-2022 Copyright
#  * @Author: lix
#  * @Date: 2023-03-03 11:09:48
#  * @Last Modified by: lix
#  * @Last Modified time: 2023-03-03 11:09:48
#  **************************************************************/

import torch
from onnxsim import simplify
import onnx, os
from loguru import logger

def export_onnx(model, inputs, onnx_name):
    model.eval()
    with torch.no_grad():
        torch.onnx.export(model, inputs, onnx_name, opset_version=13)

    model = onnx.load(onnx_name)
    os.remove(onnx_name)

    onnx.checker.check_model(model)
    model_simp, check = simplify(model)
    assert check, "Simplified ONNX model could not be validated"
    onnx.checker.check_model(model_simp)
    onnx.save(model_simp, onnx_name)
    logger.info("simplify onnx done !")


from resnet50 import ResNet50
inputs = torch.rand((1, 3, 224, 224)).cuda()
model = ResNet50(num_classes=1000, sz = 224).cuda().eval()
export_onnx(model, inputs, "ResNet50.onnx")

tensorrt 的编译运行

/usr/src/tensorrt/bin/trtexec --onnx=ResNet50.onnx --fp16 --verbose \
--saveEngine=ResNet50.plan \
--dumpProfile --dumpLayerInfo --separateProfileRun --useCudaGraph \
--inputIOFormats=fp16:chw --outputIOFormats=fp16:chw --noDataTransfers --useSpinWait

build log