一、函式介紹
torch.nn.functional.grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corners=None)
-
對於4D輸入,
input
維度為 \((N,C,H_{in},W_{in})\),grid
維度為 \((N,H_{out},W_{out},2)\) ,則output
維度為 \((N,C,H_{out},W_{out})\) -
對於5D輸入,
input
維度為 \((N,C,D_{in},H_{in},W_{in})\),grid
維度為 \((N,D_{out},H_{out},W_{out},3)\) ,則output
維度為 \((N,C,D_{out},H_{out},W_{out})\) -
gird
儲存著用於在輸入特徵圖上進行元素取樣的座標偏移量。grid的元素值通常在 \(\left [-1, 1 \right ]\) 之間, \(\left (-1, -1 \right )\) 表示取輸入特徵圖左上角的元素, \(\left (1, 1 \right )\) 表示取輸入特徵圖右下角的元素。
二、示例程式碼
import torch
import torch.nn.functional as F
# 定義一個 4x4 的輸入張量
input_tensor = torch.tensor([
[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12],
[13, 14, 15, 16],
], dtype=torch.float).view(1, 1, 4, 4)
print(input_tensor)
# 定義取樣點,歸一化座標在 [-1, 1] 範圍內
grid = torch.tensor([[
[[-0.5, -0.5],
[0.5, -0.5]],
[[-0.5, 0.5],
[0.5, 0.5]],
]], dtype=torch.float)
print(grid)
# 使用 F.grid_sample 進行取樣
output = F.grid_sample(input_tensor, grid, align_corners=True)
print(output)
計算過程
假設輸入張量的尺寸為 (4, 4),取樣點座標的歸一化範圍在 [-1, 1],我們將其轉換為張量座標的範圍 [0, 3]。
歸一化座標轉換公式
歸一化座標轉換公式如下:
示例計算 1:歸一化取樣點 [-0.5, -0.5]
對於歸一化取樣點 [-0.5, -0.5],我們將其轉換為輸入張量的實際座標:
這樣,歸一化座標 [-0.5, -0.5] 對應的輸入張量實際座標為 [0.75, 0.75]。
假設取樣點 (x, y) 對應輸入張量的座標 [0.75, 0.75],我們可以確定其周圍的四個畫素值:
左上角畫素 (0, 0)
右上角畫素 (0, 1)
左下角畫素 (1, 0)
右下角畫素 (1, 1)
使用雙線性插值公式計算插值值:
top_left = input_tensor[0, 0, 0, 0] # 1
top_right = input_tensor[0, 0, 0, 1] # 2
bottom_left = input_tensor[0, 0, 1, 0] # 5
bottom_right = input_tensor[0, 0, 1, 1] # 6
value = (1-0.75)*(1-0.75)*f(0,0) + (1-0.25)*(1-0.75)*f(0,1) \
+ (1-0.75)*(1-0.25)*f(1,0) + (1-0.25)*(1-0.25)*f(1,1)
value = (1 - 0.75) * (1 - 0.75) * 1 + 0.75 * (1 - 0.75) * 2 \
+ (1 - 0.75) * 0.75 * 5 + 0.75 * 0.75 * 6 = 4.75
補充知識:
1.性插值法(linear interpolation)
假設我們已知座標 ((x0, y0) 與 (x1, y1),要得到 [x0, x1] 區間內某一位置 x 在直線上的值。根據圖中所示,我們得到
由於 x 值已知,所以可以從公式得到 y 的值
2.雙線性插值法(bilinear interpolation)
在數學上,雙線性插值是有兩個變數的插值函式的線性插值擴充套件,其核心思想是在兩個方向分別進行一次線性插值。
如座標圖所示,用橫縱座標代表影像畫素的位置,f(x,y)代表該畫素點(x,y)的彩色值或灰度值。
假設我們已知函式f(x,y)
在 Q11 = (x1, y1)
、Q12 = (x1, y2)
, Q21 = (x2, y1)
以及 Q22 = (x2, y2)
四個點的值
若想得到未知函式f(x,y)
在點P=(x, y)
的值,首先在 x 方向進行線性插值,得到
然後在 y 方向進行線性插值,得到
這樣就得到所要的結果 f(x, y)
,
2.1 單位正方形
如果選擇一個座標系統使得 f 的四個已知點座標分別為 (0, 0)、(0, 1)、(1, 0) 和 (1, 1),那麼插值公式就可以化簡為
或者用矩陣運算表示為
2.2 非線性
雙線性插值的結果不是線性的,它是兩個線性函式的積。在單位正方形上,雙線性插值可以記作
常數的數目(4個)對應於給定的 f 的資料點數目
雙線性插值的結果與插值的順序無關。首先進行 y 方向的插值,然後進行 x 方向的插值,所得到的結果是一樣的。雙線性插值的一個顯然的三維空間延伸是三線性插值。
參考文章:
- 一文徹底弄懂 PyTorch 的 F.grid_sample
- PyTorch中grid_sample的使用方法
- 通俗易懂】詳解torch.nn.functional.grid_sample函式:可實現對特徵圖的水平/垂直翻轉
- 雙線性插值(Bilinear Interpolation) 原理、存在的問題及其解決方案、OpenCV程式碼實現