語義分割丨PSPNet原始碼解析「網路訓練」

vincent1997發表於2019-05-28

引言

之前一段時間在參與語義分割的專案,最近有時間了,正好把這段時間的所學總結一下。

在程式碼上,語義分割的框架會比目標檢測簡單很多,但其中也涉及了很多細節。在這篇文章中,我以PSPNet為例,解讀一下語義分割框架的程式碼。搞清楚一個框架後,再看別人的框架都是大同小異。

工程來自https://github.com/speedinghzl/pytorch-segmentation-toolbox

框架中一個非常重要的部分是evaluate.py,即測試階段。但由於篇幅較長,我將另開一篇來闡述測試過程,本文關注訓練過程。

整體框架

pytorch-segmentation-toolbox
    |— dataset      資料集相關
        |— list         存放資料集的list
        |— datasets.py  資料集載入函式
    |— libs         存放pytorch的op如bn
    |— networks     存放網路程式碼
        |— deeplabv3.py
        |— pspnet.py
    |— utils        其他函式
        |— criterion.py 損失計算
        |— encoding.py  視訊記憶體均勻
        |— loss.py      OHEM難例挖掘
        |— utils.py     colormap轉換
    |— evaluate.py  網路測試
    |— run_local.sh 訓練指令碼
    |— train.py     網路訓練

train.py

網路訓練主函式,主要操作有:

  1. 傳入訓練引數;通常採用argparse庫,支援指令碼傳入。
  2. 網路訓練;包括定義網路、載入模型、前向反向傳播、儲存模型等。
  3. 將訓練情況視覺化;使用tensorboard繪製loss曲線。
import argparse

import torch
import torch.nn as nn
from torch.utils import data
import numpy as np
import pickle
import cv2
import torch.optim as optim
import scipy.misc
import torch.backends.cudnn as cudnn
import sys
import os
from tqdm import tqdm
import os.path as osp
from networks.pspnet import Res_Deeplab
from dataset.datasets import CSDataSet

import random
import timeit
import logging
from tensorboardX import SummaryWriter
from utils.utils import decode_labels, inv_preprocess, decode_predictions
from utils.criterion import CriterionDSN, CriterionOhemDSN
from utils.encoding import DataParallelModel, DataParallelCriterion

torch_ver = torch.__version__[:3]
if torch_ver == '0.3':
    from torch.autograd import Variable

start = timeit.default_timer()

#由於使用了ImageNet的預訓練權重,因此需要在資料預處理過程減去ImageNet上的均值。
IMG_MEAN = np.array((104.00698793,116.66876762,122.67891434), dtype=np.float32)

#這些超引數可在sh指令碼中定義。
BATCH_SIZE = 8
DATA_DIRECTORY = 'cityscapes'
DATA_LIST_PATH = './dataset/list/cityscapes/train.lst'
IGNORE_LABEL = 255
INPUT_SIZE = '769,769'
LEARNING_RATE = 1e-2
MOMENTUM = 0.9
NUM_CLASSES = 19
NUM_STEPS = 40000
POWER = 0.9
RANDOM_SEED = 1234
RESTORE_FROM = './dataset/MS_DeepLab_resnet_pretrained_init.pth'
SAVE_NUM_IMAGES = 2
SAVE_PRED_EVERY = 10000
SNAPSHOT_DIR = './snapshots/'
WEIGHT_DECAY = 0.0005

def str2bool(v):
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

