Conditional AutoEncoder的Pytorch完全實現

倦鸟已归时發表於2024-04-14

一個完整的深度學習流程,必須包含的部分有:引數配置、Dataset和DataLoader構建、模型與optimizer與Loss函式建立、訓練、驗證、儲存模型,以及讀取模型、測試集驗證等,對於生成模型來說,還應該有重構測試、生成測試。

AutoEncoder進能夠重構見過的資料、VAE可以透過取樣生成新資料,對於MNIST資料集來說都可以透過全連線神經網路訓練。但是我們需要用CNN來實現呢,也很輕易。

Conditional VAE則有些特殊,它要把資料標籤轉換成One-Hot格式再拼接到資料上,MNIST資料集尚可,資料拉開也就784維度,那麼對於一般的影像資料來說就不可行了。

解決這個問題只需要CNN後面接MLP(多層感知機(就是全連線網路))就行了,對於CNN應有準確的認識,實際上,CNN從通道上看,和全連線神經網路一樣,都是全連線的,如果把某一層的卷積核個數(通道數)視做全連線網路某一層的節點數,那麼兩者結構上是一樣的,不同的是卷積核還要在影像上做複雜運算,每一通道的資料是2維的,全連線網路每一個“通道”是1個標量,因此通常CNN的資料是四個維度(bs, c, h, w),而全連線網路的資料一般兩個維度(bs, dim)。

CNN的特點是,隨著層數加深,影像尺寸越來越小,通道數越來越高,此時就和全連線神經網路非常像了,對每一個通道做一個全域性平均池化(GAP)或者最大池化(MaxPooling),那不就變成全連線網路的輸入了?

實現如下(Condition AE,不是Conditional VAE)

編寫Encoder如下:

import torch
import torch.nn as nn


