PIL影像轉torch的tensor

海_纳百川發表於2024-04-11

1、圖片轉Tensor

from PIL import Image
import os
import numpy as np
import torch
from torchvision import transforms

pic_location = 'dataset/1.png'
img = Image.open(os.path.join(os.getcwd(), pic_location))

# 方法一
img_convert_to_numpy = np.array(img)  # (32, 32, 3)
img_convert_to_tensor1 = torch.tensor(img_convert_to_numpy.transpose(2, 0, 1) / 255)  # torch.Size([3, 32, 32])

# 方法二
transform = transforms.Compose([transforms.ToTensor()])
img_convert_to_tensor2 = transform(img_convert_to_numpy)  # torch.Size([3, 32, 32])

print(torch.equal(img_convert_to_tensor1, img_convert_to_tensor2))  # False
print(img_convert_to_tensor1.dtype)  # torch.float64
print(img_convert_to_tensor2.dtype)  # torch.float32

  首先,使用PIL庫(Python Imaging Library)讀取圖片,得到的是PIL影像物件。透過numpy.array(img)或者numpy.asarray(img),轉化為uint8的數值陣列形式,格式為 H x W x C ,數值範圍在0-255之間。(主要區別在於當資料來源是ndarray時,array仍然會copy出一個副本,佔用新的記憶體,但asarray不會)
  由於在Pytorch中,影像的格式為 C x H x W(想象一下,就是卷積核要卷積圖片的形式),所以需要用transpose進行轉置。這篇文章使用的圖片和我前一篇博文透過cifar-10資料集理解numpy陣列的高(H)、寬(W)、通道數(C)中選取的圖片一樣,透過例項圖片和程式碼解析能更好地幫助你理解。

2、細節

  tensor除法會使輸出結果的精度高一級,可能會導致後面計算型別不匹配,如float32 / float32 = float64。在上面的程式碼中,torch.equal(img_convert_to_tensor1, img_convert_to_tensor2)是等於False的。Tensor預設的dtype是float32,所以當Tensor的型別為float32時,列印Tensor是不會顯示的。
  所以,我們要進行這樣的處理:img_convert_to_tensor1 = torch.tensor(img_convert_to_numpy.transpose(2, 0, 1) / 255, dtype=torch.float32),結果就等於True了。

3、擴充:將Tensor轉換為PIL

  有batch維度的Tensor一定要透過torch.squeeze(image,dim=0)降維,然後img = transforms.ToPILImage()(image)一步搞定。

相關文章