def get_arguments():
    """Parse all the arguments provided from the CLI.
    
    Returns:
      A list of parsed arguments.
    """
    parser = argparse.ArgumentParser(description="DeepLab-ResNet Network")
    parser.add_argument("--batch-size", type=int, default=BATCH_SIZE,       #Batch Size
                        help="Number of images sent to the network in one step.")
    parser.add_argument("--data-dir", type=str, default=DATA_DIRECTORY,     #資料集地址
                        help="Path to the directory containing the PASCAL VOC dataset.")
    parser.add_argument("--data-list", type=str, default=DATA_LIST_PATH,    #資料集清單
                        help="Path to the file listing the images in the dataset.")
    parser.add_argument("--ignore-label", type=int, default=IGNORE_LABEL,   #忽略類別(未使用)
                        help="The index of the label to ignore during the training.")
    parser.add_argument("--input-size", type=str, default=INPUT_SIZE,       #輸入尺寸
                        help="Comma-separated string with height and width of images.")
    parser.add_argument("--is-training", action="store_true",               #是否訓練   若不傳入為false
                        help="Whether to updates the running means and variances during the training.")
    parser.add_argument("--learning-rate", type=float, default=LEARNING_RATE,   #學習率
                        help="Base learning rate for training with polynomial decay.")
    parser.add_argument("--momentum", type=float, default=MOMENTUM,         #動量係數,用於優化引數
                        help="Momentum component of the optimiser.")
    parser.add_argument("--not-restore-last", action="store_true",          #是否儲存最後一層(未使用)
                        help="Whether to not restore last (FC) layers.")
    parser.add_argument("--num-classes", type=int, default=NUM_CLASSES,     #類別數
                        help="Number of classes to predict (including background).")
    parser.add_argument("--start-iters", type=int, default=0,               #起始iter數
                        help="Number of classes to predict (including background).")
    parser.add_argument("--num-steps", type=int, default=NUM_STEPS,         #訓練步數   
                        help="Number of training steps.")
    parser.add_argument("--power", type=float, default=POWER,               #power係數,用於更新學習率
                        help="Decay parameter to compute the learning rate.")
    parser.add_argument("--random-mirror", action="store_true",             #資料增強 翻轉
                        help="Whether to randomly mirror the inputs during the training.")
    parser.add_argument("--random-scale", action="store_true",              #資料增強 多尺度
                        help="Whether to randomly scale the inputs during the training.")
    parser.add_argument("--random-seed", type=int, default=RANDOM_SEED,     #隨機種子
                        help="Random seed to have reproducible results.")
    parser.add_argument("--restore-from", type=str, default=RESTORE_FROM,   #模型斷點續跑
                        help="Where restore model parameters from.")
    parser.add_argument("--save-num-images", type=int, default=SAVE_NUM_IMAGES, #儲存多少張圖片(未使用)
                        help="How many images to save.")
    parser.add_argument("--save-pred-every", type=int, default=SAVE_PRED_EVERY, #每多少次儲存一次斷點
                        help="Save summaries and checkpoint every often.")
    parser.add_argument("--snapshot-dir", type=str, default=SNAPSHOT_DIR,       #模型儲存位置
                        help="Where to save snapshots of the model.")
    parser.add_argument("--weight-decay", type=float, default=WEIGHT_DECAY,     #權重衰減係數,用於正則化
                        help="Regularisation parameter for L2-loss.")
    parser.add_argument("--gpu", type=str, default='None',                      #使用哪些GPU
                        help="choose gpu device.")
    parser.add_argument("--recurrence", type=int, default=1,                #迴圈次數(未使用)
                        help="choose the number of recurrence.")
    parser.add_argument("--ft", type=bool, default=False,                   #微調模型(未使用)
                        help="fine-tune the model with large input size.")

    parser.add_argument("--ohem", type=str2bool, default='False',           #難例挖掘
                        help="use hard negative mining")
    parser.add_argument("--ohem-thres", type=float, default=0.6,
                        help="choose the samples with correct probability underthe threshold.")
    parser.add_argument("--ohem-keep", type=int, default=200000,
                        help="choose the samples with correct probability underthe threshold.")
    return parser.parse_args()

args = get_arguments()  #載入引數

#poly學習策略
def lr_poly(base_lr, iter, max_iter, power):
    return base_lr*((1-float(iter)/max_iter)**(power))
            
#調整學習率
def adjust_learning_rate(optimizer, i_iter):
    """Sets the learning rate to the initial LR divided by 5 at 60th, 120th and 160th epochs"""
    lr = lr_poly(args.learning_rate, i_iter, args.num_steps, args.power)
    optimizer.param_groups[0]['lr'] = lr
    return lr

#將BN設定為測試狀態
def set_bn_eval(m):
    classname = m.__class__.__name__
    if classname.find('BatchNorm') != -1:
        m.eval()

#設定BN動量
def set_bn_momentum(m):
    classname = m.__class__.__name__
    if classname.find('BatchNorm') != -1 or classname.find('InPlaceABN') != -1:
        m.momentum = 0.0003

