基於pytorch的目標檢測資料增強(tensor資料流版本)

專注的阿熊發表於2021-02-03

隨機縮放

class  randomScale(object):

    def __call__(self,image,target):

        # 固定住高度,以 0.8-1.2 伸縮寬度,做影像形變

        if random.random() < 0.3:

            image = np.array(image)

            image = np.transpose(image, (1, 2, 0))

            boxes = target["boxes"]

            scale = random.uniform(0.8,1.2)

            height,width,c = image.shape

            image = cv2.resize(image,(int(width*scale),height))

            scale_tensor = torch.FloatTensor([[scale,1,scale,1]]).expand_as(boxes)

            boxes = boxes * scale_tensor

            image = np.transpose(image, (2, 0, 1))

            image = torch.from_numpy(image)

            target["boxes"] = boxes

        return image,target

隨機模糊

class randomBlur(object):

    def __call__(self, image, target):

        if random.random() < 0.3:

            image = np.array(image)

            image = np.transpose(image, (1, 2, 0))

            image = cv2.blur(image, (5, 5))

            image = np.transpose(image, (2, 0, 1))

            image = torch.from_numpy(image)

        return image, target

隨機擦除(遮擋)
可以增加魯棒性,提供兩個經典演算法,cutout randomerase

class Cutout(object):

    """Randomly mask out one or more patches from an image.

    Args:

        n_holes (int): Number of patches to cut out of each image.

        length (int): The length (in pixels) of each square patch.

    """

    def __init__(self, n_holes=6, length=50):

        self.n_holes = n_holes

        self.length = length

    def __call__(self, image, target):

        """

        Args:

            img (Tensor): Tensor image of size (C, H, W).

        Returns:

            Tensor: Image with n_holes of dimension length x length cut out of it.

        """

        if random.random() < 0.3:

            img = image

            h = img.shape[1]

            w = img.shape[2]

            mask = np.ones((h, w), np.float32)

            for n in range(self.n_holes):

                y = np.random.randint(h)

                x = np.random.randint(w)

                y1 = np.clip(y - self.length // 2, 0, h)

                y2 = np.clip(y + self.length // 2, 0, h)

                x1 = np.clip(x - self.length // 2, 0, w)

                x2 = np.clip(x + self.length // 2, 0, w)

                mask[y1: y2, x1: x2] = 0.

            mask = torch.from_numpy(mask)

            mask = mask.expand_as(img)

            img = img * mask

            image = img

        return image, targetclass RandomErasing(object):

    '''

    Class that performs Random Erasing in Random Erasing Data Augmentation by Zhong et al.

    probability: The probability that the operation will be performed.

    sl: min erasing area

    sh: max erasing area

    r1: min aspect ratio

    mean: erasing value

    '''

    def __init__(self, sl=0.01, sh=0.25, r1=0.3, mean=[0.4914, 0.4822, 0.4465]):

        self.mean = mean

        self.sl = sl

        self.sh = sh

        self.r1 = r1

    def __call__(self, image, target):

        if random.random() < 0.3:

            image = np.array(image)

            boxes = target["boxes"].numpy()

            area_box = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])

            for attempt in range(100):

                area = image.shape[1] * image.shape[2]

                target_area = random.uniform(self.sl, self.sh) * area

                aspect_ratio = random.uniform(self.r1, 1 / self.r1)

                if target_area > area_box.all() * 3:

                    break

                h = int(round(math.sqrt(target_area * aspect_ratio)))

                w = int(round(math.sqrt(target_area / aspect_ratio)))

                if w < image.shape[2] and h < image.shape[1]:

                    x1 = random.randint(0, 跟單網https://www.gendan5.com/image.shape[1] - h)

                    y1 = random.randint(0, image.shape[2] - w)

                    if image.shape[0] == 3:

                        image[0, x1:x1 + h, y1:y1 + w] = self.mean[0]

                        image[1, x1:x1 + h, y1:y1 + w] = self.mean[1]

                        image[2, x1:x1 + h, y1:y1 + w] = self.mean[2]

                    else:

                        image[0, x1:x1 + h, y1:y1 + w] = self.mean[0]

            image = torch.from_numpy(image)

        return image, target

隨機裁剪

class Random_crop(object):

    def __call__(self, image, target):

        if random.random() < 0.3:

            boxes = target["boxes"]

            labels = target["labels"]

            image = np.array(image)

            image = np.transpose(image, (1, 2, 0))

            center = (boxes[:, 2:] + boxes[:, :2]) / 2

            height, width, c = image.shape

            h = random.uniform(0.6 * height, height)

            w = random.uniform(0.6 * width, width)

            x = random.uniform(0, width - w)

            y = random.uniform(0, height - h)

            x, y, h, w = int(x), int(y), int(h), int(w)

            center = center - torch.FloatTensor([[x, y]]).expand_as(center)

            mask1 = (center[:, 0] > 0) & (center[:, 0] < w)

            mask2 = (center[:, 1] > 0) & (center[:, 1] < h)

            mask = (mask1 & mask2).view(-1, 1)

            boxes_in = boxes[mask.expand_as(boxes)].view(-1, 4)

            # if (len(boxes_in) == 0):

            #     return image, boxes, labels

            box_shift = torch.FloatTensor([[x, y, x, y]]).expand_as(boxes_in)

            boxes_in = boxes_in - box_shift

            boxes_in[:, 0] = boxes_in[:, 0].clamp_(min=0, max=w)

            boxes_in[:, 2] = boxes_in[:, 2].clamp_(min=0, max=w)

            boxes_in[:, 1] = boxes_in[:, 1].clamp_(min=0, max=h)

            boxes_in[:, 3] = boxes_in[:, 3].clamp_(min=0, max=h)

            labels_in = labels[mask.view(-1)]

            img_croped = image[y:y + h, x:x + w, :]

            image = np.transpose(img_croped, (2, 0, 1))

            image = torch.from_numpy(image)

            target["labels"] = labels_in

            target["boxes"] = boxes_in

        return image, target


來自 “ ITPUB部落格 ” ,連結:http://blog.itpub.net/69946337/viewspace-2755843/,如需轉載,請註明出處,否則將追究法律責任。

相關文章