MNIST資料集詳解及視覺化處理(pytorch)

後廠村張先生發表於2020-11-24

MNIST資料集詳解及視覺化處理(pytorch)

MNIST 資料集已經是一個被”嚼爛”了的資料集, 作為機器學習在視覺領域的“hello world”,很多教程都會對它”下手”, 幾乎成為一個 “典範”。 不過有些人可能對它還不是很瞭解, 下面來介紹一下。

MNIST 資料集可在 http://yann.lecun.com/exdb/mnist/ 獲取, 它包含了四個部分:

  • Training set images: train-images-idx3-ubyte.gz (9.9 MB, 解壓後 47 MB,
    包含 60,000 個樣本)

  • Training set labels: train-labels-idx1-ubyte.gz (29
    KB, 解壓後 60 KB, 包含 60,000 個標籤)

  • Test set images:t10k-images-idx3-ubyte.gz (1.6 MB, 解壓後 7.8 MB, 包含 10,000 個樣本)

  • Test set labels: t10k-labels-idx1-ubyte.gz (5KB, 解壓後 10 KB, 包含 10,000 個標籤)

MNIST 資料集來自美國國家標準與技術研究所, National Institute of Standards and Technology (NIST). 訓練集 (training set) 由來自 250 個不同人手寫的數字構成, 其中 50% 是高中學生, 50% 來自人口普查局 (the Census Bureau) 的工作人員. 測試集(test set) 也是同樣比例的手寫數字資料。
在這裡插入圖片描述在使用pytorch進行學習時,可以使用pytorch的處理影像視訊的torchvision工具集直接下載MNIST的訓練和測試圖片,torchvision包含了一些常用的資料集、模型和轉換函式等等,比如圖片分類、語義切分、目標識別、例項分割、關鍵點檢測、視訊分類等工具。

from torchvision import datasets, transforms

#下載測試集
train_dataset = datasets.MNIST('./data', train=True, 
                                transfrom=transforms.ToTensor(), 
                                download=True)
test_dataset =  datasets.MNIST('./data', train=False, 
                                transform=transforms.ToTensor(),
                                download=True)

下載完成後的資料集如下圖所示。

資料集的圖片是以位元組的形式進行儲存,在我們進行訓練和測試時可以直接使用 torch.utils.data.DataLoader 進行載入。

在這裡插入圖片描述雖然下載下來的資料集檔案,其具體的儲存格式我們暫時不用太過關心,但如何才能將這部分資料轉換為可見的圖片形式呢?我們可以利用pytorch自帶的工具進行檔案讀取,並提取資料儲存為可開啟的jpg檔案和txt檔案。

import os
from skimage import io
import torchvision.datasets.mnist as mnist

root="D:/MNIST/data/MNIST/raw"
train_set = (
    mnist.read_image_file(os.path.join(root, 'train-images-idx3-ubyte')),
    mnist.read_label_file(os.path.join(root, 'train-labels-idx1-ubyte'))
        )
test_set = (
    mnist.read_image_file(os.path.join(root, 't10k-images-idx3-ubyte')),
    mnist.read_label_file(os.path.join(root, 't10k-labels-idx1-ubyte'))
        )
print("training set :",train_set[0].size())
print("test set :",test_set[0].size())

def convert_to_img(train=True):
    if(train):
        f=open(root+'train.txt','w')
        data_path=root+'/train/'
        if(not os.path.exists(data_path)):
            os.makedirs(data_path)
        for i, (img,label) in enumerate(zip(train_set[0],train_set[1])):
            img_path=data_path+str(i)+'.jpg'
            io.imsave(img_path,img.numpy())
            f.write(img_path+' '+str(label)+'\n')
        f.close()
    else:
        f = open(root + 'test.txt', 'w')
        data_path = root + '/test/'
        if (not os.path.exists(data_path)):
            os.makedirs(data_path)
        for i, (img,label) in enumerate(zip(test_set[0],test_set[1])):
            img_path = data_path+ str(i) + '.jpg'
            io.imsave(img_path, img.numpy())
            f.write(img_path + ' ' + str(label) + '\n')
        f.close()

convert_to_img(True)#轉換訓練集
convert_to_img(False)#轉換測試集

等待轉換完成後,可以在MNIST訓練集和測試集所在的資料夾內出現train和test兩個資料夾。
在這裡插入圖片描述裡面是已經轉換完成的jpg格式的6000張訓練資料和10000張測試資料。在test和train資料夾的上一層raw資料夾中也產生了相對應的兩個txt檔案,裡面是每張jpg格式的圖片所標註的數字。
在這裡插入圖片描述 此時我們已經能看到所有的手寫字圖片和相對應的標籤了,下一步可以使用pytorch來實現LeNet-5網路,利用訓練資料集對卷積神經網路進行訓練,訓練完成後可以使用測試集對網路進行測試,以檢驗網路的訓練結果。

當然,也可以自行任意選擇圖片輸送到網路中進行測試,或者根據網路要求,自行繪製數字然後處理成網路要求的格式進行測試。

注:MNIST的網路處理要求28x28x1的圖片輸入,我們在資料集還原出來的jpg檔案為28x28x3,因此在任選圖片進行預測時,需要先將圖片進行處理。

相關文章