Torch.stack()
1. 概念
在一個新的維度上連線一個張量序列
2. 引數
- tensors (sequence)需要連線的張量序列
- dim (int)在第dim個維度上連線
注意輸入的張量shape要完全一致,且dim必須小於len(tensors)。
3. 舉例
3.1 四個shape為[3, 3]的張量
以下面這4個張量,每個張量shape為[3, 3]。
1 a = torch.Tensor([[1,2,3],[4,5,6],[7,8,9]])
2 b = torch.Tensor([[10,20,30],[40,50,60],[70,80,90]])
3 c = torch.Tensor([[100,200,300],[400,500,600],[700,800,900]])
4 d = torch.Tensor([[1000,2000,3000],[4000,5000,6000],[7000,8000,9000]])
3.1.1 dim=0的情況下,直接來看結果。
torch.stack((a,b,c,d),dim=0)
此時在第0個維度上連線,新張量的shape可以發現為[4, 3, 3],4代表在第0個維度有4項。
觀察可以得知:即初始的四個張量,即a、b、c、d四個初始張量。
可以理解為新張量的第0個維度上連線a、b、c、d。
3.1.2 dim=1的情況下
torch.stack((a,b,c,d),dim=1)
此時在第1個維度上連線,新張量的shape可以發現為[3, 4, 3],4代表在第1個維度有4項。
觀察可以得知:
- 新張量[0][0]為a[0],[0][1]為b[0],[0][2]為c[0],[0][3]為d[0]
- 新張量[1][0]為a[1],[1][1]為b[1],[1][2]為c[1],[1][3]為d[1]
- 新張量[2][0]為a[2],[2][1]為b[2],[2][2]為c[2],[2][3]為d[2]
可以理解為新張量的第1個維度上連線a、b、c、d的第0個維度單位,具體地說,在新張量[i]中連線a[i]、b[i]、c[i]、d[i],即將a[i]賦給新張量[i][0]、b[i]賦給新張量[i][1]、c[i]賦給新張量[i][2]、d[i]賦給新張量[i][3]。
3.1.2 dim=2的情況下
此時在第2個維度上連線,新張量的shape可以發現為[3, 3, 4],4代表在第2個維度有4項。
觀察可以得知:
- 新張量[0][0][0]為a[0][0],[0][0][1]為b[0][0],[0][0][2]為c[0][0],[0][0][3]為d[0][0]
- 新張量[0][1][0]為a[0][1],[0][1][1]為b[0][1],[0][1][2]為c[0][1],[0][1][3]為d[0][1]
- 新張量[0][2][0]為a[0][2],[0][2][1]為b[0][2],[0][2][2]為c[0][2],[0][2][3]為d[0][2]
- 新張量[1][0][0]為a[1][0],[1][0][1]為b[1][0],[1][0][2]為c[1][0],[1][0][3]為d[1][0]
- 新張量[1][1][0]為a[1][1],[1][1][1]為b[1][1],[1][1][2]為c[1][1],[1][1][3]為d[1][1]
- 新張量[1][2][0]為a[1][2],[1][2][1]為b[1][2],[1][2][2]為c[1][2],[1][2][3]為d[1][2]
- 新張量[2][0][0]為a[2][0],[2][0][1]為b[2][0],[2][0][2]為c[2][0],[2][0][3]為d[2][0]
- 新張量[2][1][0]為a[2][1],[2][1][1]為b[2][1],[2][1][2]為c[2][1],[2][1][3]為d[2][1]
- 新張量[2][2][0]為a[2][2],[2][2][1]為b[2][2],[2][2][2]為c[2][2],[2][2][3]為d[2][2]]
可以理解為新張量的第2個維度上連線a、b、c、d的第1個維度的單位,具體地說,在新張量[i][j]中連線a[i][j]、b[i][j]、c[i][j]、d[i][]j。
3.1.3 總結
通過dim=0、1、2的情況,可以總結並推漲出規律:
假設有n個[x,y]的張量,當dim=z時。新張量在第z個維度上連線n個張量第z-1維度的單位,具體來說,新張量[i][i+1]..[i+z-1]中依次連線n個向量[i][i+1]..[i+z-1]。
3.2 7個shape為[5, 7, 4, 2]的張量
1 a1 = torch.rand([5, 7, 4, 3])
2 a2 = a1 + 1
3 a3 = a2 + 1
4 a4 = a3 + 1
5 a5 = a4 + 1
6 a6 = a5 + 1
7 a7 = a6 + 1
假設dim=3時連線
test = torch.stack((a1, a2, a3, a4, a5, a6, a7), dim=3)
7個張量在第3個維度連線後形成的新張量賦為test,test的shape為[5, 7, 4, 7, 3],代表在第3個維度有7項。
隨機(在新張量[0][0][0]到新張量[4][6][3]區間內)檢視一個新張量第3維度上的單位:
a = test[0][1][2]
再根據總結的規律,將7個向量中的[0][1][2]連線起來,再次檢視,驗證了規律。
b = torch.zeros(0)
for i in (a1, a2, a3, a4, a5, a6, a7):
b = torch.cat((b, i[0][1][2]), dim=0)
4. 理解
通過shape來看,假設shape為[a, b, c... z],有n個shape相同的張量,在dim=x時連線n個張量,可以得到新張量,shape為[a, b, c, ... n, ...z],其中n所在維度即為第x個維度。
然後即可通過新張量[i][i+1]..[i+x-1]看作索引,對應的資料為n個張量[i][i+1][i+x-1]按順序連線。
因此,遇到一個多個張量在dim=x情況下做stack操作時,要根據x的大小進行區分:
- x=0:新增維度為第0維度,直接將每個張量本身按順序連線起來
- x>0:通過新張量0到x-1這些維度索引,可以找到每個初始張量在0到x-1維度的單位,這些單位即是新張量在第x維度上的每個單位,根據順序連線每個單位,連線的順序即是第x維度索引的值。然後對第0到x-1每個維度都重複一遍如上操作,即得出多個張量在dim=x情況下連線的新張量。