pytorch載入語音類自定義資料集

凌逆戰發表於2020-11-09

  pytorch對一下常用的公開資料集有很方便的API介面,但是當我們需要使用自己的資料集訓練神經網路時,就需要自定義資料集,在pytorch中,提供了一些類,方便我們定義自己的資料集合

  • torch.utils.data.Dataset:所有繼承他的子類都應該重寫  __len()__  , __getitem()__ 這兩個方法
    •  __len()__ :返回資料集中資料的數量
    •   __getitem()__ :返回支援下標索引方式獲取的一個資料
  • torch.utils.data.DataLoader:對資料集進行包裝,可以設定batch_size、是否shuffle....

第一步

  自定義的 Dataset 都需要繼承 torch.utils.data.Dataset 類,並且重寫它的兩個成員方法:

  • __len()__:讀取資料,返回資料和標籤
  • __getitem()__:返回資料集的長度
from torch.utils.data import Dataset


class AudioDataset(Dataset):
    def __init__(self, ...):
        """類的初始化"""
        pass

    def __getitem__(self, item):
        """每次怎麼讀資料,返回資料和標籤"""
        return data, label

    def __len__(self):
        """返回整個資料集的長度"""
        return total

注意事項:Dataset只負責資料的抽象,一次呼叫getiitem只返回一個樣本

案例:

  檔案目錄結構

  • p225
    • ***.wav
    • ***.wav
    • ***.wav
    • ...
  • dataset.py

目的:讀取p225資料夾中的音訊資料

 1 class AudioDataset(Dataset):
 2     def __init__(self, data_folder, sr=16000, dimension=8192):
 3         self.data_folder = data_folder
 4         self.sr = sr
 5         self.dim = dimension
 6 
 7         # 獲取音訊名列表
 8         self.wav_list = []
 9         for root, dirnames, filenames in os.walk(data_folder):
10             for filename in fnmatch.filter(filenames, "*.wav"):  # 實現列表特殊字元的過濾或篩選,返回符合匹配“.wav”字元列表
11                 self.wav_list.append(os.path.join(root, filename))
12 
13     def __getitem__(self, item):
14         # 讀取一個音訊檔案,返回每個音訊資料
15         filename = self.wav_list[item]
16         wb_wav, _ = librosa.load(filename, sr=self.sr)
17 
18         # 取 幀
19         if len(wb_wav) >= self.dim:
20             max_audio_start = len(wb_wav) - self.dim
21             audio_start = np.random.randint(0, max_audio_start)
22             wb_wav = wb_wav[audio_start: audio_start + self.dim]
23         else:
24             wb_wav = np.pad(wb_wav, (0, self.dim - len(wb_wav)), "constant")
25 
26         return wb_wav, filename
27 
28     def __len__(self):
29         # 音訊檔案的總數
30         return len(self.wav_list)

注意事項:19-24行:每個音訊的長度不一樣,如果直接讀取資料返回出來的話,會造成維度不匹配而報錯,因此只能每次取一個音訊檔案讀取一幀,這樣顯然並沒有用到所有的語音資料,

第二步

  例項化 Dataset 物件

Dataset= AudioDataset("./p225", sr=16000)

如果要通過batch讀取資料的可直接跳到第三步,如果你想一個一個讀取資料的可以看我接下來的操作

# 例項化AudioDataset物件
train_set = AudioDataset("./p225", sr=16000)

for i, data in enumerate(train_set):
    wb_wav, filname = data
    print(i, wb_wav.shape, filname)

    if i == 3:
        break
    # 0 (8192,) ./p225\p225_001.wav
    # 1 (8192,) ./p225\p225_002.wav
    # 2 (8192,) ./p225\p225_003.wav
    # 3 (8192,) ./p225\p225_004.wav

第三步

  如果想要通過batch讀取資料,需要使用DataLoader進行包裝

為何要使用DataLoader?

  1. 深度學習的輸入是mini_batch形式
  2. 樣本載入時候可能需要隨機打亂順序,shuffle操作
  3. 樣本載入需要採用多執行緒

  pytorch提供的 DataLoader 封裝了上述的功能,這樣使用起來更方便。

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False)

引數

  • dataset:載入的資料集(Dataset物件)
  • batch_size每個批次要載入多少個樣本(預設值:1)
  • shuffle:每個epoch是否將資料打亂
  • sampler定義從資料集中抽取樣本的策略如果指定,則不能指定洗牌。
  • batch_sampler類似於sampler,但每次返回一批索引。與batch_size、shuffle、sampler和drop_last相互排斥。
  • num_workers:使用多程式載入的程式數,0代表不使用多執行緒
  • collate_fn:如何將多個樣本資料拼接成一個batch,一般使用預設拼接方式
  • pin_memory:是否將資料儲存在pin memory區,pin memory中的資料轉到GPU會快一些
  • drop_last:dataset中的資料個數可能不是batch_size的整數倍,drop_last為True會將多出來不足一個batch的資料丟棄

返回:資料載入器

案例:

# 例項化AudioDataset物件
train_set = AudioDataset("./p225", sr=16000)
train_loader = DataLoader(train_set, batch_size=8, shuffle=True)

