# net.py
import torch
import torch.nn as nn
import lightning as L
from torchmetrics.classification import BinaryAccuracy
class AlexNet(L.LightningModule):
def __init__(self, num_classes=1):
super(AlexNet, self).__init__()
self.save_hyperparameters()
self.features = nn.Sequential(
nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(96, 256, kernel_size=5, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(256, 384, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(384, 384, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(384, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
)
self.classifier = nn.Sequential(
nn.Dropout(),
nn.Linear(256 * 6 * 6, 4096),
nn.ReLU(inplace=True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(inplace=True),
nn.Linear(4096, num_classes),
)
self.train_accuracy = BinaryAccuracy()
self.val_accuracy = BinaryAccuracy()
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), 256 * 6 * 6)
x = self.classifier(x)
return x
def training_step(self, batch, batch_idx):
images, labels = batch
outputs = self(images).squeeze(1)
loss = nn.BCEWithLogitsLoss()(outputs, labels.float())
acc = self.train_accuracy(outputs, labels)
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
self.log('train_acc', acc, on_step=True, on_epoch=True, prog_bar=True, logger=True)
return loss
def validation_step(self, batch, batch_idx):
images, labels = batch
outputs = self(images).squeeze(1)
loss = nn.BCEWithLogitsLoss()(outputs, labels.float())
acc = self.val_accuracy(outputs, labels)
self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
self.log('val_acc', acc, on_step=True, on_epoch=True, prog_bar=True, logger=True)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
return optimizer
# main.py
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint
from net import AlexNet # 從net.py匯入AlexNet
L.seed_everything(42)
torch.set_float32_matmul_precision('high')
# 載入資料集
data_dir = './data'
# 定義資料集的轉換
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
# 定義LightningDataModule
class DataModule(L.LightningDataModule):
def __init__(self, data_dir, batch_size, transform):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
self.transform = transform
def setup(self, stage=None):
# 使用ImageFolder載入資料集
self.train_dataset = datasets.ImageFolder(self.data_dir + '/train', transform=self.transform)
self.val_dataset = datasets.ImageFolder(self.data_dir + '/val', transform=self.transform)
self.test_dataset = datasets.ImageFolder(self.data_dir + '/test', transform=self.transform)
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=4)
def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=4)
def test_dataloader(self):
return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=4)
# 例項化資料模組和模型
data_module = DataModule(data_dir=data_dir, batch_size=32, transform=transform)
model = AlexNet(num_classes=1)
# 定義 ModelCheckpoint 回撥
checkpoint_callback = ModelCheckpoint(
monitor='val_acc', # 監控的指標
dirpath='checkpoints', # 儲存的路徑
filename='best-checkpoint', # 儲存的檔名
save_top_k=1, # 僅儲存最好的模型
mode='max' # 指標越大越好
)
# 使用Trainer進行訓練
trainer = L.Trainer(
max_epochs=15,
accelerator='gpu',
devices=1,
callbacks=[checkpoint_callback],
)
trainer.fit(model, datamodule=data_module)
# 訓練完成後,可以載入最佳模型
best_model_path = checkpoint_callback.best_model_path
print(f"Best model saved at: {best_model_path}")
import os
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import pandas as pd
import matplotlib.pyplot as plt
from torchmetrics.classification import BinaryAccuracy
import lightning as L
from PIL import Image
import seaborn as sns
from sklearn.metrics import confusion_matrix
from net import AlexNet # 從 net.py 匯入 AlexNet
# 設定隨機種子和精度
L.seed_everything(42)
torch.set_float32_matmul_precision("high")
# 資料轉換
transform = transforms.Compose(
[
transforms.Resize((227, 227)),
transforms.ToTensor(),
]
)
# 資料目錄
data_dir = "./data"
# 載入測試資料集
test_dataset = datasets.ImageFolder(data_dir + "/test", transform=transform)
test_loader = DataLoader(test_dataset, batch_size=804, shuffle=False, num_workers=4)
# 載入最佳模型檢查點
best_model_path = "checkpoints/best-checkpoint.ckpt"
best_model = AlexNet.load_from_checkpoint(best_model_path, num_classes=1)
# 設定模型為評估模式並將其移至 GPU
best_model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
best_model.to(device)
# 定義評估指標和損失函式
test_accuracy = BinaryAccuracy().to(device)
test_loss_fn = nn.BCEWithLogitsLoss()
# 用於儲存分類錯誤的影像路徑的列表
mis_cat = []
mis_dog = []
# 用於收集真實標籤和預測標籤的列表
true_labels = []
predicted_labels = []
# 禁用梯度計算以提高推理速度
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
# 進行預測
outputs = best_model(images).squeeze(1)
# 計算損失和準確率
test_loss = test_loss_fn(outputs, labels.float())
test_acc = test_accuracy(outputs, labels)
# 收集分類錯誤的影像路徑
preds = (torch.sigmoid(outputs).cpu() > 0.5).numpy().astype(int)
incorrect_indices = preds != labels.cpu().numpy()
for idx, incorrect in enumerate(incorrect_indices):
if incorrect:
image_path = test_dataset.imgs[idx][0]
if labels.cpu().numpy()[idx] == 0: # 真實標籤為貓
mis_cat.append(image_path)
else: # 真實標籤為狗
mis_dog.append(image_path)
# 收集真實標籤和預測標籤
true_labels.extend(labels.cpu().numpy())
predicted_labels.extend(preds)
print(f"測試損失: {test_loss.item():.4f}, 測試準確率: {test_acc.item():.4f}")
# 計算混淆矩陣
cm = confusion_matrix(true_labels, predicted_labels)
# 繪製混淆矩陣
plt.figure(figsize=(8, 6))
sns.heatmap(
cm,
annot=True,
fmt="d",
cmap="Blues",
cbar=False,
annot_kws={"fontsize": 15, "fontweight": "bold"},
xticklabels=["Cat", "Dog"],
yticklabels=["Cat", "Dog"],
)
plt.xlabel("Predicted Labels", fontsize=14)
plt.ylabel("True Labels", fontsize=14)
plt.title("confusion_matrix", fontsize=16)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
# 儲存影像檔案
plt.savefig("results/confusion_matrix.png", bbox_inches="tight")
plt.show()
# 繪製分類錯誤的貓和狗影像
def plot_images(cat_paths, dog_paths, n=5):
plt.figure(figsize=(15, 10))
for i, img_path in enumerate(cat_paths[:n]):
img = Image.open(img_path)
plt.subplot(2, n, i + 1)
plt.imshow(img)
plt.title(f"Dog {i + 1}")
plt.axis("off")
for i, img_path in enumerate(dog_paths[:n]):
img = Image.open(img_path)
plt.subplot(2, n, i + 1 + n)
plt.imshow(img)
plt.title(f"Cat {i + 1}")
plt.axis("off")
plt.savefig("results/mistake.png")
plt.show()
print("顯示分類錯誤的貓和狗影像:")
plot_images(mis_cat, mis_dog, n=5)