Fast-RCNN解析:訓練階段程式碼導讀
這一週開始接觸RCNN相關的技術,希望用它來進行物體定位方面的研究。現記錄一些學習心得,以備查詢。——jeremy@gz
關於Fast-RCNN的解析,我們將主要分為兩個部分來介紹,其中一個是訓練部分,這個部分非常重要,是我們需要重點講解的;另一個是測試部分,這個部分關係到具體的應用,所以也是必須要了解的。本篇博文中,我們先從訓練部分講起。
訓練階段流程
在官方文件中,訓練階段的啟動指令碼如下所示:
./tools/train_net.py --gpu 0 --solver models/VGG16/solver.prototxt \
--weights data/imagenet_models/VGG16.v2.caffemodel
從這段指令碼中,我們可以知道,訓練的入口函式就在train_net.py中,其位於fast-rcnn/tools/資料夾內,我們先來看看這個檔案。
if __name__ == '__main__':
args = parse_args()
print('Called with args:')
print(args)
if args.cfg_file is not None:
cfg_from_file(args.cfg_file)
if args.set_cfgs is not None:
cfg_from_list(args.set_cfgs)
print('Using config:')
pprint.pprint(cfg)
if not args.randomize:
# fix the random seeds (numpy and caffe) for reproducibility
np.random.seed(cfg.RNG_SEED)
caffe.set_random_seed(cfg.RNG_SEED)
# set up caffe
caffe.set_mode_gpu()
if args.gpu_id is not None:
caffe.set_device(args.gpu_id)
imdb = get_imdb(args.imdb_name)
print 'Loaded dataset `{:s}` for training'.format(imdb.name)
roidb = get_training_roidb(imdb)
output_dir = get_output_dir(imdb, None)
print 'Output will be saved to `{:s}`'.format(output_dir)
train_net(args.solver, roidb, output_dir,
pretrained_model=args.pretrained_model,
max_iters=args.max_iters)
從以上的code,我們可以看到,train_net.py的主要處理過程包括以下三個部分:
(1) 首先對啟動指令碼的輸入引數進行處理,是通過如下這個函式parse_args()進行處理的。
def parse_args():
"""
Parse input arguments
"""
parser = argparse.ArgumentParser(description='Train a Fast R-CNN network')
parser.add_argument('--gpu', dest='gpu_id',
help='GPU device id to use [0]', default=0, type=int)
parser.add_argument('--solver', dest='solver',
help='solver prototxt', default=None, type=str)
parser.add_argument('--iters', dest='max_iters',
help='number of iterations to train',default=40000, type=int)
parser.add_argument('--weights', dest='pretrained_model',
help='initialize with pretrained model weights', default=None, type=str)
parser.add_argument('--cfg', dest='cfg_file',
help='optional config file',default=None, type=str)
parser.add_argument('--imdb', dest='imdb_name',
help='dataset to train on',default='voc_2007_trainval', type=str)
parser.add_argument('--rand', dest='randomize',
help='randomize (do not use a fixed seed)',action='store_true')
parser.add_argument('--set', dest='set_cfgs',
help='set config keys', default=None,nargs=argparse.REMAINDER)
if len(sys.argv) == 1:
parser.print_help()
sys.exit(1)
args = parser.parse_args()
return args
從這個函式中,我們可以瞭解到,訓練指令碼的可選輸入引數包括:
- –gpu: 這個引數指定訓練使用的GPU裝置,我的電腦只有一枚GPU,預設情況下自動開啟,其gpu_id為0;
- –solver: 這個引數指定網路的優化方法,並在其solver的prototxt指向了定義網路結構的檔案(train.prototxt);
- –weights: 這個引數指定了finetune的初始引數,我的電腦GPU不怎麼高階,只能使用caffenet進行finetune;
- –imdb: 這個引數指定了訓練所需要的訓練資料,如果你需要訓練自己的資料,那麼這個引數是必須要指定的;
(2) 然後是根據輸入的引數(–imdb 引數後面指定的資料)來準備訓練樣本,這個步驟涉及到兩個函式:一個 imdb=get_imdb(args.imdb_name)
, 另一個是roidb=get_training_roidb(imdb)
。關於這兩個函式我們下部分會花大時間來解析,這裡先不談。
(3) 最後就是訓練函式:train_net(args.solver,roidb, output_dir, pretrained_model= args.pretrained_model, max_iters= args.max_iters)
而這個 train_net() 函式是從 fast_rcnn/lib/fast_rcnn 資料夾中的 train.py 中 import 進來的。那麼接下來,我們來看看這個train.py
這個函式主要由一個類SolverWrapper和兩個函式get_training_roidb()和train_net()組成。
首先,我們來看看train_net()函式:
def train_net(solver_prototxt, roidb, output_dir,
pretrained_model=None, max_iters=40000):
"""Train a Fast R-CNN network."""
sw = SolverWrapper(solver_prototxt, roidb, output_dir,
pretrained_model=pretrained_model)
print 'Solving...'
sw.train_model(max_iters)
print 'done solving'
可以發現,該函式是通過呼叫類SolverWrapper來實現其主要功能的,因此,我們跟進到類SolverWrapper的類建構函式中去:
def __init__(self, solver_prototxt, roidb, output_dir,
pretrained_model=None):
"""Initialize the SolverWrapper."""
self.output_dir = output_dir
print 'Computing bounding-box regression targets...'
self.bbox_means, self.bbox_stds = \
rdl_roidb.add_bbox_regression_targets(roidb)
print 'done'
self.solver = caffe.SGDSolver(solver_prototxt)
if pretrained_model is not None:
print ('Loading pretrained model '
'weights from {:s}').format(pretrained_model)
self.solver.net.copy_from(pretrained_model)
self.solver_param = caffe_pb2.SolverParameter()
with open(solver_prototxt, 'rt') as f:
pb2.text_format.Merge(f.read(), self.solver_param)
self.solver.net.layers[0].set_roidb(roidb)
初始化完成後,就是要呼叫train_model函式來進行網路訓練,我們來看一下它的主體部分:
def train_model(self, max_iters):
"""Network training loop."""
last_snapshot_iter = -1
timer = Timer()
while self.solver.iter < max_iters:
# Make one SGD update
timer.tic()
self.solver.step(1)
timer.toc()
if self.solver.iter % (10 * self.solver_param.display) == 0:
print 'speed: {:.3f}s / iter'.format(timer.average_time)
if self.solver.iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
last_snapshot_iter = self.solver.iter
self.snapshot()
if last_snapshot_iter != self.solver.iter:
self.snapshot()
到此為止,網路就可以開始訓練了。
訓練資料處理
不過,關於Fast-RCNN的重頭戲我們其實還沒開始——那就是如何準備訓練資料。
在上面介紹訓練的流程中,與此相關的函式是:imdb= get_imdb(args.imdb_name)
這個函式是從從lib/datasets/資料夾中的factory.py中import進來的,我們來看一下這個函式:
def get_imdb(name):
"""Get an imdb (image database) by name."""
if not __sets.has_key(name):
raise KeyError('Unknown dataset: {}'.format(name))
return __sets[name]()
這個函式很簡單,其實就是根據字典的key來取得訓練資料。
那麼這個字典是怎麼形成的呢?看下面:
inria_devkit_path = '/home/jeremy/jWork/frcn/fast-rcnn/data/INRIA/'
for split in ['train', 'test']:
name = '{}_{}'.format('inria', split)
__sets[name] = (lambda split=split: datasets.inria(split, inria_devkit_path))
它本質上是通過lib/datasets/資料夾下面的inria.py引入的。
所以,現在我們就得開始進入inria.py(這個函式需要我們自己編寫,可以參考pascal_voc.py編寫)。
首先,我們來看看類inria的建構函式:
def __init__(self, image_set, devkit_path):
datasets.imdb.__init__(self, image_set)
self._image_set = image_set
self._devkit_path = devkit_path
self._data_path = os.path.join(self._devkit_path, 'data')
self._classes = ('__background__', # always index 0
'1001')
self._class_to_ind = dict(zip(self.classes, xrange(self.num_classes)))
self._image_ext = ['.jpg', '.png']
self._image_index = self._load_image_set_index()
# Default to roidb handler
self._roidb_handler = self.selective_search_roidb
# Specific config options
self.config = {'cleanup' : True,
'use_salt' : True,
'top_k' : 2000}
assert os.path.exists(self._devkit_path), \
'Devkit path does not exist: {}'.format(self._devkit_path)
assert os.path.exists(self._data_path), \
'Path does not exist: {}'.format(self._data_path)
這裡面最要注意的是要根據自己訓練的類別同步修改self._classes,我這裡面只有兩類。
類 inria 構造完成後,會呼叫函式 roidb,這個函式是從類 imdb 中繼承過來的,這個函式會呼叫 _roidb_handler 來處理,其中 _roidb_handler=self.selective_search_roidb,下面我們來看看這個函式:
def selective_search_roidb(self):
"""
Return the database of selective search regions of interest.
Ground-truth ROIs are also included.
This function loads/saves from/to a cache file to speed up future calls.
"""
cache_file = os.path.join(self.cache_path,
self.name + '_selective_search_roidb.pkl')
if os.path.exists(cache_file):
with open(cache_file, 'rb') as fid:
roidb = cPickle.load(fid)
print '{} ss roidb loaded from {}'.format(self.name, cache_file)
return roidb
if self._image_set != 'test':
gt_roidb = self.gt_roidb()
ss_roidb = self._load_selective_search_roidb(gt_roidb)
roidb = datasets.imdb.merge_roidbs(gt_roidb, ss_roidb)
else:
roidb = self._load_selective_search_roidb(None)
print len(roidb)
with open(cache_file, 'wb') as fid:
cPickle.dump(roidb, fid, cPickle.HIGHEST_PROTOCOL)
print 'wrote ss roidb to {}'.format(cache_file)
return roidb
這個函式在訓練階段會首先呼叫get_roidb()
函式:
def gt_roidb(self):
"""
Return the database of ground-truth regions of interest.
This function loads/saves from/to a cache file to speed up future calls.
"""
cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl')
if os.path.exists(cache_file):
with open(cache_file, 'rb') as fid:
roidb = cPickle.load(fid)
print '{} gt roidb loaded from {}'.format(self.name, cache_file)
return roidb
gt_roidb = [self._load_inria_annotation(index)
for index in self.image_index]
with open(cache_file, 'wb') as fid:
cPickle.dump(gt_roidb, fid, cPickle.HIGHEST_PROTOCOL)
print 'wrote gt roidb to {}'.format(cache_file)
return gt_roidb
如果存在cache_file,那麼get_roidb()就會直接從cache_file中讀取資訊;如果不存在cache_file,那麼會呼叫_load_inria_annotation()來取得標註資訊。_load_inria_annotation函式如下所示:
def _load_inria_annotation(self, index):
"""
Load image and bounding boxes info from txt files of INRIA Person.
"""
filename = os.path.join(self._data_path, 'Annotations', index + '.xml')
print 'Loading: {}'.format(filename)
def get_data_from_tag(node, tag):
return node.getElementsByTagName(tag)[0].childNodes[0].data
with open(filename) as f:
data = minidom.parseString(f.read())
objs = data.getElementsByTagName('object')
num_objs = len(objs)
boxes = np.zeros((num_objs, 4), dtype=np.uint16)
gt_classes = np.zeros((num_objs), dtype=np.int32)
overlaps = np.zeros((num_objs, self.num_classes), dtype=np.float32)
# Load object bounding boxes into a data frame.
for ix, obj in enumerate(objs):
# Make pixel indexes 0-based
x1 = float(get_data_from_tag(obj, 'xmin')) - 1
y1 = float(get_data_from_tag(obj, 'ymin')) - 1
x2 = float(get_data_from_tag(obj, 'xmax')) - 1
y2 = float(get_data_from_tag(obj, 'ymax')) - 1
# ---------------------------------------------
# add these lines to avoid the accertion error
if x1 < 0:
x1 = 0
if y1 < 0:
y1 = 0
# ----------------------------------------------
cls = self._class_to_ind[
str(get_data_from_tag(obj, "name")).lower().strip()]
boxes[ix, :] = [x1, y1, x2, y2]
gt_classes[ix] = cls
overlaps[ix, cls] = 1.0
overlaps = scipy.sparse.csr_matrix(overlaps)
return {'boxes' : boxes,
'gt_classes': gt_classes,
'gt_overlaps' : overlaps,
'flipped' : False}
當處理完標註的資料後,接下來就要載入SS階段獲得的資料,通過如下函式完成:
def _load_selective_search_roidb(self, gt_roidb):
filename = os.path.abspath(os.path.join(self._devkit_path,
self.name + '.mat'))
assert os.path.exists(filename), \
'Selective search data not found at: {}'.format(filename)
raw_data = sio.loadmat(filename)['boxes'].ravel()
box_list = []
for i in xrange(raw_data.shape[0]):
#這個地方需要注意,如果在SS中你已經變換了box的值,那麼就不需要再改變box值的位置了
#box_list.append(raw_data[i][:, (1, 0, 3, 2)] - 1)
box_list.append(raw_data[i][:, (1, 0, 3, 2)])
return self.create_roidb_from_box_list(box_list, gt_roidb)
有一點需要注意的是,ss中獲得的box的值,和fast-rcnn中認為的box值有點差別,那就是你需要交換box的x和y座標。
未完待續……
本文地址:http://blog.csdn.net/linj_m/article/details/48930179
更多資源請關注 部落格:LinJM-機器視覺 微博:林建民-機器視覺
相關文章
- Python練手程式碼段(2020.11.11)Python
- fasttext訓練模型程式碼AST模型
- Spark SQL原始碼解析(四)Optimization和Physical Planning階段解析SparkSQL原始碼
- PyTorch 模型訓練實⽤教程(程式碼訓練步驟講解)PyTorch模型
- 人工智慧大模型的訓練階段和使用方式來分類人工智慧大模型
- [原始碼解析] 深度學習分散式訓練框架 horovod (13) --- 彈性訓練之 Driver原始碼深度學習分散式框架
- React 原始碼解析系列 - React 的 render 階段(三):completeUnitOfWorkReact原始碼
- React 原始碼解析系列 - React 的 render 階段(二):beginWorkReact原始碼
- 利用Python訓練手勢模型程式碼Python模型
- 語義分割丨PSPNet原始碼解析「網路訓練」原始碼
- 物件導向綜合訓練物件
- OpenPose訓練過程解析(2)
- 程式碼隨想錄演算法訓練營第28天 | 貪心進階演算法
- [原始碼解析] PyTorch 分散式之彈性訓練(3)---代理原始碼PyTorch分散式
- [原始碼解析] 深度學習分散式訓練框架 horovod (16) --- 彈性訓練之Worker生命週期原始碼深度學習分散式框架
- [原始碼解析] 深度學習分散式訓練框架 horovod (14) --- 彈性訓練發現節點 & State原始碼深度學習分散式框架
- Spark SQL原始碼解析(五)SparkPlan準備和執行階段SparkSQL原始碼
- React原始碼解析之Commit第一子階段「before mutation」React原始碼MIT
- Vue原始碼模板編譯階段----HTML解析器腦圖Vue原始碼編譯HTML
- [原始碼解析] PyTorch 分散式之彈性訓練(5)---Rendezvous 引擎原始碼PyTorch分散式
- [原始碼解析] 模型並行分散式訓練Megatron (5) --Pipedream Flush原始碼模型並行分散式
- [原始碼解析] 深度學習分散式訓練框架 horovod (8) --- on spark原始碼深度學習分散式框架Spark
- [原始碼解析] 深度學習分散式訓練框架 horovod (7) --- DistributedOptimizer原始碼深度學習分散式框架
- 程式碼隨想錄演算法訓練營第15天 | 二叉樹進階演算法二叉樹
- TWI工作指導的四階段法
- “安全即程式碼”目前發展到哪個階段?
- Siamese RPN 訓練網路結構解析
- [原始碼解析] PyTorch 分散式之彈性訓練(1) --- 總體思路原始碼PyTorch分散式
- [原始碼解析] 深度學習分散式訓練框架 horovod (10) --- run on spark原始碼深度學習分散式框架Spark
- [原始碼解析] 深度學習分散式訓練框架 horovod (5) --- 融合框架原始碼深度學習分散式框架
- 深入解析Node.js Event Loop各階段Node.jsOOP
- 一段柯里化函式程式碼閱讀函式
- [原始碼解析] 深度學習流水線並行之PopeDream(1)--- Profile階段原始碼深度學習並行
- .Net7 GC標記階段程式碼的改變GC
- Python 高階程式設計:深入解析 CSV 檔案讀取Python程式設計
- Java基礎 --- 物件導向綜合訓練Java物件
- Python進階學習之程式碼閱讀Python
- 《The Rust Programming language》程式碼練習(part 2 進階部分)Rust
- [原始碼解析] 模型並行分散式訓練Megatron (2) --- 整體架構原始碼模型並行分散式架構