在Python以陣列的方式讀取影像,並轉為Torch張量

weixin_44828787發表於2020-11-12

圖片路徑:

img_path = ‘./datasets/proposal-flow-pascal/PF-dataset-PASCAL/JPEGImages/2009_003820.jpg’

讀取圖片的2種不同方法:

from PIL import Image
image_1 = Image.open(img_path)

from skimage import io
image_2 = io.imread(img_path)

這2種方法讀取的影像的shape是一樣的,都是numpy陣列型別:(228, 300, 3)

print(image_1.shape)
print(image_2.shape)

對(影像)陣列進行升維(通過numpy.升維)

source_image = np.expand_dims(source_image.transpose((2,0,1)),0)

source_image的shape為:(1, 3, 228, 300)

print(source_image.shape)

將numpy陣列轉為torch張量:

source_image = torch.Tensor(source_image)

source_image的shape為:torch.Size([1, 3, 228, 300])

上一句也可以替換為,將numpy陣列轉為torch張量(同時歸一化):

source_image = torch.Tensor(source_image.astype(np.float32)/255.0)

source_image的shape為:torch.Size([1, 3, 228, 300])

將torch張量轉為變數

image_var = Variable(source_image,requires_grad=False)

image_var的shape為:image_var_shape: torch.Size([1, 3, 228, 300])

對(影像)進行降維
image_var_1 = image_var_1[0] # 4維降3維

調整不同維度的順序(對於變數)
image_var_1 = image_var_1.permute(1, 2, 0)

調整不同維度的順序(對於陣列)
source_image = source_image.transpose(1, 2, 0)

轉為陣列
image_var_1 = image_var_1.numpy()

Python讀取影像時,讀取為numpy陣列,用plt顯示影像時,也必須是numpy陣列的格式才能顯示。

張量和變數必須先轉為陣列才能作為影像顯示。

相關文章