#網路訓練主函式
def main():
    """Create the model and start the training."""
    writer = SummaryWriter(args.snapshot_dir)   #定義SummaryWriter物件來視覺化訓練情況。
    
    if not args.gpu == 'None':
        os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu
    h, w = map(int, args.input_size.split(',')) #769, 769
    input_size = (h, w) #(769, 769)

    cudnn.enabled = True

    # Create network.
    deeplab = Res_Deeplab(num_classes=args.num_classes) #定義網路
    print(deeplab)

    saved_state_dict = torch.load(args.restore_from)    #載入模型   saved_state_dict['conv1.weight'] = {Tensor}
    new_params = deeplab.state_dict().copy()    #模態字典,建立層與引數的對映關係   new_params['conv1.weight']={Tensor}
    for i in saved_state_dict:  #剔除預訓練模型中的全連線層部分
        #Scale.layer5.conv2d_list.3.weight
        i_parts = i.split('.')  #['conv1', 'weight', '2']
        # print i_parts
        # if not i_parts[1]=='layer5':
        if not i_parts[0]=='fc':
            new_params['.'.join(i_parts[0:])] = saved_state_dict[i]
    
    deeplab.load_state_dict(new_params) #剔除後,載入模態字典,完成模型載入
    #deeplab.load_state_dict(torch.load(args.restore_from)) #若無需剔除

    model = DataParallelModel(deeplab)  #多GPU並行處理
    model.train()   #設定訓練模式,在evaluate.py中是model.eval()
    model.float()
    # model.apply(set_bn_momentum)
    model.cuda()    #會將模型載入到0號gpu上並作為主GPU,也可自己指定
    #model = model.cuda(device_ids[0])

    if args.ohem:   #是否採用難例挖掘
        criterion = CriterionOhemDSN(thresh=args.ohem_thres, min_kept=args.ohem_keep)
    else:
        criterion = CriterionDSN() #CriterionCrossEntropy()
    criterion = DataParallelCriterion(criterion)    #多GPU機器均衡負載
    criterion.cuda()    #優化器也放在gpu上
    
    cudnn.benchmark = True  #可以提升一點訓練速度,沒有額外開銷,一般都會加

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    #資料載入,該部分見datasets.py
    trainloader = data.DataLoader(CSDataSet(args.data_dir, args.data_list, max_iters=args.num_steps*args.batch_size, crop_size=input_size, 
                    scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN), 
                    batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True)

    #優化器
    optimizer = optim.SGD([{'params': filter(lambda p: p.requires_grad, deeplab.parameters()), 'lr': args.learning_rate }], 
                lr=args.learning_rate, momentum=args.momentum,weight_decay=args.weight_decay)
    optimizer.zero_grad()   #清空上一步的殘餘更新引數值

    interp = nn.Upsample(size=input_size, mode='bilinear', align_corners=True)  #(未使用)

    for i_iter, batch in enumerate(trainloader):
        i_iter += args.start_iters  
        images, labels, _, _ = batch
        images = images.cuda()
        labels = labels.long().cuda()
        if torch_ver == "0.3":
            images = Variable(images)
            labels = Variable(labels)

        optimizer.zero_grad()   #清空上一步的殘餘更新引數值
        lr = adjust_learning_rate(optimizer, i_iter)    #調整學習率
        preds = model(images)   #[x, x_dsn]

        loss = criterion(preds, labels) #計算誤差
        loss.backward()     #誤差反向傳播
        optimizer.step()    #更新引數值

        #用之前定義的SummaryWriter物件在Tensorboard中繪製lr和loss曲線
        if i_iter % 100 == 0:
            writer.add_scalar('learning_rate', lr, i_iter)
            writer.add_scalar('loss', loss.data.cpu().numpy(), i_iter)

        #是否將訓練中途的結果視覺化
        # if i_iter % 5000 == 0:
        #     images_inv = inv_preprocess(images, args.save_num_images, IMG_MEAN)
        #     labels_colors = decode_labels(labels, args.save_num_images, args.num_classes)
        #     if isinstance(preds, list):
        #         preds = preds[0]
        #     preds_colors = decode_predictions(preds, args.save_num_images, args.num_classes)
        #     for index, (img, lab) in enumerate(zip(images_inv, labels_colors)):
        #         writer.add_image('Images/'+str(index), img, i_iter)
        #         writer.add_image('Labels/'+str(index), lab, i_iter)
        #         writer.add_image('preds/'+str(index), preds_colors[index], i_iter)

        print('iter = {} of {} completed, loss = {}'.format(i_iter, args.num_steps, loss.data.cpu().numpy()))

        if i_iter >= args.num_steps-1:  #儲存最終模型
            print('save model ...')
            torch.save(deeplab.state_dict(),osp.join(args.snapshot_dir, 'CS_scenes_'+str(args.num_steps)+'.pth'))
            break

        if i_iter % args.save_pred_every == 0:  #每隔一定步數儲存模型
            print('taking snapshot ...')
            torch.save(deeplab.state_dict(),osp.join(args.snapshot_dir, 'CS_scenes_'+str(i_iter)+'.pth'))   #僅儲存學習到的引數
            #torch.save(deeplab, PATH)  #儲存整個model及狀態

    end = timeit.default_timer()
    print(end-start,'seconds')

if __name__ == '__main__':
    main()

datasets.py

在pytorch中資料載入到模型的操作順序如下:

  1. 建立一個Dataset物件,一般過載__len____getitem__方法。__len__返回資料集大小,__getitem__支援索引,以便Dataset[i]獲取第i個樣本。
  2. 建立一個DataLoader物件,將Dataset作為引數傳入。
  3. 迴圈這個DataLoader物件,將img、label載入到模型中進行訓練。

這裡展示一個簡單的例子:

dataset = MyDataset()
dataloader = DataLoader(dataset)
num_epoches = 100
for epoch in range(num_epoches):
    for img, label in dataloader:       

我們還需在Dataset物件中定義資料預處理,這裡採用:

  1. 0.7-1.4倍的隨機尺度縮放

  2. 各通道減去ImageNet的均值
  3. 隨機crop下769x769大小
  4. 映象隨機翻轉

注意:為了讓Image和Label對應,也要對Label作相應的預處理,具體過程詳見程式碼。

import os
import os.path as osp
import numpy as np
import random
import collections
import torch
import torchvision
import cv2
from torch.utils import data

