學習筆記14:模型儲存

有何m不可發表於2024-06-04

轉自:https://www.cnblogs.com/miraclepbc/p/14361926.html

儲存訓練過程中使得測試集上準確率最高的引數

import copy
best_model_wts = copy.deepcopy(model.state_dict())
best_acc = 0
train_loss = []
train_acc = []
test_loss = []
test_acc = []
for epoch in range(extend_epoch):
    epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc = fit(epoch, model, train_dl, test_dl)
    if epoch_test_acc > best_acc:
        best_model_wts = copy.deepcopy(model.state_dict())
        best_acc = epoch_test_acc
    train_loss.append(epoch_loss)
    train_acc.append(epoch_acc)
    test_loss.append(epoch_test_loss)
    test_acc.append(epoch_test_acc)
model.load_state_dict(best_model_wts)

儲存模型

PATH = 'E:/my_model.pth'
torch.save(model.state_dict(), PATH)

重新載入模型

new_model = models.resnet101(pretrained = True)
in_f = new_model.fc.in_features
new_model.fc = nn.Linear(in_f, 4)
new_model.load_state_dict(torch.load(PATH))

測試是否載入成功

new_model.to(device)
test_correct = 0
test_total = 0
new_model.eval()
with torch.no_grad():
    for x, y in test_dl:
        x, y = x.to(device), y.to(device)
        y_pred = new_model(x)
        loss = loss_func(y_pred, y)
        y_pred = torch.argmax(y_pred, dim = 1)
        test_correct += (y_pred == y).sum().item()
        test_total += y.size(0)
epoch_test_acc = test_correct / test_total
print(epoch_test_acc)

相關文章