在深度學習中訓練模型都是小批次小批次地最佳化訓練的,即每次都會從原資料集中取出一小批次進行訓練,完成一次權重更新後,再從原資料集中取下一個小批次資料,然後再訓練再更新。
另外,原資料集往往很大,不可能一次性的全部載入模型,只能一小批一小批地載入。訓練完了就扔了,再載入下一小批。
準備資料
import pandas as pd
import numpy as np
data = np.random.rand(128, 3) # 128x3
data = pd.DataFrame(data, columns=['feature_1', 'feature_2', 'label'])
Dataset和Dataloader使用模板
class MyDataset(Dataset):
def __init__(self, data):
super().__init__()
'''
有兩種寫法:
1、將全部資料都載入進記憶體裡,適用於少量資料;
2、當資料量或者標籤量很大時,比如圖片,就把這些資料或者標籤放到檔案或資料庫裡去,只需在此方法中初始化定義這些檔案索引的列表即可。
'''
# 以下2個方法都是魔法方法
def __getitem__(self, index): # 實現索引資料集中的某一個資料
# 表示將來例項化這個物件後,它能支援下標(索引)操作,也就是能透過索引把裡面的資料拿出來。
# ...
# 資料解析
# ...
return 某一個資料
def __len__(self): # 返回資料集的長度
return len(self.data)
my_dataset = MyDataset(data)
train_loader = DataLoader(dataset=my_dataset, # 傳遞資料集
batch_size=32, #一個小批次容量是多少
shuffle=True, # 資料集順序是否要打亂,一般是要的。測試資料集一般沒必要
num_workers=0) # 需要幾個程式來一次性讀取這個小批次資料
建立一個完整的Dataset,使用上面自己生成的資料集。
from torch.utils.data import Dataset # Dataset是個抽象類,只能用於繼承
from torch.utils.data import DataLoader # DataLoader需例項化,用於載入資料
class MyDataset(Dataset):
def __init__(self, data):
super().__init__()
self.data = data
self.features = self.data[['feature_1', 'feature_2']].values # [[]]是返回結果是一個df,[]可能返回一個Serial;
self.label = self.data[['label']].values
# .values 取DataFrame或Series的值,返回值是一個numpy ndarray的副本;而不是對原始資料的引用。這意味著你可以對返回的陣列執行任何操作,而不會影響原始的pandas物件。
def __getitem__(self, index): # 引數index必寫
return self.features[index], self.label[index]
def __len__(self):
return len(self.data)
# 例項化
my_dataset = MyDataset(data)
train_loader = DataLoader(dataset=my_dataset, # 要傳遞的資料集
batch_size=32, #一個小批次資料的大小是多少
shuffle=True, # 資料集順序是否要打亂,一般是要的。測試資料集一般沒必要
num_workers=0) # 需要幾個程式來一次性讀取這個小批次資料
dataloader 本質上是一個可迭代物件!
關於可迭代物件以及迭代器的說明:https://www.cnblogs.com/kphang/p/17026932.html
所以就有這說明裡提到的幾種方式進行訪問。
# 方式1:每一輪就是Dataloader中設定的batch_size的大小。
for batch_features, batch_labels in train_loader:
pass
# 方式2:
for i, batch_data in enumerate(train_loader):
print(i) # 0 1 2 3 # 共4輪,因為batch_size設定的是32,資料項共128個;
在 PyTorch 中,dataloader 會遍歷完所有的資料。每次迭代會返回一個批次的資料。你可以使用 for 迴圈來迭代 dataloader,每次迴圈會返回一個批次的資料。當 dataloader 迭代完所有資料時,for 迴圈將會結束。
也可以使用 iter() 函式來獲取迭代器,然後使用 next() 函式來迭代資料。例如:
# 方式3:
iterator = iter(dataloader)
while True:
try:
data = next(iterator)
# Process the data
except StopIteration:
break