1、圖片轉Tensor
from PIL import Image
import os
import numpy as np
import torch
from torchvision import transforms
pic_location = 'dataset/1.png'
img = Image.open(os.path.join(os.getcwd(), pic_location))
# 方法一
img_convert_to_numpy = np.array(img) # (32, 32, 3)
img_convert_to_tensor1 = torch.tensor(img_convert_to_numpy.transpose(2, 0, 1) / 255) # torch.Size([3, 32, 32])
# 方法二
transform = transforms.Compose([transforms.ToTensor()])
img_convert_to_tensor2 = transform(img_convert_to_numpy) # torch.Size([3, 32, 32])
print(torch.equal(img_convert_to_tensor1, img_convert_to_tensor2)) # False
print(img_convert_to_tensor1.dtype) # torch.float64
print(img_convert_to_tensor2.dtype) # torch.float32
首先,使用PIL庫(Python Imaging Library)讀取圖片,得到的是PIL影像物件。透過numpy.array(img)
或者numpy.asarray(img)
,轉化為uint8的數值陣列形式,格式為 H x W x C ,數值範圍在0-255之間。(主要區別在於當資料來源是ndarray時,array仍然會copy出一個副本,佔用新的記憶體,但asarray不會)
由於在Pytorch中,影像的格式為 C x H x W(想象一下,就是卷積核要卷積圖片的形式),所以需要用transpose進行轉置。這篇文章使用的圖片和我前一篇博文透過cifar-10資料集理解numpy陣列的高(H)、寬(W)、通道數(C)中選取的圖片一樣,透過例項圖片和程式碼解析能更好地幫助你理解。
2、細節
tensor除法會使輸出結果的精度高一級,可能會導致後面計算型別不匹配,如float32 / float32 = float64。在上面的程式碼中,torch.equal(img_convert_to_tensor1, img_convert_to_tensor2)
是等於False的。Tensor預設的dtype是float32,所以當Tensor的型別為float32時,列印Tensor是不會顯示的。
所以,我們要進行這樣的處理:img_convert_to_tensor1 = torch.tensor(img_convert_to_numpy.transpose(2, 0, 1) / 255, dtype=torch.float32)
,結果就等於True了。
3、擴充:將Tensor轉換為PIL
有batch維度的Tensor一定要透過torch.squeeze(image,dim=0)
降維,然後img = transforms.ToPILImage()(image)
一步搞定。