py图像预处理

1 minute read

Published:

本文介绍py图像预处理。


import os
import cv2
import numpy as np

class ONNXForward():
    def __init__(self, roi = [[0, 96, 1280, 864],]): 
        self.input_shape  = [4, 3, 768, 1280]
        self.output_shape = [4, 5040, 26]
        self.roi = roi 

    def get_ratio(self, img_data):
        return min(self.input_shape[-2] / img_data.shape[0], self.input_shape[-1] / img_data.shape[1])

    def _pre_process(self, img_data, swap=(2, 0, 1), basename = ""):
        print(f"img_data.shape:{img_data.shape}") # (1440, 1920, 3)
        ## roi = [[0, 170, 1920, 1322]]
                                     # row                              col
                                     # 170, 1322                        0, 1920
        ### crop                            
        img_data = img_data[self.roi[0][1]: self.roi[0][3], self.roi[0][0]: self.roi[0][2]]

        ## padding + resize + transpose
        if len(img_data.shape) == 3:
            padded_img = np.ones((self.input_shape[-2], self.input_shape[-1], 3), dtype=np.uint8) * 114
        else:
            padded_img = np.ones(self.input_shape[-2:], dtype=np.uint8) * 114

        ### resize
        r = self.get_ratio(img_data)
        resized_img = cv2.resize(img_data, (int(img_data.shape[1] * r), int(img_data.shape[0] * r)),
                                 interpolation=cv2.INTER_LINEAR, ).astype(np.uint8)        

        ### pad
        padded_img[: int(img_data.shape[0] * r), : int(img_data.shape[1] * r)] = resized_img
        
        ### permute 
        padded_img = padded_img.transpose(swap)
        padded_img = np.expand_dims(padded_img, axis=0)
        padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)

        it = basename.replace('.jpeg', '_py.jpeg')
        cv2.imwrite(it, resized_img)

        it = basename.replace('.jpeg', '_py.xml')
        fs = cv2.FileStorage(it, cv2.FileStorage_WRITE)
        fs.write('py', padded_img)
        fs.release()
        return padded_img, r
    

if __name__ == "__main__":
    od_handler = ONNXForward()
    img_root_dir = "./data/"
    od_handler.roi = [[0, 170, 1920, 1322], ]
    
    img_files = os.listdir(img_root_dir)
    print("img_files num", len(img_files), ":  ", img_files)

    for one_file in img_files:
        img_path = os.path.join(img_root_dir, one_file)
        origin_img = cv2.imread(img_path)
        basename = os.path.basename(img_path)
        if origin_img is None:
            continue

        print("img_path ", img_path)
        results, r = od_handler._pre_process(origin_img, basename = basename)