Python pytorch 座標系變換與維度轉換

liuliu55發表於2024-04-22

前言

深度學習中經常要用到張量座標系變換與維度轉換,因此記錄一下,避免混淆

座標系變換

座標系變換(矩陣轉置),主要是調換tensor/array的維度

pytorch

import torch

def info(tensor):
    print(f"tensor: {tensor}")
    print(f"tensor size: {tensor.size()}")
    print(f"tensor is contiguous: {tensor.is_contiguous()}")
    print(f"tensor stride: {tensor.stride()}")

tensor = torch.rand([1,2,3])
info(tensor)

# output:
# tensor: tensor([[[0.9516, 0.2289, 0.0042],
#          [0.2808, 0.4321, 0.8238]]])
# tensor size: torch.Size([1, 2, 3])
# tensor is contiguous: True
# tensor stride: (6, 3, 1)

per_tensor = tensor.permute(1,2,0)
info(per_tensor)

# output:
# tensor: tensor([[[0.9516, 0.2808],
#          [0.2289, 0.4321],
#          [0.0042, 0.8238]]])
# tensor size: torch.Size([1, 3, 2])
# tensor is contiguous: False
# tensor stride: (6, 1, 3)

numpy

import numpy as np

def np_info(array):
    print(f"array: {array}")
    print(f"array size: {array.shape}")
    print(f"array is contiguous: {array.flags['C_CONTIGUOUS']}")
    print(f"array stride: {array.strides}")

array = np.random.rand(1,2,3)
np_info(array)

# output:
# array: [[[0.58227139 0.32251543 0.12221412]
#   [0.72647191 0.42323578 0.65290986]]]
# array size: (1, 2, 3)
# array is contiguous: True
# array stride: (48, 24, 8)

trans_array = np.transpose(array, (0,2,1))
np_info(trans_array)

# output:
# array: [[[0.58227139 0.72647191]
#   [0.32251543 0.42323578]
#   [0.12221412 0.65290986]]]
# array size: (1, 3, 2)
# array is contiguous: False
# array stride: (48, 8, 24)

所以對於高維的tensor來說,其實並沒有改變資料的相對位置,只是旋轉了這個data的(超)立方體,即改變(超)立方體的觀察角度

維度變換

tensor.view()

view()主要是將tensor轉化為想要的張量尺寸,但並不影響contiguous屬性
view()相當於tensor的一個引用,透過它會直接對原tensor進行操作,不會產生複製,輸出和輸入是共享內部儲存的

view_tensor = tensor.view(3,2,1)
info(view_tensor)

# output:
# tensor: tensor([[[0.9516],
#          [0.2289]],
# 
#         [[0.0042],
#          [0.2808]],
# 
#         [[0.4321],
#          [0.8238]]])
# tensor size: torch.Size([3, 2, 1])
# tensor is contiguous: True
# tensor stride: (2, 1, 1)

但當對contiguous為false的tensor進行view操作時,則會報錯

view_per_tensor  = per_tensor.view(2,3) 

#output:
# ---------------------------------------------------------------------------
# RuntimeError                              Traceback (most recent call last)
# /tmp/ipykernel_388070/1679121630.py in <module>
# ----> 1 view_per_tensor  = per_tensor.view(2,3)
#       2 # info(per_tensor)
#       3 info(view_per_tensor)
#       4 print(view_per_tensor.data_ptr() == per_tensor.data_ptr())

# RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

tensor.reshape()

torch.Tensor.reshape()可以對任意tensor進行操作,相當於torch.Tensor.view() + torch.Tensor.contiguous().view(),也就是說,reshape操作也不一定會開闢新的記憶體空間,如果tensor是連續的話,實際上呼叫的view的實現,而當tensor不連續且步長不相容的時候,就會對tensor進行深複製。

reshape_per_tensor = per_tensor.reshape(2,3) 
info(reshape_per_tensor)

# output:
# tensor: tensor([[0.9384, 0.9049, 0.8476],
#         [0.5196, 0.7949, 0.0637]])
# tensor size: torch.Size([2, 3])
# tensor is contiguous: True
tensor stride: (3, 1)

Ref

  1. https://blog.csdn.net/wulele2/article/details/127337439
  2. https://blog.csdn.net/wxfighting/article/details/122758553

相關文章