基於Pytorch實現貓狗分類

專注的阿熊發表於2021-06-08

from __future__ import print_function, division

from PIL import Image

from torchvision import transforms

import torch.nn.functional as F

import torch

import torch.nn as nn

import torch.nn.parallel

# 定義網路

class ConvNet(nn.Module):

     def __init__(self):

         super(ConvNet, self).__init__()

         self.conv1 = nn.Conv2d(3, 32, 3)

         self.max_pool1 = nn.MaxPool2d(2)

         self.conv2 = nn.Conv2d(32, 64, 3)

         self.max_pool2 = nn.MaxPool2d(2)

         self.conv3 = nn.Conv2d(64, 64, 3)

         self.conv4 = nn.Conv2d(64, 64, 3)

         self.max_pool3 = nn.MaxPool2d(2)

         self.conv5 = nn.Conv2d(64, 128, 3)

         self.conv6 = nn.Conv2d(128, 128, 3)

         self.max_pool4 = nn.MaxPool2d(2)

         self.fc1 = nn.Linear(4608, 512)

         self.fc2 = nn.Linear(512, 1)

     def forward(self, x):

         in_size = x.size(0)

         x = self.conv1(x)

         x = F.relu(x)

         x = self.max_pool1(x)

         x = self.conv2(x)

         x = F.relu(x)

         x = self.max_pool2(x)

         x = self.conv3(x)

         x = F.relu(x)

         x = self.conv4(x)

         x = F.relu(x)

         x = self.max_pool3(x)

         x = self.conv5(x)

         x = F.relu(x)

         x = self.conv6(x)

         x = F.relu(x)

         x = self.max_pool4(x)

         # 展開

         x = x.view(in_size, -1)

         x = self.fc1(x)

         x = F.relu(x)

         x = self.fc2(x)

         x = torch.sigmoid(x)

         return x

# 模型儲存路徑

model_save_path = 'E:\\Cat_And_Dog\\kaggle\\model.pth'

# ------------------------ 載入資料 --------------------------- #

# Data augmentation and normalization for training

# Just normalization for validation

# 定義預訓練變換

# 資料預處理

transform_test = transforms.Compose([

     transforms.Resize(100),

     transforms.RandomVerticalFlip(),

     transforms.RandomCrop(50),

     transforms.RandomResizedCrop(150),

     transforms.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5),

     transforms.ToTensor(),

     transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])

])

class_names = ['cat', 'dog']  # 這個順序很重要,外匯跟單gendan5.com要和訓練時候的類名順序一致

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# ------------------------ 載入模型並且訓練 --------------------------- #

model = torch.load(model_save_path)

model.eval()

# print(model)

image_PIL = Image.open('E:\\Cat_And_Dog\\kaggle\\cats_and_dogs_small\\test\\cats\\cat.1500.jpg')

#

image_tensor = transform_test(image_PIL)

# 以下語句等效於 image_tensor = torch.unsqueeze(image_tensor, 0)

image_tensor.unsqueeze_(0)

# 沒有這句話會報錯

image_tensor = image_tensor.to(device)

out = model(image_tensor)

pred = torch.tensor([[1] if num[0] >= 0.5 else [0] for num in out]).to(device)

print(class_names[pred])


來自 “ ITPUB部落格 ” ,連結:http://blog.itpub.net/69946337/viewspace-2776017/,如需轉載,請註明出處,否則將追究法律責任。

相關文章