基於深度學習的回聲消除系統與Pytorch實現

凌逆戰發表於2021-05-21

文章作者:凌逆戰

文章程式碼(pytorch實現):https://github.com/LXP-Never/AEC_DeepModel

文章地址(轉載請指明出處):https://www.cnblogs.com/LXP-Never/p/14779360.html


寫這篇文章的目的

  1. 降低全國想要做基於深度學習的回聲消除同學們一個入門門檻。萬事開頭難呀,肯定有很多小白辛苦研究了一年,連基線系統都搭建不出來的,他們肯定心心念念有誰能幫幫他們,這不,我來了。
  2. 在基於深度學習的回聲消除這一塊,網上幾乎沒人開源,github上能找到的幾乎都是基於自適應濾波器的。我個人是很提倡開源精神的,能讓更多的人能夠參與進來,小到促進這個領域的進步,大到提升國家科學競爭力,哪怕只是一小步,都需要有人做出行動。
  3. 今天我開源,明天你開源。可能以後你們的開源專案也能幫助到我。

作者獨白

  • 寫這篇文章的目的在於想做基於深度學習的回聲消除小白們一份入門教學,所以別對這篇文章有什麼創新點或者效能上的較大期待,我只是隨便搭建了一個基線系統,來進行回聲消除程式碼的講解,帶領小白入門。
  • 別問我為什麼不除錯好了再分享出來,時間精力有限,我的研究方向也不是回聲消除,我只是感興趣,也沒人給我錢支援我研究,從一個基線模型到最終一個完善的模型,是需要巨大的時間成本的,每往下走一步需要的付出精力越多,這就是科研之路。
  • 本文分享出來的系統在哪個點可以改進,可以做創新發論文,我都會在文中說明,不用感謝我?
  • 本文引用了諸多我原先的文章,遇到不懂的大家可能還需要多翻看原來的文章,知識需要積累,沒有一蹴而就的捷徑。
  • 文中若有不對之處,還請各位看官多多包含,多提意見,我會積極修改的?。覺得寫得不錯的,建議點贊關注一下,這是對我最大的支援,是給我開源精神最大的鼓勵,我以後也還會努力分享好文章給大家的。

原理

傳統演算法

主要參考我的另外一篇文章:聲學回聲消除(Acoustic Echo Cancellation)原理與實現

  圖中$x(n)$為遠端語音,$y(n)$為遠端回聲$y(n)=x(n)*w(n)$,$s(n)$為近端語音,$d(n)$為近端麥克風語音訊號。

深度學習演算法

  回聲包含線性回聲和非線性回聲

  • 線性回聲:遠端語音直接 被近端麥克風接收的回聲。
  • 非線性回聲:遠端語音經過多徑傳播後 被近端麥克風接收的回聲

  線性回聲可以通過 時延估計、端點檢測和自適應濾波器技術較好的消除,非線性回聲經過多次反射後產生了混響,聲學特性複雜,很難消除。基於深度學習的回聲消除技術,目前有這幾個方向在做:

  • 神經網路
  • 自適應濾波器+神經網路

神經網路

  利用神經網路較強的非線性擬合能力,直接消除線性回聲和非線性回聲

  • 優點:過程簡單,一步到位
  • 缺點:可能需要更復雜或精煉的模型,才能達到更好的效果。更加考驗模型的能力

自適應濾波器+神經網路

  先利用簡單的傳統方法消除線性回聲,再利用神經網路消除非線性回聲

  • 優點:有針對性的進行回聲消除,能降低神經網路的負擔
  • 缺點:能一步到位的事情,就不要把事情複雜化

圖片來源於論文:Residual acoustic echo suppression based on efficient multi-task convolutional neural network,圖中$e(n)$為自適應濾波器輸出的的殘差訊號,$u(n)$為遠端參考訊號,然後利用短時間傅立葉變換(STFT)將$e(n)$和$u(n)$轉換到頻域,串聯作為輸入特徵。同樣輸出mask。估計的近端振幅為:

$$估計的近端振幅=mask*自適應濾波器輸出$$

