如何部署自己的SSD檢測模型到AndroidTFLite上
TensorFlow Object Detection API 上提供了使用SSD部署到TFLite執行上去的方法, 可是這套API封裝太死板, 如果你要自己實現了一套SSD的訓練演算法,應該怎麼才能部署到TFLite上呢?
首先,拋開後處理的部分,你的SSD模型(無論是VGG-SSD和Mobilenet-SSD), 你最終的模型的輸出是對class_predictions和bbox_predictions; 並且是encoded的
Encoding的方式:
class_predictions: M個Feature Layer, Feature Layer的大小(寬高)視網路結構而定; 每個Feature Layer有Num_Anchor_Depth_of_this_layer x Num_classes個channels
box_predictions: M個Feature Layer; 每個Feature Layer有Num_Anchor_Depth_of_this_layer x 4個channes 這4個channel分別代表dy,dx,h,w, 即bbox中心距離anchor中心座標的偏移量和寬高
注:通常,為了平衡loss之間的大小, 不會直接編碼dy,dx,w,h的原始值,而是dy/anchor_h*scale0, dx/anchor_w*scale0, log(h/anchor_h)*scale1, log(w/anchor_w)*scale1, 也就是偏移量的絕對值除anchor寬高得到相對值,然後再乘上一個scale, 經驗值 scale0取5,scale1取10; 對於h,w是對得到相對值後先取log再乘以scale, h/anchor_h的範圍在1附近, 取log後可以轉換到0附近;所以在解碼的時候需要做對應相反的變換;
在後面TFLite_Detection_PostProcess的Op實現裡就有這麼一段:
然後我們需要的是做的是decode出來對每個class的confidence和location的預測值
後處理
在Object Detection API的 export_tflite_ssd_graph_lib.py檔案中,你可以看到,它區別與直接freeze pb的操作就在於最後替換了後處理的部分;
frozen_graph_def = exporter.freeze_graph_with_def_protos(
input_graph_def=tf.get_default_graph().as_graph_def(),
input_saver_def=input_saver_def,
input_checkpoint=checkpoint_to_use,
output_node_names=`,`.join([
`raw_outputs/box_encodings`, `raw_outputs/class_predictions`,
`anchors`
]),
restore_op_name=`save/restore_all`,
filename_tensor_name=`save/Const:0`,
clear_devices=True,
output_graph=“,
initializer_nodes=“)
# Add new operation to do post processing in a custom op (TF Lite only)
if add_postprocessing_op:
transformed_graph_def = append_postprocessing_op(
frozen_graph_def, max_detections, max_classes_per_detection,
nms_score_threshold, nms_iou_threshold, num_classes, scale_values)
else:
# Return frozen without adding post-processing custom op
transformed_graph_def = frozen_graph_def
後處理的部分,其實看程式碼也很簡單,就是增加了一個叫TFLite_Detection_PostProcess的node,用於解碼和非極大抑制. 這個node的輸入就是上面提到的box_predictions和class_predictions, 還有anchors的編碼; 用這個node的目的只TFLite並不支援tf.contrib.image.non_max_surpression操作
Reshape過程:
這裡需要明確,TFLite_Detection_PostProcess 這個op對raw_outputs/box_encodings, raw_outputs/class_predictions, anchors的Shape是有一個定製要求的
raw_outputs/box_encodings.shape=[1, num_anchors,4]
raw_outputs/class_predictions.shape=[1, num_anchors,num_classes+1]
anchors.shape=[1,num_anchors,4]
這裡需要注意:1, 這三個都必須是3維的Tensor; 2.raw_outputs/class_predictions.shape的最後一個維度是包含background的classes, 也就是是num_classes+1; TFLite_Detection_PostProcess還有一個引數num_classes, 這個引數值是不包含background的, 所以也就導致TFLite_Detection_PostProcess的輸出的class index是從0計數的;
with tf.variable_scope(`raw_outputs`):
cls_pred = [tf.reshape(pred, [-1, num_classes]) for pred in cls_pred]
location_pred = [tf.reshape(pred, [-1, 4]) for pred in location_pred]
cls_pred = tf.concat(cls_pred, axis=0)
location_pred = tf.expand_dims(tf.concat(location_pred, axis=0),0, name=`box_encodings`)
cls_pred=tf.nn.softmax(cls_pred)
tf.identity(tf.expand_dims(cls_pred,0), name=`class_predictions`)
這段程式碼就是用來reshape成要求的輸入的, 需要注意的是對class_prediction需要做依次softmax或者sigmoid, 具體選擇哪種取決於你是否允許一個anchor對應多個類;
對於anchors, 這其實是一constant的值:
num_anchors = anchor_cy.get_shape().as_list()
with tf.Session() as sess:
y_out, x_out, h_out, w_out = sess.run([anchor_cy, anchor_cx, anchor_h, anchor_w])
encoded_anchors = tf.constant(
np.transpose(np.stack((y_out, x_out, h_out, w_out))),
dtype=tf.float32,
shape=[num_anchors[0], 4])
注意: 之前我使用tf.stack合成這個值的時候發現,TFLite只支援axis=0的時候的tf.stack, 否則就會轉換是吧
匯出pb
新增完後處理,既可以匯出一個帶有後處理功能的pb檔案了; 如果你不新增後處理,把它放在CPU上後續去做,其實也可以省去不少麻煩;
binary_graph = os.path.join(output_dir, `tflite_graph.pb`)
with tf.gfile.GFile(binary_graph, `wb`) as f:
f.write(transformed_graph_def.SerializeToString())
txt_graph = os.path.join(output_dir, `tflite_graph.pbtxt`)
with tf.gfile.GFile(txt_graph, `w`) as f:
f.write(str(transformed_graph_def))
注意: 匯出的pb如果包含後處理, 是沒辦法用正常的TF執行的,必須轉成tflite執行
匯出tflite
bazel run –config=opt tensorflow/contrib/lite/toco:toco —
–input_file=$OUTPUT_DIR/tflite_graph.pb
–output_file=$OUTPUT_DIR/detect.tflite
–input_shapes=1,300,300,3
–input_arrays=normalized_input_image_tensor
–output_arrays=`TFLite_Detection_PostProcess`,`TFLite_Detection_PostProcess:1`,`TFLite_Detection_PostProcess:2`,`TFLite_Detection_PostProcess:3`
–inference_type=QUANTIZED_UINT8
–mean_values=128
–std_values=128
–change_concat_input_ranges=false
–allow_custom_ops
or
bazel run -c opt tensorflow/lite/toco:toco —
–input_file=$OUTPUT_DIR/tflite_graph.pb
–output_file=$OUTPUT_DIR/detect.tflite
–input_shapes=1,300,300,3
–input_arrays=normalized_input_image_tensor
–output_arrays=`TFLite_Detection_PostProcess`,`TFLite_Detection_PostProcess:1`,`TFLite_Detection_PostProcess:2`,`TFLite_Detection_PostProcess:3`
–inference_type=FLOAT
–allow_custom_ops
匯出的過程中,可能遇到Converting unsupported operation: TFLite_Detection_PostProcess 這個提示, 正常如果是TF在1.10以上就忽略這個提示好了
然後你可以先用python的程式載入這個tflite去測試一下
注意: 這時候會發現一個問題, TFLite_Detection_PostProcess的NMS操作是忽略類標籤的,如果你設定max_classes_per_detection=1; 但是如果你設定成>1的值, 會發現它吧background的標籤也算進來了, 導致出來很多誤檢測的bbox;
部署Android
然後,你可以嘗試部署到Android上, 在不使用NNAPI的時候正常,但是如果是NNAPI就需要自己實現相關操作了,否則會crash掉
相關文章
- 目標檢測之SSD
- 目標檢測---教你利用yolov5訓練自己的目標檢測模型YOLO模型
- SSD 目標檢測 Keras 版Keras
- 【目標檢測從放棄到入門】SSD / RCNN / YOLO通俗講解CNNYOLO
- SSD物體檢測演算法詳解演算法
- SSD固態硬碟檢測工具:SSDReporter mac版硬碟Mac
- 將智慧合約部署到Rinkeby測試鏈上
- TF專案實戰(基於SSD目標檢測)——人臉檢測1
- 如何將專案部署到伺服器上伺服器
- LabVIEW+OpenVINO在CPU上部署新冠肺炎檢測模型實戰View模型
- 如何確定計算節點能不能檢測到儲存節點上的磁碟
- 單級式目標檢測方法概述:YOLO與SSDYOLO
- 在MCU端部署GRU模型實現鼾聲檢測:科技與健康管理的融合模型
- 如何將React專案,部署到Web伺服器的Tomcat 上ReactWeb伺服器Tomcat
- 如何擁有自己的專屬GPT-本地部署目前最強大模型llama3GPT大模型
- 如何把本地的Django專案部署到伺服器(親測)Django伺服器
- 基於函式計算快速搭建基於人工智慧的目標檢測系統(自己部署自己抽盲盒)函式人工智慧
- 如何避免在網頁抓取時被檢測到?網頁
- 如何把本地網站部署到雲伺服器上網站伺服器
- 從Densebox到Dubox:更快、效能更優、更易部署的anchor-free目標檢測
- 用SSD-Pytorch訓練自己的資料集PyTorch
- 在安卓上執行yolov8目標檢測模型(ncnn)安卓YOLO模型CNN
- 深度學習“吃雞外掛”——目標檢測 SSD 實驗深度學習
- 目標檢測入門系列手冊六:SSD訓練教程
- 深度學習與CV教程(13) | 目標檢測 (SSD,YOLO系列)深度學習YOLO
- 0-目標檢測模型的基礎模型
- 釋出自己的jar到Maven Repository公服上JARMaven
- 不同網站檢測到的ip不同網站
- R2CNN模型——用於文字目標檢測的模型CNN模型
- ResNet50-SSD300模型圖模型
- modelscope上的模型如何下載?模型
- 如何開始定製你自己的大型語言模型模型
- 如何更快的找到自己所需的模型關聯型別?模型型別
- SSIS 部署篇-如何部署SSIS包到SqlServerSQLServer
- 如何檢視ssd壽命?教你macOS 免安裝用指令即可查詢SSD健康度/壽命Mac
- 固態硬碟壽命檢測方法 怎麼看SSD還能用多久?硬碟
- 目標檢測 YOLO v3 訓練 人臉檢測模型YOLO模型
- 如何檢測圖中的環?