#Cityscapes資料集載入
#crop_size(769,769)、max_iters = num_steps * batch_size = 8 * 40000 = 320000
class CSDataSet(data.Dataset):
    def __init__(self, root, list_path, max_iters=None, crop_size=(321, 321), mean=(128, 128, 128), scale=True, mirror=True, ignore_label=255):
        self.root = root    #資料集地址
        self.list_path = list_path  #資料集列表
        self.crop_h, self.crop_w = crop_size    #剪裁尺寸
        self.scale = scale  #尺度
        self.ignore_label = ignore_label    #忽略類別
        self.mean = mean    #資料集各通道平均值
        self.is_mirror = mirror #是否映象
        # self.mean_bgr = np.array([104.00698793, 116.66876762, 122.67891434])
        self.img_ids = [i_id.strip().split() for i_id in open(list_path)]   #列表 存放每張影像及其標籤在資料集中的地址
        if not max_iters==None: #訓練時根據max_iter數將列表翻倍    if max_iter=320000、len(trainset)=2975
        #每一個iter訓練一張圖,要計算max_iter要訓練多少輪trainset
                self.img_ids = self.img_ids * int(np.ceil(float(max_iters) / len(self.img_ids)))    # 2975 * (32000/2975) = 321300
        self.files = [] #用來放資料的列表
        # for split in ["train", "trainval", "val"]:
        for item in self.img_ids:   #遍歷每一張訓練樣本
            image_path, label_path = item   #影像、標籤地址
            name = osp.splitext(osp.basename(label_path))[0]
            img_file = osp.join(self.root, image_path)
            label_file = osp.join(self.root, label_path)
            self.files.append({ #列表的每一項是一個字典
                "img": img_file,
                "label": label_file,
                "name": name            #aachen_000000_000019_leftImg8bit.png
            })
        #19類與官方給定類別的轉換
        self.id_to_trainid = {-1: ignore_label, 0: ignore_label, 1: ignore_label, 2: ignore_label,
                              3: ignore_label, 4: ignore_label, 5: ignore_label, 6: ignore_label,
                              7: 0, 8: 1, 9: ignore_label, 10: ignore_label, 11: 2, 12: 3, 13: 4,
                              14: ignore_label, 15: ignore_label, 16: ignore_label, 17: 5,
                              18: ignore_label, 19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11, 25: 12, 26: 13, 27: 14,
                              28: 15, 29: ignore_label, 30: ignore_label, 31: 16, 32: 17, 33: 18}
        print('{} images are loaded!'.format(len(self.img_ids)))

    def __len__(self):  #資料集長度
        return len(self.files)  #321300

    #生成不同尺度下的樣本和標籤
    def generate_scale_label(self, image, label):
        f_scale = 0.7 + random.randint(0, 14) / 10.0    # 0.7 + (0~1.4)
        image = cv2.resize(image, None, fx=f_scale, fy=f_scale, interpolation = cv2.INTER_LINEAR)
        label = cv2.resize(label, None, fx=f_scale, fy=f_scale, interpolation = cv2.INTER_NEAREST)
        return image, label

    #實現類別數和trainId的相互轉換:如第19類對應trainId 33
    def id2trainId(self, label, reverse=False):
        label_copy = label.copy()
        if reverse: #trainId2id
            for v, k in self.id_to_trainid.items():
                label_copy[label == k] = v
        else:   #id2trainId
            for k, v in self.id_to_trainid.items():
                label_copy[label == k] = v
        return label_copy

    #返回一張樣本
    def __getitem__(self, index):
        datafiles = self.files[index]
        image = cv2.imread(datafiles["img"], cv2.IMREAD_COLOR)  #shape(1024,2048,3)
        label = cv2.imread(datafiles["label"], cv2.IMREAD_GRAYSCALE)    #shape(1024,2048)
        label = self.id2trainId(label)  #label影像(-1~33) 轉化為陣列(0~19)
        size = image.shape  #[1024,2048,3]
        name = datafiles["name"]
        if self.scale:  #若採用多尺度
            image, label = self.generate_scale_label(image, label)
        image = np.asarray(image, np.float32)
        image -= self.mean  #減去均值
        img_h, img_w = label.shape  #1024, 2048
        pad_h = max(self.crop_h - img_h, 0) #max(769-1024, 0)
        pad_w = max(self.crop_w - img_w, 0) #max(769-2048, 0)
        if pad_h > 0 or pad_w > 0:  #若尺度縮放後的尺寸比crop_size尺寸小,則對邊界進行填充
            img_pad = cv2.copyMakeBorder(image, 0, pad_h, 0, 
                pad_w, cv2.BORDER_CONSTANT, 
                value=(0.0, 0.0, 0.0))
            label_pad = cv2.copyMakeBorder(label, 0, pad_h, 0, 
                pad_w, cv2.BORDER_CONSTANT,
                value=(self.ignore_label,))
        else:
            img_pad, label_pad = image, label

        img_h, img_w = label_pad.shape  #1024、2048
        h_off = random.randint(0, img_h - self.crop_h)  #生成隨機數如100
        w_off = random.randint(0, img_w - self.crop_w)  #20
        # roi = cv2.Rect(w_off, h_off, self.crop_w, self.crop_h);
        image = np.asarray(img_pad[h_off : h_off+self.crop_h, w_off : w_off+self.crop_w], np.float32)   #任意扣下([100:100+769, 20:20+769])
        label = np.asarray(label_pad[h_off : h_off+self.crop_h, w_off : w_off+self.crop_w], np.float32) #([100:100+769, 20:20+769])
        #image = image[:, :, ::-1]  # change to BGR
        image = image.transpose((2, 0, 1))  #shape(3, 769, 769)
        if self.is_mirror:  #映象隨機翻轉
            flip = np.random.choice(2) * 2 - 1  #flip = 1 or -1
            image = image[:, :, ::flip]
            label = label[:, ::flip]

        return image.copy(), label.copy(), np.array(size), name #image.shape(3, 769, 769)、label.shape(769, 769)