訓練策略

  • 頻譜對映:輸入(近端麥克風語音訊譜,遠端語音訊譜),輸出(近端語音訊譜)
  • 波形對映:輸入(近端麥克風語音波形,遠端語音波形),輸出(近端語音波形)
  • 頻譜mask:輸入(近端麥克風語音訊譜,遠端語音訊譜),輸出 (mask),近端語音訊譜 = mask*近端麥克風語音訊譜
  • 時域mask:輸入(近端麥克風語音波形),輸出(近端語音mask, 遠端回聲mask),近端語音波形 = 近端語音mask*近端麥克風語音波形(這個點,我是受到語音分離的一篇文章啟發,覺得可行,所以也分享在這了,目前還沒有這方向的論文,科研工作者可以去嘗試)

  頻譜對映、波形對映、頻譜mask我在這篇文章中做了詳細的說明,時域mask在這篇文章中做了詳細的講解。

  回聲消除跟語音增強語音去混響或者語音分離很像,都是從混合語音或者汙染語音中提取乾淨的語音。因此我們如果想要在回聲消除領域找創新點的話,不妨去多看看我剛剛提的三個方向的論文。我主要參考的是語音增強和語音分離。

基線模型

  本文重點來了,我搭建的基線系統是使用神經網路直接消除回聲, 訓練策略為 頻譜mask。

資料準備

  做回聲消除任務主要有兩類資料,真實回聲資料以及合成回聲資料

  • 真實回聲資料:在真實環境中採集的回聲,目前只有微軟舉辦的 回聲消除挑戰賽中開源的資料集,我個人認為微軟資料集中真實資料集有點問題,詳情見部落格。
  • 合成回聲資料:通過RIR合成的回聲。可以使用任意的語音資料集,使用RIR-Generator生成房間衝擊響應(推薦使用MATLAB方法),再卷積遠端語音得到回聲。科研界主要使用的TIMIT資料集。AEC-Challenge 資料集也有合成資料集。

  我這裡就偷個懶,直接使用AEC-Challenge合成好了的資料集。檔案結構如下

└─Synthetic
    ├─TEST
    │  ├─echo_signal
    │  ├─farend_speech
    │  ├─nearend_mic_signal
    │  └─nearend_speech
    ├─TRAIN
    │  ├─echo_signal
    │  ├─farend_speech
    │  ├─nearend_mic_signal
    │  └─nearend_speech
    └─VAL
        ├─echo_signal
        ├─farend_speech
        ├─nearend_mic_signal
        └─nearend_speech

  如果你們想用TIMIT資料集的話(畢竟很多論文都用他),可以具體參考這篇論文的資料準備方法。我個人被這篇論文給繞暈了,資料準備看似不簡單,但用程式碼實現起來卻非常難。你們可以自己去試試。

  但不管用哪個資料集,我還是建議大家都把資料按照上面的檔案路徑結構放好,方便讀取。

  我搭建的基線系統實現的是頻譜mask的訓練策略,模型輸入為[遠端語音振幅,近端麥克風振幅],模型輸出IRM mask。IRM公式可以寫成以下幾種形式為:

$$\operatorname{IRM}=\sqrt{\frac{近端語音振幅^2}{近端語音振幅^2+遠端回聲振幅^2}}$$

$$\mathrm{IRM}=\sqrt{\frac{\text { 遠端語音振幅 }^{2}}{(\text { 近端語音振幅+遠端回聲振幅 })^{2}}}$$

$$\operatorname{IRM}=\sqrt{\frac{近端語音振幅^2}{近端麥克風語音振幅^2}}$$

  我使用的是Pytorch搭建的模型,Pytorch有一套自己的資料載入方式,我之前寫過一篇文章進行了總結:pytorch載入語音類自定義資料集 。如果你已經很熟悉了請繼續看,本文的回聲消除資料預處理程式碼如下:

# Author:凌逆戰
# -*- coding:utf-8 -*-
"""
作用:資料預處理
"""
import glob
import os
import torch.nn.functional as F
import torch
import torchaudio
from torch.utils.data import Dataset
from torch.utils.data import DataLoader


