深度學習--實戰 LeNet5

林每天都要努力發表於2023-04-24

深度學習--實戰 LeNet5

資料集

資料集選用CIFAR-10的資料集,Cifar-10 是由 Hinton 的學生 Alex Krizhevsky、Ilya Sutskever 收集的一個用於普適物體識別的計算機視覺資料集,它包含 60000 張 32 X 32 的 RGB 彩色圖片,總共 10 個分類。其中,包括 50000 張用於訓練集,10000 張用於測試集。

模型實現

模型需要繼承nn.module

import torch
from torch import  nn


class Lenet5(nn.Module):
    """
    for cifar10 dataset.
    """
    def __init__(self):
        super(Lenet5,self).__init__()

        self.conv_unit = nn.Sequential(
            #input:[b,3,32,32] ===> output:[b,6,x,x]
            #Conv2d(Input_channel:輸入的通道數,kernel_channels:卷積核的數量,輸出的通道數,kernel_size:卷積核的大小,stride:步長,padding:邊緣補足)
            nn.Conv2d(3,6,kernel_size=5,stride=1,padding=0),

            #池化
            nn.MaxPool2d(kernel_size=2,stride=2,padding=0),

            #卷積層
            nn.Conv2d(6,16,kernel_size=5,stride=1,padding=0),

            #池化
            nn.AvgPool2d(kernel_size=2,stride=2,padding=0)

            #output:[b,16,5,5]
        )

        #flatten

        #Linear層
        self.fc_unit=nn.Sequential(
            nn.Linear(16*5*5,120),
            nn.ReLU(),
            nn.Linear(120,84),
            nn.ReLU(),
            nn.Linear(84,10)
        )

        #測試卷積輸出到全連線層的輸入
        #tmp = torch.rand(2,3,32,32)
        #out = self.conv_unit(tmp)
        #print("conv_out:",out.shape)

        #Loss評價  Cross Entropy Loss  分類  在其中包含一個softmax()操作
        #self.criteon = nn.MSELoss()  迴歸
        #self.criteon = nn.CrossEntropyLoss()

    def forward(self,x):
        """

        :param x:[b,3,32,32]
        :return:
        """
        batchsz = x.size(0)
        #[b,3,32,32]=>[b,16,5,5]
        x = self.conv_unit(x)
        #[b,16,5,5]=>[b,16*5*5]
        x = x.view(batchsz,16*5*5)
        #[b,16*5*5]=>[b,10]
        logits = self.fc_unit(x)

        return logits

        # [b,10]
        # pred = F.softmax(logits,dim=1)  這步在CEL中包含了,所以不需要再寫一次
        #loss = self.criteon(logits,y)




def main():
    net = Lenet5()
    tmp = torch.rand(2,3,32,32)
    out = net(tmp)
    print("lenet_out:",out.shape)

if __name__ == '__main__':
    main()

訓練與測試

import torch
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
from lenet5 import Lenet5
import torch.nn.functional as F
from torch import  nn,optim

def main():

    batch_size = 32
    epochs = 1000
    learn_rate = 1e-3

    #匯入圖片,一次只匯入一張
    cifer_train = datasets.CIFAR10('cifar',train=True,transform=transforms.Compose([
        transforms.Resize((32,32)),
        transforms.ToTensor()
    ]),download=True)

    #載入圖
    cifer_train = DataLoader(cifer_train,batch_size=batch_size,shuffle=True)

    #匯入圖片,一次只匯入一張
    cifer_test = datasets.CIFAR10('cifar',train=False,transform=transforms.Compose([
        transforms.Resize((32,32)),
        transforms.ToTensor()
    ]),download=True)

    #載入圖
    cifer_test = DataLoader(cifer_test,batch_size=batch_size,shuffle=True)

    #iter迭代器,__next__()方法可以獲得資料
    x, label = iter(cifer_train).__next__()
    print("x:",x.shape,"label:",label.shape)
    #x: torch.Size([32, 3, 32, 32]) label: torch.Size([32])


    device = torch.device('cuda')
    model = Lenet5().to(device)
    print(model)
    criteon = nn.CrossEntropyLoss().to(device)
    optimizer = optim.Adam(model.parameters(),lr=learn_rate)


    for epoch in range(epochs):
        model.train()
        for batchidx,(x,label) in enumerate(cifer_train):
            x,label = x.to(device),label.to(device)

            logits = model(x)
            #logits:[b,10]

            loss = criteon(logits,label)

            #backprop
            optimizer.zero_grad()  #梯度清零
            loss.backward()
            optimizer.step()  #梯度更新
        #
        print(epoch,loss.item())

        model.eval()
        with torch.no_grad():
            #test
            total_correct = 0
            total_num = 0
            for x,label in cifer_test:
                x,label = x.to(device),label.to(device)
                #[b,10]
                logits = model(x)
                #[b]
                pred =logits.argmax(dim=1)

                #[b] vs [b] => scalar tensor
                total_correct += torch.eq(pred,label).float().sum().item()
                total_num += x.size(0)

        acc = total_correct/total_num
        print("epoch:",epoch,"acc:",acc)


if __name__ == '__main__':
    main()

相關文章