onnx溢出检查

less than 1 minute read

Published:

本文介绍onnx溢出检查,检查是否超过fp16范围。


import onnx
from collections import defaultdict
import numpy as np
from onnx import numpy_helper

FP16_MAX = 65504.0
INT32_MAX = 2147483647

def collect_node_statistics(onnx_model_path):
    model = onnx.load(onnx_model_path) 
    node_type_map = defaultdict(list)

    for node in model.graph.node:
        node_type = node.op_type    
        node_name = node.name     
        node_type_map[node_type].append(node_name)


    for node_type, names in node_type_map.items():
        print(f"Node Type : {node_type}, Count: {len(names)}")
        print(f"Node Names: {names}\n")   

    return model

def check_zero_weights_bias(model):
    for initializer in model.graph.initializer:
        name = initializer.name
        # weights = np.frombuffer(initializer.raw_data, dtype=np.float32)   
        weights = numpy_helper.to_array(initializer)  
        if np.all(weights == 0):
            print(f"Warning: Parameter '{name}' {weights.shape} is all zeros!")

        # if np.any(weights > FP16_MAX):                    
        #     print(f"Warning: Parameter '{name}' has values exceeding FP16 max ({FP16_MAX})!")
        overflow_values = weights[weights > FP16_MAX]
        # if np.any(weights > INT32_MAX):
        if overflow_values.size > 0 :
            print(f"Warning: Parameter '{name}' has values exceeding INT32 max ({FP16_MAX}) \n {overflow_values.tolist()}!")

        overflow_values = weights[weights > INT32_MAX]
        # if np.any(weights > INT32_MAX):
        if overflow_values.size > 0 :
            print(f"Warning: Parameter '{name}' has values exceeding INT32 max ({INT32_MAX}) \n {overflow_values.tolist()}!")

onnx_model_path = "ResNet50.onnx"  
model = collect_node_statistics(onnx_model_path)
check_zero_weights_bias(model)