行人重識別(17)——程式碼實踐之區域性對齊最小距離演算法(local_distance.py)

東方旅行者發表於2021-01-03

!轉載請註明原文地址!——東方旅行者

更多行人重識別文章移步我的專欄:行人重識別專欄

區域性對齊最小距離演算法

一、區域性對齊最小距離演算法作用

本檔案用於定義區域性對齊最小距離演算法。
用於解決兩張圖片行人部件不對齊的問題。如下圖所示
區域性未對齊

二、區域性對齊最小距離演算法編寫思路

主要的函式batch_local_dist,該函式呼叫shortest_dist函式與batch_euclidean_dist函式。

batch_euclidean_dist用於計算區域性特徵的歐氏距離。輸入兩個三維張量X(假設維度為(32,128,64)),Y(假設維度為(32,32,64))。首先需要判斷這兩個張量是否是三維,然後需要判斷張量的第一維數值與第三維數值是否對應相等。先計算各自張量的平方,然後使用維度擴充套件在進行維度交換是這兩個張量維度大小相同。然後使用torch.baddbmm_進行批矩陣相乘,計算x2+y2-2xy,x維度為(32,128,64),y為(32,32,64)交換維度後(32,64,32),最終結果維度(32,128,32)。然後開平方返回兩區域性分支的歐氏距離。

shortest_dist用於計算區域性特徵的區域性對齊最小距離。設圖片A有區域性特徵8段,圖片B有區域性特徵6段,則設AB距離矩陣大小為8×6,則dist(3,4)就代表圖片A的前4段區域性特徵與圖片B的前5段區域性特徵的距離。由此我們可以知道dist(7,5)就是圖片A的前8段區域性特徵與圖片B的前6段區域性特徵的距離,即圖片A與圖片B的距離。在計算最小距離時同時具有區域性對齊的作用。在距離矩陣邊界上,最小距離只有一條通路,所以只能加和。在距離矩陣內部,可以選擇從上方的通路與左側通路,從中選取最小的路徑。

batch_local_dist根據輸入的區域性特徵先進行判斷張量是否合法,然後使用batch_euclidean_dist計算區域性特徵的歐氏距離,然後對距離矩陣進行歸一化,使用shortest_dist計算區域性對齊最小距離,返回該距離。

三、程式碼

import torch

"""
本檔案用於定義區域性對齊最小距離演算法
"""
def batch_euclidean_dist(x, y):
    """
    計算區域性特徵的歐氏距離
    輸入x(N, m, d),y(N, n, d)
    輸出dist(N, m, n)
    其中N為batch_size,m為x的local part,n為y的local part
    """
    assert len(x.size()) == 3#需要判斷x是否是三維
    assert len(y.size()) == 3#需要判斷y是否是三維
    assert x.size(0) == y.size(0)#需要判斷x的第一維的數值是否等於y第一維的數值
    assert x.size(-1) == y.size(-1)#需要判斷x的第二維的數值是否等於y第二維的數值
    N, m, d = x.size()
    N, n, d = y.size()
    #經過計算後xx與yy維度都是維度(N, m, n)
    xx = torch.pow(x, 2).sum(-1, keepdim=True).expand(N, m, n)
    #先擴充套件維度在交換維度的原因時,如果n大於m的話無法使用.expand(N, m, n)函式,因為原始張量第二維數值為n大於目標數值n,所以只能先擴充套件第三維然後再交換
    yy = torch.pow(y, 2).sum(-1, keepdim=True).expand(N, n, m).permute(0, 2, 1)
    dist = xx + yy
    #計算x^2+y^2-2xy
    dist.baddbmm_(1, -2, x, y.permute(0, 2, 1))#進行三維矩陣相乘,x為N, m, d,y為N, n, d交換維度後N, d, n,最終結果維度N,m,d
    dist = dist.clamp(min=1e-12).sqrt() #維度N,m,d
    return dist

def shortest_dist(dist_mat):
    """
    根據距離矩陣計算區域性對齊最小距離
    
    設圖片A有區域性特徵8段,圖片B有區域性特徵6段,則設AB距離矩陣大小為8×6,則dist(3,4)就代表圖片A的前4段區域性特徵與圖片B的前5段區域性特徵的距離。
    由此我們可以知道dist(7,5)就是圖片A的前8段區域性特徵與圖片B的前6段區域性特徵的距離,即圖片A與圖片B的距離。
    
    在計算最小距離時同時具有區域性對齊的作用
    
    輸入dist_mat(m, n, N)
    輸出dist(N)
    其中N為batch_size,m為x的local part,n為y的local part
    """
    m, n = dist_mat.size()[:2]#獲取輸入距離矩陣前兩維
    dist = [[0 for _ in range(n)] for _ in range(m)]#初始化距離矩陣,型別list,元素也為list
    for i in range(m):
        for j in range(n):
            if (i == 0) and (j == 0):#初始化邊界
                dist[i][j] = dist_mat[i, j]
            elif (i == 0) and (j > 0):#當i為0時,最小距離只有一種,該種情況屬於距離矩陣邊界
                dist[i][j] = dist[i][j - 1] + dist_mat[i, j]
            elif (i > 0) and (j == 0):#當j為0時,最小距離只有一種,該種情況屬於距離矩陣邊界
                dist[i][j] = dist[i - 1][j] + dist_mat[i, j]
            else:#在位於距離矩陣內部時,可以選擇從上方或從左側的距離,所以選取其中更小的距離
                dist[i][j] = torch.min(dist[i - 1][j], dist[i][j - 1]) + dist_mat[i, j]
    dist = dist[-1][-1]#最後返回距離矩陣右下角元素,即為兩區域性特徵張量的區域性對齊最小距離
    return dist

def batch_local_dist(x, y):
    """
    根據區域性特徵計算最小距離
    輸入x(N, m, d),y(N, n, d)
    輸出dist(n)
    """
    assert len(x.size()) == 3
    assert len(y.size()) == 3
    assert x.size(0) == y.size(0)
    assert x.size(-1) == y.size(-1)
    #維度(N, m, n)
    dist_mat = batch_euclidean_dist(x, y)
    #歸一化維度(N, m, n)
    #dist_mat = (torch.exp(dist_mat) - 1.) / (torch.exp(dist_mat) + 1.)
    #輸入維度(m, N, n),輸出維度(n)
    dist = shortest_dist(dist_mat.permute(1, 2, 0))
    return dist

if __name__=='__main__':
    x=torch.randn(32,64,64)
    y=torch.randn(32,32,64)
    local_dist=batch_local_dist(x,y)
    print(local_dist)

相關文章