onnx溢出检查
Published:
本文介绍onnx溢出检查,检查是否超过fp16范围。
import onnx
from collections import defaultdict
import numpy as np
from onnx import numpy_helper
FP16_MAX = 65504.0
INT32_MAX = 2147483647
def collect_node_statistics(onnx_model_path):
model = onnx.load(onnx_model_path)
node_type_map = defaultdict(list)
for node in model.graph.node:
node_type = node.op_type
node_name = node.name
node_type_map[node_type].append(node_name)
for node_type, names in node_type_map.items():
print(f"Node Type : {node_type}, Count: {len(names)}")
print(f"Node Names: {names}\n")
return model
def check_zero_weights_bias(model):
for initializer in model.graph.initializer:
name = initializer.name
# weights = np.frombuffer(initializer.raw_data, dtype=np.float32)
weights = numpy_helper.to_array(initializer)
if np.all(weights == 0):
print(f"Warning: Parameter '{name}' {weights.shape} is all zeros!")
# if np.any(weights > FP16_MAX):
# print(f"Warning: Parameter '{name}' has values exceeding FP16 max ({FP16_MAX})!")
overflow_values = weights[weights > FP16_MAX]
# if np.any(weights > INT32_MAX):
if overflow_values.size > 0 :
print(f"Warning: Parameter '{name}' has values exceeding INT32 max ({FP16_MAX}) \n {overflow_values.tolist()}!")
overflow_values = weights[weights > INT32_MAX]
# if np.any(weights > INT32_MAX):
if overflow_values.size > 0 :
print(f"Warning: Parameter '{name}' has values exceeding INT32 max ({INT32_MAX}) \n {overflow_values.tolist()}!")
onnx_model_path = "ResNet50.onnx"
model = collect_node_statistics(onnx_model_path)
check_zero_weights_bias(model)
