30秒輕鬆實現TensorFlow物體檢測

pythontab發表於2018-03-14

Google釋出了新的TensorFlow物體檢測API,包含了預訓練模型,一個釋出模型的jupyter notebook,一些可用於使用自己資料集對模型進行重新訓練的有用指令碼。

使用該API可以快速的構建一些圖片中物體檢測的應用。這裡我們一步一步來看如何使用預訓練模型來檢測影像中的物體。

首先我們載入一些會使用的庫

import numpy as np 
import os 
import six.moves.urllib as urllib 
import sys 
import tarfile 
import tensorflow as tf 
import zipfile 
 
from collections import defaultdict 
from io import StringIO 
from matplotlib import pyplot as plt 
from PIL import Image


接下來進行環境設定

%matplotlib inline 
sys.path.append("..")

物體檢測載入

from utils import label_map_util  
from utils import visualization_utils as vis_util

準備模型


變數  任何使用export_inference_graph.py工具輸出的模型可以在這裡載入,只需簡單改變PATH_TO_CKPT指向一個新的.pb檔案。這裡我們使用“移動網SSD”模型。

MODEL_NAME = 'ssd_mobilenet_v1_coco_11_06_2017' 
MODEL_FILE = MODEL_NAME + '.tar.gz' 
DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/' 
 
PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb' 
 
PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt') 
 
NUM_CLASSES = 90

下載模型

opener = urllib.request.URLopener() 
opener.retrieve(DOWNLOAD_BASE + MODEL_FILE, MODEL_FILE) 
tar_file = tarfile.open(MODEL_FILE) 
for file in tar_file.getmembers(): 
  file_name = os.path.basename(file.name) 
  if 'frozen_inference_graph.pb' in file_name: 
    tar_file.extract(file, os.getcwd()) 
將(frozen)TensorFlow模型載入記憶體
detection_graph = tf.Graph() 
with detection_graph.as_default(): 
  od_graph_def = tf.GraphDef() 
  with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid: 
    serialized_graph = fid.read() 
    od_graph_def.ParseFromString(serialized_graph) 
    tf.import_graph_def(od_graph_def, name='')

載入標籤圖


標籤圖將索引對映到類名稱,當我們的卷積預測5時,我們知道它對應飛機。這裡我們使用內建函式,但是任何返回將整數對映到恰當字元標籤的字典都適用。

label_map = label_map_util.load_labelmap(PATH_TO_LABELS) 
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True) 
category_index = label_map_util.create_category_index(categories)

輔助程式碼

def load_image_into_numpy_array(image): 
 (im_width, im_height) = image.size 
 return np.array(image.getdata()).reshape( (im_height, im_width, 3)).astype(np.uint8)

檢測

PATH_TO_TEST_IMAGES_DIR = 'test_images' 
TEST_IMAGE_PATHS = [ os.path.join(PATH_TO_TEST_IMAGES_DIR, 'image{}.jpg'.format(i)) for i in range(1, 3) ] 
IMAGE_SIZE = (12, 8) 
with detection_graph.as_default(): 
 
 with tf.Session(graph=detection_graph) as sess: 
  for image_path in TEST_IMAGE_PATHS: 
   image = Image.open(image_path) 
   # 這個array在之後會被用來準備為圖片加上框和標籤 
   image_np = load_image_into_numpy_array(image) 
   # 擴充套件維度,應為模型期待: [1, None, None, 3] 
   image_np_expanded = np.expand_dims(image_np, axis=0) 
   image_tensor = detection_graph.get_tensor_by_name('image_tensor:0') 
   # 每個框代表一個物體被偵測到. 
   boxes = detection_graph.get_tensor_by_name('detection_boxes:0') 
   # 每個分值代表偵測到物體的可信度. 
   scores = detection_graph.get_tensor_by_name('detection_scores:0') 
   classes = detection_graph.get_tensor_by_name('detection_classes:0') 
   num_detections = detection_graph.get_tensor_by_name('num_detections:0') 
   # 執行偵測任務. 
   (boxes, scores, classes, num_detections) = sess.run( 
     [boxes, scores, classes, num_detections], 
     feed_dict={image_tensor: image_np_expanded}) 
   # 圖形化. 
   vis_util.visualize_boxes_and_labels_on_image_array( 
     image_np, 
     np.squeeze(boxes), 
     np.squeeze(classes).astype(np.int32), 
     np.squeeze(scores), 
     category_index, 
     use_normalized_coordinates=True, 
     line_thickness=8) 
   plt.figure(figsize=IMAGE_SIZE) 
   plt.imshow(image_np)

在載入模型部分可以嘗試不同的偵測模型以比較速度和準確度,將你想偵測的圖片放入TEST_IMAGE_PATHS中執行即可。


相關文章