class FileDateset(Dataset):
    def __init__(self, dataset_path="./Synthetic/TRAIN", fs=16000, win_length=320, mode="train"):
        self.fs = fs
        self.win_length = win_length
        self.mode = mode

        farend_speech_path = os.path.join(dataset_path, "farend_speech")        # "./Synthetic/TRAIN/farend_speech"
        nearend_mic_signal_path = os.path.join(dataset_path, "nearend_mic_signal")  # "./Synthetic/TRAIN/nearend_mic_signal"
        nearend_speech_path = os.path.join(dataset_path, "nearend_speech")      # "./Synthetic/TRAIN/nearend_speech"

        self.farend_speech_list = sorted(glob.glob(farend_speech_path+"/*.wav"))    # 遠端語音路徑,list
        self.nearend_mic_signal_list = sorted(glob.glob(nearend_mic_signal_path+"/*.wav"))  # 近端麥克風語音路徑,list
        self.nearend_speech_list = sorted(glob.glob(nearend_speech_path+"/*.wav"))  # 近端語音路徑,list

    def spectrogram(self, wav_path):
        """
        :param wav_path: 音訊路徑
        :return: 返回該音訊的振幅和相位
        """
        wav, _ = torchaudio.load(wav_path)
        wav = wav.squeeze()

        if len(wav) < 160000:
            wav = F.pad(wav, (0,160000-len(wav)), mode="constant",value=0)

        S = torch.stft(wav, n_fft=self.win_length, hop_length=self.win_length//2,
                       win_length=self.win_length, window=torch.hann_window(window_length=self.win_length),
                       center=False, return_complex=True)   # (*, F,T)
        magnitude = torch.abs(S)        # 振幅
        phase = torch.exp(1j * torch.angle(S))  # 相位
        return magnitude, phase


    def __getitem__(self, item):
        """__getitem__是類的專有方法,使類可以像list一樣按照索引來獲取元素
        :param item: 索引
        :return:  按 索引取出來的 元素
        """
        # 遠端語音 振幅,相位 (F, T),F為頻點數,T為幀數
        farend_speech_magnitude, farend_speech_phase = self.spectrogram(self.farend_speech_list[item])  # torch.Size([161, 999])
        # 近端麥克風 振幅,相位
        nearend_mic_magnitude, nearend_mic_phase = self.spectrogram(self.nearend_mic_signal_list[item])
        # 近端語音 振幅,相位
        nearend_speech_magnitude, nearend_speech_phase = self.spectrogram(self.nearend_speech_list[item])

        X = torch.cat((farend_speech_magnitude, nearend_mic_magnitude), dim=0)  # 在頻點維度上進行拼接(161*2, 999),模型輸入

        _eps = torch.finfo(torch.float).eps  # 防止分母出現0
        mask_IRM = torch.sqrt(nearend_speech_magnitude ** 2/(nearend_mic_magnitude ** 2+_eps))  # IRM,模型輸出


        return X, mask_IRM, nearend_mic_magnitude, nearend_speech_magnitude

    def __len__(self):
        """__len__是類的專有方法,獲取整個資料的長度"""
        return len(self.farend_speech_list)


if __name__ == "__main__":
    train_set = FileDateset()
    train_loader = DataLoader(train_set, batch_size=64, shuffle=True, drop_last=True)

    for x, y, nearend_mic_magnitude,nearend_speech_magnitude in train_loader:
        print(x.shape)  # torch.Size([64, 322, 999])
        print(y.shape)  # torch.Size([64, 161, 999])
        print(nearend_mic_magnitude.shape)

  我幾乎每行程式碼都給了註釋了,各位點個贊不過分吧?。還有不懂地方的各位可以在評論區指出。

 如果想要創新發文章的話,資料處理這裡也可以做改動:

  1. 更改mask方法,或者提出更好用的mask,我這篇文章總結了不少:基於深度學習的單通道語音增強,大家可以輪著試一試,反正我給出了程式碼。
  2. 我這裡使用的是振幅,你們可以嘗試提取一些語音其他的特徵,類似 梅爾頻譜特徵,對數功率譜等等。
  3. 在強調一遍呀,現在沒有基於時域mask的回聲消除論文,大家快去攻略佔地呀,主要參考語音分離這個領域。

模型搭建

  我這裡使用的是頻譜mask的訓練策略,模型輸入為 遠端語音振幅 和 近端麥克風振幅 的串聯,模型輸出IRM。由上可知,輸入大小為 [64, 322, 999],輸出大小為 [64, 161, 999]。那麼我們只需要隨便搭建一個模型符合這個輸入輸出就行了。

# Author:凌逆戰
# -*- coding:utf-8 -*-
"""
作用:隨便搭建的模型,只要符合輸入大小是[64, 322, 999],輸出大小是[64, 161, 999],就能跑通
"""
import torch.nn as nn
import torch