上面定義了一個Dataset物件CSDataSet,之後我們在train.py中定義DataLoader物件trainloader,並將CSDataSet作為引數傳入。

trainloader = data.DataLoader(CSDataSet(args.data_dir, args.data_list, max_iters=args.num_steps*args.batch_size, crop_size=input_size, 
                    scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN), 
                    batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True)

為更清楚這些引數的含義,可以參考一下DataLoader類的定義。

class DataLoader(object):
    r"""
    Data loader. Combines a dataset and a sampler, and provides
    single- or multi-process iterators over the dataset.

    Arguments:
        dataset(Dataset): 傳入的資料集
        batch_size(int, optional): 每個batch有多少個樣本
        shuffle(bool, optional): 在每個epoch開始的時候,對資料進行重新排序
        sampler(Sampler, optional): 自定義從資料集中取樣本的策略,如果指定這個引數,那麼shuffle必須為False
        batch_sampler(Sampler, optional): 與sampler類似,但是一次只返回一個batch的indices(索引),需要注意的是,一旦指定了這個引數,那麼batch_size,shuffle,sampler,drop_last就不能再製定了(互斥——Mutually exclusive)
        num_workers (int, optional): 這個引數決定了有幾個程式來處理data loading。0意味著所有的資料都會被load進主程式。(預設為0)
        collate_fn (callable, optional): 將一個list的sample組成一個mini-batch的函式
        pin_memory (bool, optional): 如果設定為True,那麼data loader將會在返回它們之前,將tensors拷貝到CUDA中的固定記憶體(CUDA pinned memory)中.

        drop_last (bool, optional): 如果設定為True:這個是對最後的未完成的batch來說的,比如你的batch_size設定為64,而一個epoch只有100個樣本,那麼訓練的時候後面的36個就被扔掉了…
        如果為False(預設),那麼會繼續正常執行,只是最後的batch_size會小一點。

        timeout(numeric, optional): 如果是正數,表明等待從worker程式中收集一個batch等待的時間,若超出設定的時間還沒有收集到,那就不收集這個內容了。這個numeric應總是大於等於0。預設為0
        worker_init_fn (callable, optional): 每個worker初始化函式 If not None, this will be called on each
        worker subprocess with the worker id (an int in [0, num_workers - 1]) as
        input, after seeding and before data loading. (default: None)

    .. note:: By default, each worker will have its PyTorch seed set to
              ``base_seed + worker_id``, where ``base_seed`` is a long generated
              by main process using its RNG. However, seeds for other libraies
              may be duplicated upon initializing workers (w.g., NumPy), causing
              each worker to return identical random numbers. (See
              :ref:`dataloader-workers-random-seed` section in FAQ.) You may
              use ``torch.initial_seed()`` to access the PyTorch seed for each
              worker in :attr:`worker_init_fn`, and use it to set other seeds
              before data loading.

    .. warning:: If ``spawn`` start method is used, :attr:`worker_init_fn` cannot be an
                 unpicklable object, e.g., a lambda function.
    """

    __initialized = False

    def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
                 num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,
                 timeout=0, worker_init_fn=None):
        self.dataset = dataset
        self.batch_size = batch_size
        self.num_workers = num_workers  
        self.collate_fn = collate_fn    
        self.pin_memory = pin_memory
        self.drop_last = drop_last
        self.timeout = timeout
        self.worker_init_fn = worker_init_fn

        if timeout < 0:
            raise ValueError('timeout option should be non-negative')

        if batch_sampler is not None:
            if batch_size > 1 or shuffle or sampler is not None or drop_last:
                raise ValueError('batch_sampler option is mutually exclusive '
                                 'with batch_size, shuffle, sampler, and '
                                 'drop_last')
            self.batch_size = None
            self.drop_last = None

        if sampler is not None and shuffle:
            raise ValueError('sampler option is mutually exclusive with '
                             'shuffle')

        if self.num_workers < 0:
            raise ValueError('num_workers option cannot be negative; '
                             'use num_workers=0 to disable multiprocessing.')

        if batch_sampler is None:
            if sampler is None:
                if shuffle:
                    sampler = RandomSampler(dataset)  //將list打亂
                else:
                    sampler = SequentialSampler(dataset)
            batch_sampler = BatchSampler(sampler, batch_size, drop_last)

        self.sampler = sampler
        self.batch_sampler = batch_sampler
        self.__initialized = True

    def __setattr__(self, attr, val):
        if self.__initialized and attr in ('batch_size', 'sampler', 'drop_last'):
            raise ValueError('{} attribute should not be set after {} is '
                             'initialized'.format(attr, self.__class__.__name__))

        super(DataLoader, self).__setattr__(attr, val)

    def __iter__(self):
        return _DataLoaderIter(self)

    def __len__(self):
        return len(self.batch_sampler)