for (i, data) in enumerate(train_loader):
    wav_data, wav_name = data
    print(wav_data.shape)   # torch.Size([8, 8192])
    print(i, wav_name)
    # ('./p225\\p225_293.wav', './p225\\p225_156.wav', './p225\\p225_277.wav', './p225\\p225_210.wav',
    # './p225\\p225_126.wav', './p225\\p225_021.wav', './p225\\p225_257.wav', './p225\\p225_192.wav')

我們來吃幾個栗子消化一下:

栗子1

  這個例子就是本文一直舉例的,栗子1只是合併了一下而已

  檔案目錄結構

  • p225
    • ***.wav
    • ***.wav
    • ***.wav
    • ...
  • dataset.py

目的:讀取p225資料夾中的音訊資料

 1 import fnmatch
 2 import os
 3 import librosa
 4 import numpy as np
 5 from torch.utils.data import Dataset
 6 from torch.utils.data import DataLoader
 7 
 8 
 9 class Aduio_DataLoader(Dataset):
10     def __init__(self, data_folder, sr=16000, dimension=8192):
11         self.data_folder = data_folder
12         self.sr = sr
13         self.dim = dimension
14 
15         # 獲取音訊名列表
16         self.wav_list = []
17         for root, dirnames, filenames in os.walk(data_folder):
18             for filename in fnmatch.filter(filenames, "*.wav"):  # 實現列表特殊字元的過濾或篩選,返回符合匹配“.wav”字元列表
19                 self.wav_list.append(os.path.join(root, filename))
20 
21     def __getitem__(self, item):
22         # 讀取一個音訊檔案,返回每個音訊資料
23         filename = self.wav_list[item]
24         print(filename)
25         wb_wav, _ = librosa.load(filename, sr=self.sr)
26 
27         # 取 幀
28         if len(wb_wav) >= self.dim:
29             max_audio_start = len(wb_wav) - self.dim
30             audio_start = np.random.randint(0, max_audio_start)
31             wb_wav = wb_wav[audio_start: audio_start + self.dim]
32         else:
33             wb_wav = np.pad(wb_wav, (0, self.dim - len(wb_wav)), "constant")
34 
35         return wb_wav, filename
36 
37     def __len__(self):
38         # 音訊檔案的總數
39         return len(self.wav_list)
40 
41 
42 train_set = Aduio_DataLoader("./p225", sr=16000)
43 train_loader = DataLoader(train_set, batch_size=8, shuffle=True)
44 
45 
46 for (i, data) in enumerate(train_loader):
47     wav_data, wav_name = data
48     print(wav_data.shape)   # torch.Size([8, 8192])
49     print(i, wav_name)
50     # ('./p225\\p225_293.wav', './p225\\p225_156.wav', './p225\\p225_277.wav', './p225\\p225_210.wav',
51     # './p225\\p225_126.wav', './p225\\p225_021.wav', './p225\\p225_257.wav', './p225\\p225_192.wav')

注意事項

  1. 27-33行:每個音訊的長度不一樣,如果直接讀取資料返回出來的話,會造成維度不匹配而報錯,因此只能每次取一個音訊檔案讀取一幀,這樣顯然並沒有用到所有的語音資料,
  2. 48行:我們在__getitem__中並沒有將numpy陣列轉換為tensor格式,可是第48行顯示資料是tensor格式的。這裡需要引起注意

栗子2

  相比於案例1,案例二才是重點,因為我們不可能每次只從一音訊檔案中讀取一幀,然後讀取另一個音訊檔案,通常情況下,一段音訊有很多幀,我們需要的是按順序的讀取一個batch_size的音訊幀,先讀取第一個音訊檔案,如果滿足一個batch,則不用讀取第二個batch,如果不足一個batch則讀取第二個音訊檔案,來補充。

  我給出一個建議,先按順序讀取每個音訊檔案,以窗長8192、幀移4096對語音進行分幀,然後拼接。得到(幀數,幀長,1)(frame_num, frame_len, 1)的陣列儲存到h5中。然後用上面講到的 torch.utils.data.Dataset 和 torch.utils.data.DataLoader 讀取資料。

具體實現程式碼:

  第一步:建立一個H5_generation指令碼用來將資料轉換為h5格式檔案:

  第二步:通過Dataset從h5格式檔案中讀取資料

import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import h5py

def load_h5(h5_path):
    # load training data
    with h5py.File(h5_path, 'r') as hf:
        print('List of arrays in input file:', hf.keys())
        X = np.array(hf.get('data'), dtype=np.float32)
        Y = np.array(hf.get('label'), dtype=np.float32)
    return X, Y


class AudioDataset(Dataset):
    """資料載入器"""
    def __init__(self, data_folder):
        self.data_folder = data_folder
        self.X, self.Y = load_h5(data_folder)   # (3392, 8192, 1)

    def __getitem__(self, item):
        # 返回一個音訊資料
        X = self.X[item]
        Y = self.Y[item]

        return X, Y

    def __len__(self):
        return len(self.X)


train_set = AudioDataset("./speaker225_resample_train.h5")
train_loader = DataLoader(train_set, batch_size=64, shuffle=True, drop_last=True)


for (i, wav_data) in enumerate(train_loader):
    X, Y = wav_data
    print(i, X.shape)
    # 0 torch.Size([64, 8192, 1])
    # 1 torch.Size([64, 8192, 1])
    # ...

我嘗試在__init__中生成h5檔案,但是會導致記憶體爆炸,就很奇怪,因此我只好分開了,

參考

pytorch學習(四)—自定義資料集(講的比較詳細)

 

相關文章