class Encoder(nn.Module):
    def __init__(self, encoded_space_dim, fc2_input_dim=128, iscond=False, cond_dim=10):
        super().__init__()
        self.encoder_cnn = nn.Sequential(
            nn.Conv2d(1, 8, 3, stride=2, padding=1),
            nn.ReLU(True),
            nn.Conv2d(8, 16, 3, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.Conv2d(16, 32, 3, stride=2, padding=0),
            nn.ReLU(True)
        )

        self.flatten = nn.Flatten(start_dim=1)
        if iscond:
            self.encoder_lin = nn.Sequential(
                nn.Linear(3 * 3 * 32 + cond_dim, fc2_input_dim),
                nn.ReLU(True),
                nn.Linear(128, encoded_space_dim)
            )
        else:
            self.encoder_lin = nn.Sequential(
                nn.Linear(3 * 3 * 32, fc2_input_dim),
                nn.ReLU(True),
                nn.Linear(128, encoded_space_dim)
            )
        self.iscond=iscond

    def forward(self, x, cond_vec=None):
        x = self.encoder_cnn(x)
        x = self.flatten(x)
        if self.iscond:
            x = self.encoder_lin(torch.cat([x, cond_vec], dim=1))
        else:
            x = self.encoder_lin(x)
        return x

Decoder用轉置卷積,如下:

class Decoder(nn.Module):
    def __init__(self, encoded_space_dim, fc2_input_dim=128, iscond=False, cond_dim=10):
        super().__init__()
        if iscond:
            self.decoder_lin = nn.Sequential(
                nn.Linear(encoded_space_dim+cond_dim, fc2_input_dim),
                nn.ReLU(True),
                nn.Linear(128, 3 * 3 * 32),
                nn.ReLU(True)
            )
        else:
            self.decoder_lin = nn.Sequential(
                nn.Linear(encoded_space_dim, fc2_input_dim),
                nn.ReLU(True),
                nn.Linear(128, 3 * 3 * 32),
                nn.ReLU(True)
            )
        self.unflatten = nn.Unflatten(dim=1,
                                      unflattened_size=(32, 3, 3))
        self.decoder_conv = nn.Sequential(
            nn.ConvTranspose2d(32, 16, 3,
                               stride=2, output_padding=0),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.ConvTranspose2d(16, 8, 3, stride=2,
                               padding=1, output_padding=1),
            nn.BatchNorm2d(8),

            nn.ReLU(True),
            nn.ConvTranspose2d(8, 1, 3, stride=2,
                               padding=1, output_padding=1)
        )
        self.iscond = iscond

    def forward(self, x, cond_vec=None):
        if self.iscond:
            x = self.decoder_lin(torch.cat([x, cond_vec], dim=1))
        else:
            x = self.decoder_lin(x)
        x = self.unflatten(x)
        x = self.decoder_conv(x)
        x = torch.sigmoid(x)
        return x

訓練直接寫一個Trainer,初始化、資料集、模型初始化部分如下:

class TrainerMNIST(object):
    def __init__(self, istrain=False):
        self.istrain = istrain
        self.configs = {
            "lr": 0.001, "weight_decay": 1e-5, "batch_size": 256, "d": 4, "fc_input_dim": 128, "seed": 3407, 
            "epochs": 12, "iscond": True, "cond_dim": 10,
        }
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.timestr = time.strftime("%Y%m%d%H%M", time.localtime())
        # self.in_channels = 1
        # self.demo_test = demo_test
        self.run_save_dir = "../run/aeconv_mnist/" + self.timestr + 'conditional_mix/'
        if istrain:
            # if not self.demo_test:
            os.makedirs(self.run_save_dir, exist_ok=True)
        self.model_e, self.model_d = None, None
        self.optim, self.loss_fn = None, None
        self.train_loader, self.valid_loader = None, None
        self.test_dataset, self.test_loader = None, None
        self.mix_rate = int(0.667 * self.configs["batch_size"])

    def __setup_dataset(self):
        data_dir = 'F:/DATAS/mnist'
        if self.istrain:
            train_dataset = torchvision.datasets.MNIST(data_dir, train=True, download=True)
            train_transform = transforms.Compose([transforms.ToTensor(), ])
            train_dataset.transform = train_transform
            if self.configs["iscond"]:
                train_dataset.target_transform = onehot(self.configs["cond_dim"])
            m = len(train_dataset)
            train_data, val_data = random_split(train_dataset, [int(m - m * 0.2), int(m * 0.2)])
            self.train_loader = torch.utils.data.DataLoader(train_data, batch_size=self.configs["batch_size"])
            self.valid_loader = torch.utils.data.DataLoader(val_data, batch_size=self.configs["batch_size"])

        self.test_dataset = torchvision.datasets.MNIST(data_dir, train=False, download=True)
        test_transform = transforms.Compose([transforms.ToTensor(), ])
        self.test_dataset.transform = test_transform
        if self.configs["iscond"]:
            self.test_dataset.target_transform = onehot(self.configs["cond_dim"])
        self.test_loader = torch.utils.data.DataLoader(self.test_dataset, batch_size=self.configs["batch_size"], shuffle=True)

        print("Built Dataset and DataLoader")

    def setup_models(self):
        self.model_e = Encoder(encoded_space_dim=self.configs["d"], fc2_input_dim=self.configs["fc_input_dim"],
                               iscond=self.configs["iscond"], cond_dim=self.configs["cond_dim"])
        self.model_d = Decoder(encoded_space_dim=self.configs["d"], fc2_input_dim=self.configs["fc_input_dim"],
                               iscond=self.configs["iscond"], cond_dim=self.configs["cond_dim"])
        self.model_e.to(self.device)
        self.model_d.to(self.device)
        self.loss_fn = torch.nn.MSELoss()
        if self.istrain:
            paras_to_optimize = [
                {"params": self.model_e.parameters()},
                {"params": self.model_d.parameters()}
            ]
            self.optim = torch.optim.Adam(paras_to_optimize, lr=self.configs["lr"], weight_decay=self.configs["weight_decay"])
        print("Built Model and Optimizer and Loss Function")

然後編寫訓練方法和測試方法, 以及繪圖:

    def train_usl(self):
        """ 無監督,純AutoEncoder"""
        utils.setup_seed(3407)
        self.__setup_dataset()
        self.setup_models()
        diz_loss = {"train_loss":[], "val_loss":[]}
        for epoch in range(self.configs["epochs"]):

            self.model_e.train()
            self.model_d.train()
            train_loss = []
            print(f"Train Epoch {epoch}")
            for i, (x, y) in enumerate(self.train_loader):
                x = x.to(self.device)

                if self.configs["iscond"]:
                    y = y.to(self.device)
                    encoded_data = self.model_e(x, y)
                    decoded_data = self.model_d(encoded_data, y)
                    loss_value = self.loss_fn(x, decoded_data)
                else:
                    encoded_data = self.model_e(x)
                    decoded_data = self.model_d(encoded_data)
                    loss_value = self.loss_fn(x, decoded_data)

                self.optim.zero_grad()
                loss_value.backward()
                self.optim.step()
                if i % 15 == 0:
                    print(f"\t partial train loss (single batch: {loss_value.data:.6f})")
                train_loss.append(loss_value.detach().cpu().numpy())
            train_loss_value = np.mean(train_loss)

            self.model_e.eval()
            self.model_d.eval()
            val_loss = 0.
            with torch.no_grad():
                conc_out = []
                conc_label = []
                for x, y in self.valid_loader:
                    x = x.to(self.device)
                    if self.configs["iscond"]:
                        y = y.to(self.device)
                        encoded_data = self.model_e(x, y)
                        decoded_data = self.model_d(encoded_data, y)
                    else:
                        encoded_data = self.model_e(x)
                        decoded_data = self.model_d(encoded_data)
                    conc_out.append(decoded_data.cpu())
                    conc_label.append(x.cpu())
                conc_out = torch.cat(conc_out)
                conc_label = torch.cat(conc_label)
                val_loss = self.loss_fn(conc_out, conc_label)
            val_loss_value = val_loss.data
            print(f"\t Epoch {epoch} test loss: {val_loss.item()}")
            diz_loss["train_loss"].append(train_loss_value)
            diz_loss["val_loss"].append(val_loss_value)
            torch.save(self.model_e.state_dict(), self.run_save_dir + '{}_epoch_{}.pth'.format("aeconve_cond", epoch))
            torch.save(self.model_d.state_dict(), self.run_save_dir + '{}_epoch_{}.pth'.format("aeconvd_cond", epoch))
            self.plot_ae_outputs(epoch_id=epoch)
        plt.figure(figsize=(10, 8))
        plt.semilogy(diz_loss["train_loss"], label="Train")
        plt.semilogy(diz_loss["val_loss"], label="Valid")
        plt.xlabel("Epoch")
        plt.ylabel("Average Loss")
        plt.legend()
        plt.savefig(self.run_save_dir+"LossIter.png", format="png", dpi=300)

    def plot_ae_outputs(self, epoch_id):
        plt.figure()
        for i in range(5):
            ax = plt.subplot(2, 5, i+1)
            img = self.test_dataset[i][0].unsqueeze(0).to(self.device)
            self.model_e.eval()
            self.model_d.eval()
            with torch.no_grad():
                if self.configs["iscond"]:
                    y = self.test_dataset[i][1].unsqueeze(0).to(self.device)
                    rec_img = self.model_d(self.model_e(img, y), y)
                else:
                    rec_img = self.model_d(self.model_e(img))
            plt.imshow(img.cpu().squeeze().numpy(), cmap="gist_gray")
            ax.get_xaxis().set_visible(False)
            ax.get_yaxis().set_visible(False)
            if i == 5//2:
                ax.set_title("Original images")
            ax = plt.subplot(2, 5, i+1+5)
            plt.imshow(rec_img.cpu().squeeze().numpy(), cmap="gist_gray")
            ax.get_xaxis().set_visible(False)
            ax.get_yaxis().set_visible(False)
            if i == 5//2:
                ax.set_title("Reconstructed images")
        plt.savefig(self.run_save_dir+f"test_plot_epoch_{epoch_id}.png", format="png", dpi=300)

main函式里執行如下:

if __name__ == '__main__':
    trainer = TrainerMNIST(istrain=True)
    # trainer.train_usl()

訓練完畢後,在執行儲存的目錄裡面就生成了模型檔案、重構影像、損失函式迭代曲線影像。

再編寫一個測試生成影像的函式:

    def recon_test_one(self, resume_path):
        if (self.model_e is None) or (self.model_d is None):
            self.setup_models()
        self.model_e.eval()
        self.model_d.eval()
        state_dict_e = torch.load(os.path.join(resume_path, f'aeconve_cond_epoch_11.pth'))
        self.model_e.load_state_dict(state_dict_e)
        state_dict_d = torch.load(os.path.join(resume_path, f'aeconvd_cond_epoch_11.pth'))
        self.model_d.load_state_dict(state_dict_d)
        print(self.model_d)
        z = torch.randn(size=(10, 4, ), device=self.device)
        y_label = torch.zeros(size=(10, 10))
        labels = torch.arange(0, 10).unsqueeze(1)
        # print(labels)
        y_label.scatter_(1, labels, 1)
        y_label = y_label.to(self.device)
        recon_images = self.model_d(z, y_label)
        recon_images = recon_images.squeeze().detach().cpu().numpy()
        plt.figure(0)
        for i in range(10):
            plt.subplot(2, 5, i+1)
            plt.imshow(recon_images[i])
            plt.xticks([])
            plt.yticks([])
            plt.title(f"generate_{i}")
        plt.savefig(resume_path+"generate_0.png", format="png", dpi=300)
        plt.show()


if __name__ == '__main__':
    trainer = TrainerMNIST(istrain=True)
    # trainer.test_data()
    # trainer.train_usl()
    trainer.recon_test_one("../run/aeconv_mnist/202404141930conditional_mix/")

測試結果發現其實效果並不好,個人覺得這是Conditional AE自身的問題,學習到的表徵並不充分,沒有涵蓋生成新資料的連續空間,如果是Conditional VAE,就能做到。

相關文章