pspnet.py

在pytorch中自定義網路,整合nn.Module類並過載__init__(self)forward,分別定義網路組成和前向傳播,這裡有一個簡單的例子。

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
       x = F.relu(self.conv1(x))
       return F.relu(self.conv2(x))

下面先看一下PSPNet的論文介紹,網路結構非常簡單,在ResNet之後接一個PPM模組。

1559011591663

此外PSPNet還採用了輔助損失分支。

1559026591118

import torch.nn as nn
from torch.nn import functional as F
import math
import torch.utils.model_zoo as model_zoo
import torch
import numpy as np
from torch.autograd import Variable
affine_par = True
import functools

import sys, os

from libs import InPlaceABN, InPlaceABNSync
BatchNorm2d = functools.partial(InPlaceABNSync, activation='none')

def conv3x3(in_planes, out_planes, stride=1):
    "3x3 convolution with padding"
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)

#ResNet的Bottleneck
class Bottleneck(nn.Module):
    expansion = 4
    def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, fist_dilation=1, multi_grid=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=dilation*multi_grid, dilation=dilation*multi_grid, bias=False)
        self.bn2 = BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=False)
        self.relu_inplace = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.dilation = dilation
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out = out + residual      
        out = self.relu_inplace(out)

        return out

#PPM模組
class PSPModule(nn.Module):
    """
    Reference: 
        Zhao, Hengshuang, et al. *"Pyramid scene parsing network."*
    """
    def __init__(self, features, out_features=512, sizes=(1, 2, 3, 6)):
        super(PSPModule, self).__init__()

        self.stages = []
        self.stages = nn.ModuleList([self._make_stage(features, out_features, size) for size in sizes])
        self.bottleneck = nn.Sequential(
            nn.Conv2d(features+len(sizes)*out_features, out_features, kernel_size=3, padding=1, dilation=1, bias=False),
            InPlaceABNSync(out_features),
            nn.Dropout2d(0.1)
            )

    def _make_stage(self, features, out_features, size):
        prior = nn.AdaptiveAvgPool2d(output_size=(size, size))
        conv = nn.Conv2d(features, out_features, kernel_size=1, bias=False)
        bn = InPlaceABNSync(out_features)
        return nn.Sequential(prior, conv, bn)

    def forward(self, feats):
        h, w = feats.size(2), feats.size(3)
        priors = [F.upsample(input=stage(feats), size=(h, w), mode='bilinear', align_corners=True) for stage in self.stages] + [feats]
        bottle = self.bottleneck(torch.cat(priors, 1))
        return bottle

#PSPNet網路整體
class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes):
        self.inplanes = 128
        super(ResNet, self).__init__()
        self.conv1 = conv3x3(3, 64, stride=2)
        self.bn1 = BatchNorm2d(64)
        self.relu1 = nn.ReLU(inplace=False)
        self.conv2 = conv3x3(64, 64)
        self.bn2 = BatchNorm2d(64)
        self.relu2 = nn.ReLU(inplace=False)
        self.conv3 = conv3x3(64, 128)
        self.bn3 = BatchNorm2d(128)
        self.relu3 = nn.ReLU(inplace=False)
        #
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.relu = nn.ReLU(inplace=False)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=True) # change
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4, multi_grid=(1,1,1))

        
        self.head = nn.Sequential(PSPModule(2048, 512),
            nn.Conv2d(512, num_classes, kernel_size=1, stride=1, padding=0, bias=True))

        #輔助損失
        self.dsn = nn.Sequential(
            nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1),
            InPlaceABNSync(512),
            nn.Dropout2d(0.1),
            nn.Conv2d(512, num_classes, kernel_size=1, stride=1, padding=0, bias=True)
            )

    def _make_layer(self, block, planes, blocks, stride=1, dilation=1, multi_grid=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                BatchNorm2d(planes * block.expansion,affine = affine_par))

        layers = []
        generate_multi_grid = lambda index, grids: grids[index%len(grids)] if isinstance(grids, tuple) else 1
        layers.append(block(self.inplanes, planes, stride,dilation=dilation, downsample=downsample, multi_grid=generate_multi_grid(0, multi_grid)))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, dilation=dilation, multi_grid=generate_multi_grid(i, multi_grid)))

        return nn.Sequential(*layers)

    def forward(self, x):   #(1,3,769,769)
        x = self.relu1(self.bn1(self.conv1(x))) #(1,64,385,385)
        x = self.relu2(self.bn2(self.conv2(x))) #(1,64,385,385)
        x = self.relu3(self.bn3(self.conv3(x))) #(1,128,385,385)
        x = self.maxpool(x) #(1,128,193,193)
        x = self.layer1(x)  #(1,256,97,97)
        x = self.layer2(x)  #(1,512,97,97)
        x = self.layer3(x)  #(1,1024,97,97)
        x_dsn = self.dsn(x) #(1,19,97,97)
        x = self.layer4(x)  #(1,2048,97,97)
        x = self.head(x)    #(1,19,769,769)
        return [x, x_dsn]

    def Res_Deeplab(num_classes=21):
    model = ResNet(Bottleneck,[3, 4, 23, 3], num_classes)
    return model