class Base_model(nn.Module):
    def __init__(self):
        super(Base_model, self).__init__()
        # [batch, channel, input_size] (B, F, T)
        # [64, 322, 999] ---> [64, 161, 999]
        self.model = nn.Sequential(
            nn.Conv1d(in_channels=322, out_channels=322, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv1d(in_channels=322, out_channels=322, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv1d(in_channels=322, out_channels=161, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv1d(in_channels=161, out_channels=161, kernel_size=3, stride=1, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        """
        :param x: 麥克風訊號和遠端訊號的特徵串聯在一起作為輸入特徵 (322, 206)
        :return: IRM_mask * input = 近端語音對數譜
        """
        Estimated_IRM = self.model(x)

        return Estimated_IRM


if __name__ == "__main__":
    model = Base_model().cuda()
    x = torch.randn(8, 322, 999).to("cuda")  # 輸入 [8, 322, 999]
    y = model(x)  # 輸出 [8, 161, 999]
    print(y.shape)

  模型是一個可以創新的點,大家可以改成目前比較流行的模型來發文章。我這裡就隨便搭建了。

如果想要創新發文章的話,模型搭建這裡也可以做改動:

  • 使用時序模型來更多的考量語音幀間相關性,如LSTM、TCN,注意力機制等等,反正現在的模型五花八門,看著誰好用借鑑過來用,然後魔改一下,有良好的效果的話,就能寫論文了。

訓練模組

  訓練模組其實是最沒啥創新的,所有寫的正兒八經的程式碼,訓練模型幾乎都一樣,但是這一塊卻是卡住所有新人的較大關卡。不懂的人覺得難的要死,懂的人覺得簡單地一批。

  訓練模組的具體流程有以下幾部分:

  1. 命令列解析
  2. 資料集載入
  3. 檢測模型儲存地址是否存在,如果不存在則建立
  4. 例項化模型
  5. 例項化優化器(一般使用Adam優化器)
  6. 準備事件檔案,方便Tensorboard視覺化
  7. 如果接著上一次檢查點訓練,則載入模型
  8. 迴圈epochs,開始訓練(前向傳播,反向傳播)
  9. 驗證模型(根據驗證集的損失和度量,對模型的超引數進行調整)
import os
import torch
from torch.utils.data import DataLoader
from torch import nn
import argparse
from tensorboardX import SummaryWriter

from data_preparation.data_preparation import FileDateset
from model.Baseline import Base_model
from model.ops import pytorch_LSD


def parse_args():
    parser = argparse.ArgumentParser()
    # 重頭開始訓練 defaule=None, 繼續訓練defaule設定為'/**.pth'
    parser.add_argument("--model_name", type=str, default=None, help="是否載入模型繼續訓練 '/50.pth' None")
    parser.add_argument("--batch-size", type=int, default=16, help="")
    parser.add_argument("--epochs", type=int, default=20)
    parser.add_argument('--lr', type=float, default=3e-4, help='學習率 (default: 0.01)')
    parser.add_argument('--train_data', default="./data_preparation/Synthetic/TRAIN", help='資料集的path')
    parser.add_argument('--val_data', default="./data_preparation/Synthetic/VAL", help='驗證樣本的path')
    parser.add_argument('--checkpoints_dir', default="./checkpoints/AEC_baseline", help='模型檢查點檔案的路徑(以繼續培訓)')
    parser.add_argument('--event_dir', default="./event_file/AEC_baseline", help='tensorboard事件檔案的地址')
    args = parser.parse_args()
    return args


def main():
    args = parse_args()
    print("GPU是否可用:", torch.cuda.is_available())  # True
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 例項化 Dataset
    train_set = FileDateset(dataset_path=args.train_data)  # 例項化訓練資料集
    val_set = FileDateset(dataset_path=args.val_data)  # 例項化驗證資料集

    # 資料載入器
    train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=False, drop_last=True)
    val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, drop_last=True)

    # ###########    儲存檢查點的地址(如果檢查點不存在,則建立)   ############
    if not os.path.exists(args.checkpoints_dir):
        os.makedirs(args.checkpoints_dir)

    ################################
    #          例項化模型          #
    ################################
    model = Base_model().to(device)  # 例項化模型
    # summary(model, input_size=(322, 999))  # 模型輸出 torch.Size([64, 322, 999])
    # ###########    損失函式   ############
    criterion = nn.MSELoss(reduce=True, size_average=True, reduction='mean')

    ###############################
    # 建立優化器 Create optimizers #
    ###############################
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, )
    # lr_schedule = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10,20], gamma=0.1)

    # ###########    TensorBoard視覺化 summary  ############
    writer = SummaryWriter(args.event_dir)  # 建立事件檔案

    # ###########    載入模型檢查點   ############
    start_epoch = 0
    if args.model_name:
        print("載入模型:", args.checkpoints_dir + args.model_name)
        checkpoint = torch.load(args.checkpoints_dir + args.model_name)
        model.load_state_dict(checkpoint["model"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        start_epoch = checkpoint['epoch']
        # lr_schedule.load_state_dict(checkpoint['lr_schedule'])  # 載入lr_scheduler

    for epoch in range(start_epoch, args.epochs):
        model.train()  # 訓練模型
        for batch_idx, (train_X, train_mask, train_nearend_mic_magnitude, train_nearend_magnitude) in enumerate(
                train_loader):
            train_X = train_X.to(device)  # 遠端語音cat麥克風語音 [batch_size, 322, 999] (, F, T)
            train_mask = train_mask.to(device)  # IRM [batch_size 161, 999]
            train_nearend_mic_magnitude = train_nearend_mic_magnitude.to(device)
            train_nearend_magnitude = train_nearend_magnitude.to(device)

            # 前向傳播
            pred_mask = model(train_X)  # [batch_size, 322, 999]--> [batch_size, 161, 999]
            train_loss = criterion(pred_mask, train_mask)

            # 近端語音訊號頻譜 = mask * 麥克風訊號頻譜 [batch_size, 161, 999]
            pred_near_spectrum = pred_mask * train_nearend_mic_magnitude
            train_lsd = pytorch_LSD(train_nearend_magnitude, pred_near_spectrum)

            # 反向傳播
            optimizer.zero_grad()  # 將梯度清零
            train_loss.backward()  # 反向傳播
            optimizer.step()  # 更新引數

            # ###########    視覺化列印   ############
        print('Train Epoch: {} Loss: {:.6f} LSD: {:.6f}'.format(epoch + 1, train_loss.item(), train_lsd.item()))

        # ###########    TensorBoard視覺化 summary  ############
        # lr_schedule.step()  # 學習率衰減
        # writer.add_scalar(tag="lr", scalar_value=model.state_dict()['param_groups'][0]['lr'], global_step=epoch + 1)
        writer.add_scalar(tag="train_loss", scalar_value=train_loss.item(), global_step=epoch + 1)
        writer.add_scalar(tag="train_lsd", scalar_value=train_lsd.item(), global_step=epoch + 1)
        writer.flush()

        # 神經網路在驗證資料集上的表現
        model.eval()  # 測試模型
        # 測試的時候不需要梯度
        with torch.no_grad():
            for val_batch_idx, (val_X, val_mask, val_nearend_mic_magnitude, val_nearend_magnitude) in enumerate(
                    val_loader):
                val_X = val_X.to(device)  # 遠端語音cat麥克風語音 [batch_size, 322, 999] (, F, T)
                val_mask = val_mask.to(device)  # IRM [batch_size 161, 999]
                val_nearend_mic_magnitude = val_nearend_mic_magnitude.to(device)
                val_nearend_magnitude = val_nearend_magnitude.to(device)

                # 前向傳播
                val_pred_mask = model(val_X)
                val_loss = criterion(val_pred_mask, val_mask)

                # 近端語音訊號頻譜 = mask * 麥克風訊號頻譜 [batch_size, 161, 999]
                val_pred_near_spectrum = val_pred_mask * val_nearend_mic_magnitude
                val_lsd = pytorch_LSD(val_nearend_magnitude, val_pred_near_spectrum)

            # ###########    視覺化列印   ############
            print('  val Epoch: {} \tLoss: {:.6f}\tlsd: {:.6f}'.format(epoch + 1, val_loss.item(), val_lsd.item()))
            ######################
            # 更新tensorboard    #
            ######################
            writer.add_scalar(tag="val_loss", scalar_value=val_loss.item(), global_step=epoch + 1)
            writer.add_scalar(tag="val_lsd", scalar_value=val_lsd.item(), global_step=epoch + 1)
            writer.flush()

        # # ###########    儲存模型   ############
        if (epoch + 1) % 10 == 0:
            checkpoint = {
                "model": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "epoch": epoch + 1,
                # 'lr_schedule': lr_schedule.state_dict()
            }
            torch.save(checkpoint, '%s/%d.pth' % (args.checkpoints_dir, epoch + 1))


if __name__ == "__main__":
    main()

  咳咳咳,這個註釋量,你們愛了沒有,很詳細了,還看不懂說明你的基礎太差了,別看這篇文章了,打基礎去吧,基礎很重要。

如果想要創新發文章的話,損失這裡也可以做改動:

  • 使用一個更加全面的損失函式引導模型訓練,我言盡於此,剩下的靠大家自己領悟了。

推理階段

將模型預測的近端語音振幅和近端麥克風語音相位相乘得到近端語音的複數表示,經過短時傅立葉逆變換得到近端語音波形。這裡需要補一點基礎知識:

複數的幾種表示形式:

  • 實部、虛部(直角座標系):$a+bj$      ($a$是實部,$b$是虛部)
  • 幅值、相位(指數系):$re^{j\theta }$  ($r$是幅值,$\theta$是相角,$e^{j\theta }$是相位)
  • 兩種形式互換:$e^{j\theta }=cos\theta+isin\theta$,$re^{j\theta }=r(cos\theta+jsin\theta)=rcos\theta+jrsin\theta$

因此,實部$a=rcos\theta$,虛部$b=rsin\theta$,

幅值$r=\sqrt{a^2+b^2}$,相角$\theta=tan^{-1}(\frac{b}{a})$

還有一種是極座標表示法:$r\angle \theta $

結合上述補充知識,以及複數矩陣D(F, T),我們可以得到一下頻譜資訊

  • 複數的實部:  real = np.real(D(F, T)) 
  • 複數的虛部: imag= np.imag(D(F, T)) 
  • 幅值:  magnitude = np.abs(D(F, T)) 或  magnitude = np.sqrt(real**2+imag**2) 
  • 相角: angle = np.angle(D(F, T)) 
  • 相位: phase = np.exp(1j * np.angle(D(F, T))) 

librosa提供了專門將複數矩陣D(F, T)分離為幅值$S$和相位$P$的函式,$D=S*P$

librosa.magphase(D, power=1)

引數

  • D:經過stft得到的複數矩陣
  • power:幅度譜的指數,例如,1代表能量,2代表功率,等等。

返回

  • D_mag:幅值$D$,
  • D_phase:相位$P$, phase = exp(1.j * phi) , phi 是複數矩陣的相位角 np.angle(D) 

當然我們也可以通過上面的公式自己求

# Author:凌逆戰
# -*- coding:utf-8 -*-
"""
作用:通過模型生成近端語音
"""
import librosa
import matplotlib
import torchaudio
import torch.nn.functional as F
import torch
import matplotlib.pyplot as plt
from model.Baseline import Base_model
from matplotlib.ticker import FuncFormatter
import numpy as np

plt.rcParams['font.sans-serif'] = ['SimHei']  # 用來正常顯示中文標籤
plt.rcParams['axes.unicode_minus'] = False  # 用來正常顯示符號


def spectrogram(wav_path, win_length=320):
    wav, _ = torchaudio.load(wav_path)
    wav = wav.squeeze()

    if len(wav) < 160000:
        wav = F.pad(wav, (0, 160000 - len(wav)), mode="constant", value=0)
    # if len(wav) != 160000:
    #     print(wav_path)
    #     print(len(wav))

    S = torch.stft(wav, n_fft=win_length, hop_length=win_length // 2,
                   win_length=win_length, window=torch.hann_window(window_length=win_length),
                   center=False, return_complex=True)
    magnitude = torch.abs(S)
    phase = torch.exp(1j * torch.angle(S))
    return magnitude, phase


fs = 16000
farend_speech = "./farend_speech/farend_speech_fileid_9992.wav"
nearend_mic_signal = "./nearend_mic_signal/nearend_mic_fileid_9992.wav"
nearend_speech = "./nearend_speech/nearend_speech_fileid_9992.wav"
echo_signal = "./echo_signal/echo_fileid_9992.wav"

print("GPU是否可用:", torch.cuda.is_available())  # True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

farend_speech_magnitude, farend_speech_phase = spectrogram(farend_speech)  # 遠端語音  振幅,相位
nearend_mic_magnitude, nearend_mic_phase = spectrogram(nearend_mic_signal)  # 近端麥克風語音 振幅,相位
nearend_speech_magnitude, nearend_speech_phase = spectrogram(nearend_speech)  # 近端語音振 幅,相位

farend_speech_magnitude = farend_speech_magnitude.to(device)
nearend_mic_phase = nearend_mic_phase.to(device)
nearend_mic_magnitude = nearend_mic_magnitude.to(device)

nearend_speech_magnitude = nearend_speech_magnitude.to(device)
nearend_speech_phase = nearend_speech_phase.to(device)

model = Base_model().to(device)  # 例項化模型
checkpoint = torch.load("../checkpoints/AEC_baseline/10.pth")
model.load_state_dict(checkpoint["model"])

X = torch.cat((farend_speech_magnitude, nearend_mic_magnitude), dim=0)
X = X.unsqueeze(0)
per_mask = model(X)  # [1, 322, 999]-->[1, 161, 999]

per_nearend_magnitude = per_mask * nearend_mic_magnitude  # 預測的近端語音 振幅

complex_stft = per_nearend_magnitude * nearend_mic_phase  # 振幅*相位=語音複數表示
print("complex_stft", complex_stft.shape)  # [1, 161, 999]

per_nearend = torch.istft(complex_stft, n_fft=320, hop_length=160, win_length=320,
                          window=torch.hann_window(window_length=320).to("cuda"))

torchaudio.save("./predict/nearend_speech_fileid_9992.wav", src=per_nearend.cpu().detach(), sample_rate=fs)
# print("近端語音", per_nearend.shape)    # [1, 159680]

y, _ = librosa.load(nearend_speech, sr=fs)
time_y = np.arange(0, len(y)) * (1.0 / fs)
recover_wav, _ = librosa.load("./predict/nearend_speech_fileid_9992.wav", sr=16000)
time_recover = np.arange(0, len(recover_wav)) * (1.0 / fs)

plt.figure(figsize=(8,6))
ax_1 = plt.subplot(3, 1, 1)
plt.title("近端語音和預測近端波形圖", fontsize=14)
plt.plot(time_y, y, label="近端語音")
plt.plot(time_recover, recover_wav, label="深度學習生成的近端語音波形")
plt.xlabel('時間/s', fontsize=14)
plt.ylabel('幅值', fontsize=14)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.subplots_adjust(top=0.932, bottom=0.085, left=0.110, right=0.998)
plt.subplots_adjust(hspace=0.809, wspace=0.365)  # 調整子圖間距
plt.legend()

norm = matplotlib.colors.Normalize(vmin=-200, vmax=-40)
ax_2 = plt.subplot(3, 1, 2)
plt.title("近端語音訊譜", fontsize=14)
plt.specgram(y, Fs=fs, scale_by_freq=True, sides='default', cmap="jet", norm=norm)
plt.xlabel('時間/s', fontsize=14)
plt.ylabel('頻率/kHz', fontsize=14)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.subplots_adjust(top=0.932, bottom=0.085, left=0.110, right=0.998)
plt.subplots_adjust(hspace=0.809, wspace=0.365)  # 調整子圖間距

ax_3 = plt.subplot(3, 1, 3)
plt.title("深度學習生成的近端語音訊譜", fontsize=14)
plt.specgram(recover_wav, Fs=fs, scale_by_freq=True, sides='default', cmap="jet", norm=norm)
plt.xlabel('時間/s', fontsize=14)
plt.ylabel('頻率/kHz', fontsize=14)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.subplots_adjust(top=0.932, bottom=0.085, left=0.110, right=0.998)
plt.subplots_adjust(hspace=0.809, wspace=0.365)  # 調整子圖間距

def formatnum(x, pos):
    return '$%d$' % (x / 1000)


formatter = FuncFormatter(formatnum)
ax_2.yaxis.set_major_formatter(formatter)
ax_3.yaxis.set_major_formatter(formatter)


plt.show()

  為了方便視覺化對比,我順便把波形圖可語譜圖畫了出來

如果這篇文章對你有幫助,點個贊是對我最大的鼓勵。

關注我,我將分享更有價值的文章!

相關文章