PSPNet輸入1x3x769x769,1為BS、3為RGB通道、769為cropsize。並有兩個輸出1x19x97x97和1x19x769x769,19為類別數,預測了每個位置屬於各類的概率。(注意這裡尚未softmax,概率之和不為1)。

criterion.py

語義分割的損失函式主要是交叉熵。由於採用了輔助損失,所以Loss應該包含兩部分。

\(total\_loss=\alpha \cdot loss1+\beta \cdot loss2\)

此外,這裡還定義了OHEM的損失計算,具體實現請看loss.py

import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo
import torch
import numpy as np
from torch.nn import functional as F
from torch.autograd import Variable
from .loss import OhemCrossEntropy2d
import scipy.ndimage as nd

class CriterionDSN(nn.Module):
    '''
    DSN : We need to consider two supervision for the model.
    我們需要考慮兩種損失
    '''
    def __init__(self, ignore_index=255, use_weight=True, reduce=True):
        super(CriterionDSN, self).__init__()
        self.ignore_index = ignore_index
        #交叉熵計算Loss,忽略了255類,並且對Loss取了平均
        self.criterion = torch.nn.CrossEntropyLoss(ignore_index=ignore_index, reduce=reduce)
        if not reduce:
            print("disabled the reduce.")

    #criterion(preds, labels)
    def forward(self, preds, target):
        h, w = target.size(1), target.size(2)   #769, 769

        scale_pred = F.upsample(input=preds[0], size=(h, w), mode='bilinear', align_corners=True)
        loss1 = self.criterion(scale_pred, target)

        scale_pred = F.upsample(input=preds[1], size=(h, w), mode='bilinear', align_corners=True)
        loss2 = self.criterion(scale_pred, target)

        return loss1 + loss2*0.4

#採用難例挖掘
class CriterionOhemDSN(nn.Module):
    '''
    DSN : We need to consider two supervision for the model.
    '''
    def __init__(self, ignore_index=255, thresh=0.7, min_kept=100000, use_weight=True, reduce=True):
        super(CriterionOhemDSN, self).__init__()
        self.ignore_index = ignore_index
        self.criterion1 = OhemCrossEntropy2d(ignore_index, thresh, min_kept)    #採用了新的計算方式
        self.criterion2 = torch.nn.CrossEntropyLoss(ignore_index=ignore_index, reduce=reduce)

    def forward(self, preds, target):
        h, w = target.size(1), target.size(2)   #769, 769

        scale_pred = F.upsample(input=preds[0], size=(h, w), mode='bilinear', align_corners=True)
        loss1 = self.criterion1(scale_pred, target)

        scale_pred = F.upsample(input=preds[1], size=(h, w), mode='bilinear', align_corners=True)
        loss2 = self.criterion2(scale_pred, target)

        return loss1 + loss2*0.4

loss.py

OHEM目的是篩選出困難樣本來訓練模型,從而提升效能,其有兩個超引數:\(\theta\)\(K\)

困難樣本被定義為預測概率小於$\theta \(的畫素,並且每個*minibatch*至少保證\)K$個困難樣本。

1559028836177

具體實現是將pspnet的輸出經過softmax,然後進行兩次篩選。第一次篩選基於label的有效區域(非255),predict上255對應的區域將不納入loss的計算。經第一次篩選,將label中對應predict概率大於0.7的區域也置為255。最後只有剩餘區域將參與loss的計算。

import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import scipy.ndimage as nd


class OhemCrossEntropy2d(nn.Module):

    def __init__(self, ignore_label=255, thresh=0.7, min_kept=100000, factor=8):
        super(OhemCrossEntropy2d, self).__init__()
        self.ignore_label = ignore_label    #忽略類別255
        self.thresh = float(thresh)         #閾值0.7
        # self.min_kept_ratio = float(min_kept_ratio)
        self.min_kept = int(min_kept)       #
        self.factor = factor
        self.criterion = torch.nn.CrossEntropyLoss(ignore_index=ignore_label)

    #尋找閾值
    #np_predict.shape(1, 19, 769, 769)、np_target.shape(1, 769, 769)
    """
    閾值的選取主要基於min_kept,用第min_kept個的概率來確定。
    且返回的閾值只能 ≥ thresh。
    """
    def find_threshold(self, np_predict, np_target):
        # downsample 1/8
        factor = self.factor    #8
        predict = nd.zoom(np_predict, (1.0, 1.0, 1.0/factor, 1.0/factor), order=1)  #雙線性插值  shape(1, 19, 96, 96)
        target = nd.zoom(np_target, (1.0, 1.0/factor, 1.0/factor), order=0) #最近臨插值  shape(1, 96, 96)

        n, c, h, w = predict.shape  #1, 19, 96, 96
        min_kept = self.min_kept // (factor*factor) #int(self.min_kept_ratio * n * h * w)   #100000/64 = 1562

        input_label = target.ravel().astype(np.int32)   #將多維陣列轉化為一維 shape(9216, )
        input_prob = np.rollaxis(predict, 1).reshape((c, -1))   #軸1滾動到軸0、shape(19, 9216)

        valid_flag = input_label != self.ignore_label   #label中有效位置(9216, )
        valid_inds = np.where(valid_flag)[0]    #(9013, )
        label = input_label[valid_flag] #有效label(9013, )
        num_valid = valid_flag.sum()    #9013
        if min_kept >= num_valid:   #1562 >= 9013
            threshold = 1.0
        elif num_valid > 0: #9013 > 0
            prob = input_prob[:,valid_flag] #(19, 9013) #找出有效區域對應的prob
            pred = prob[label, np.arange(len(label), dtype=np.int32)]   #???    shape(9013, )
            threshold = self.thresh     #0.7
            if min_kept > 0:    #1562>0
                k_th = min(len(pred), min_kept)-1   #min(9013, 1562)-1 = 1561
                new_array = np.partition(pred, k_th)    #排序並分成兩個區,小於第1561個及大於第1561個
                new_threshold = new_array[k_th]     #第1561對應的pred 0.03323581
                if new_threshold > self.thresh:     #返回的閾值只能≥0.7
                    threshold = new_threshold
        return threshold

    #生成新的labels
    #predict.shape(1, 19, 97, 97)、target.shape(1, 97, 97)
    """
    主要思路
        1先通過find_threshold找到一個合適的閾值如0.7
        2一次篩選出不為255的區域
        3再從中二次篩選找出對應預測值小於0.7的區域
        4重新生成一個label,label把預測值大於0.7和原本為255的位置 都置為255
    """
    def generate_new_target(self, predict, target):
        np_predict = predict.data.cpu().numpy() #shape(1, 19, 769, 769)
        np_target = target.data.cpu().numpy()   #shape(1, 769, 769)
        n, c, h, w = np_predict.shape   #1, 19, 769, 769

        threshold = self.find_threshold(np_predict, np_target)  #尋找閾值0.7

        input_label = np_target.ravel().astype(np.int32)    #shape(591361, )
        input_prob = np.rollaxis(np_predict, 1).reshape((c, -1))    #(19, 591361)

        valid_flag = input_label != self.ignore_label   #label中有效位置(591361, )
        valid_inds = np.where(valid_flag)[0]    #(579029, )
        label = input_label[valid_flag] #一次篩選:不為255的label(579029, )
        num_valid = valid_flag.sum()    #579029

        if num_valid > 0:
            prob = input_prob[:,valid_flag] #(19, 579029)
            pred = prob[label, np.arange(len(label), dtype=np.int32)]   #不明白這一步的操作??? (579029, )
            kept_flag = pred <= threshold   #二次篩選:在255中找出pred≤0.7的位置
            valid_inds = valid_inds[kept_flag]  #shape(579029, )
            print('Labels: {} {}'.format(len(valid_inds), threshold))

        label = input_label[valid_inds].copy()  #從原label上扣下來shape(579029, )
        input_label.fill(self.ignore_label) #shape(591361, )每個值都為255
        input_label[valid_inds] = label #把二次篩選後有效區域的對應位置為label,其餘為255
        new_target = torch.from_numpy(input_label.reshape(target.size())).long().cuda(target.get_device())  #shape(1, 769, 769)

        return new_target   #shape(1, 769, 769)


    def forward(self, predict, target, weight=None):
        """
            Args:
                predict:(n, c, h, w)    (1, 19, 97, 97)
                target:(n, h, w)        (1, 97, 97)
                weight (Tensor, optional): a manual rescaling weight given to each class.
                                           If given, has to be a Tensor of size "nclasses"
        """
        assert not target.requires_grad

        input_prob = F.softmax(predict, 1)  #在channel上進行一次softmax,得到概率
        target = self.generate_new_target(input_prob, target)   #生成新labels
        return self.criterion(predict, target)

參考

Zhao H, Shi J, Qi X, et al. Pyramid scene parsing network[C]//Proceedings of the IEEE conference on computer vision and pattern recognition. 2017: 2881-2890.

Yuan Y, Wang J. Ocnet: Object context network for scene parsing[J]. arXiv preprint arXiv:1809.00916, 2018.

相關文章