一個完整的深度學習流程,必須包含的部分有:引數配置、Dataset和DataLoader構建、模型與optimizer與Loss函式建立、訓練、驗證、儲存模型,以及讀取模型、測試集驗證等,對於生成模型來說,還應該有重構測試、生成測試。
AutoEncoder進能夠重構見過的資料、VAE可以透過取樣生成新資料,對於MNIST資料集來說都可以透過全連線神經網路訓練。但是我們需要用CNN來實現呢,也很輕易。
Conditional VAE則有些特殊,它要把資料標籤轉換成One-Hot格式再拼接到資料上,MNIST資料集尚可,資料拉開也就784維度,那麼對於一般的影像資料來說就不可行了。
解決這個問題只需要CNN後面接MLP(多層感知機(就是全連線網路))就行了,對於CNN應有準確的認識,實際上,CNN從通道上看,和全連線神經網路一樣,都是全連線的,如果把某一層的卷積核個數(通道數)視做全連線網路某一層的節點數,那麼兩者結構上是一樣的,不同的是卷積核還要在影像上做複雜運算,每一通道的資料是2維的,全連線網路每一個“通道”是1個標量,因此通常CNN的資料是四個維度(bs, c, h, w),而全連線網路的資料一般兩個維度(bs, dim)。
CNN的特點是,隨著層數加深,影像尺寸越來越小,通道數越來越高,此時就和全連線神經網路非常像了,對每一個通道做一個全域性平均池化(GAP)或者最大池化(MaxPooling),那不就變成全連線網路的輸入了?
實現如下(Condition AE,不是Conditional VAE)
編寫Encoder如下:
import torch import torch.nn as nn class Encoder(nn.Module): def __init__(self, encoded_space_dim, fc2_input_dim=128, iscond=False, cond_dim=10): super().__init__() self.encoder_cnn = nn.Sequential( nn.Conv2d(1, 8, 3, stride=2, padding=1), nn.ReLU(True), nn.Conv2d(8, 16, 3, stride=2, padding=1), nn.BatchNorm2d(16), nn.ReLU(True), nn.Conv2d(16, 32, 3, stride=2, padding=0), nn.ReLU(True) ) self.flatten = nn.Flatten(start_dim=1) if iscond: self.encoder_lin = nn.Sequential( nn.Linear(3 * 3 * 32 + cond_dim, fc2_input_dim), nn.ReLU(True), nn.Linear(128, encoded_space_dim) ) else: self.encoder_lin = nn.Sequential( nn.Linear(3 * 3 * 32, fc2_input_dim), nn.ReLU(True), nn.Linear(128, encoded_space_dim) ) self.iscond=iscond def forward(self, x, cond_vec=None): x = self.encoder_cnn(x) x = self.flatten(x) if self.iscond: x = self.encoder_lin(torch.cat([x, cond_vec], dim=1)) else: x = self.encoder_lin(x) return x
Decoder用轉置卷積,如下:
class Decoder(nn.Module): def __init__(self, encoded_space_dim, fc2_input_dim=128, iscond=False, cond_dim=10): super().__init__() if iscond: self.decoder_lin = nn.Sequential( nn.Linear(encoded_space_dim+cond_dim, fc2_input_dim), nn.ReLU(True), nn.Linear(128, 3 * 3 * 32), nn.ReLU(True) ) else: self.decoder_lin = nn.Sequential( nn.Linear(encoded_space_dim, fc2_input_dim), nn.ReLU(True), nn.Linear(128, 3 * 3 * 32), nn.ReLU(True) ) self.unflatten = nn.Unflatten(dim=1, unflattened_size=(32, 3, 3)) self.decoder_conv = nn.Sequential( nn.ConvTranspose2d(32, 16, 3, stride=2, output_padding=0), nn.BatchNorm2d(16), nn.ReLU(True), nn.ConvTranspose2d(16, 8, 3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(8), nn.ReLU(True), nn.ConvTranspose2d(8, 1, 3, stride=2, padding=1, output_padding=1) ) self.iscond = iscond def forward(self, x, cond_vec=None): if self.iscond: x = self.decoder_lin(torch.cat([x, cond_vec], dim=1)) else: x = self.decoder_lin(x) x = self.unflatten(x) x = self.decoder_conv(x) x = torch.sigmoid(x) return x
訓練直接寫一個Trainer,初始化、資料集、模型初始化部分如下:
class TrainerMNIST(object): def __init__(self, istrain=False): self.istrain = istrain self.configs = { "lr": 0.001, "weight_decay": 1e-5, "batch_size": 256, "d": 4, "fc_input_dim": 128, "seed": 3407, "epochs": 12, "iscond": True, "cond_dim": 10, } self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.timestr = time.strftime("%Y%m%d%H%M", time.localtime()) # self.in_channels = 1 # self.demo_test = demo_test self.run_save_dir = "../run/aeconv_mnist/" + self.timestr + 'conditional_mix/' if istrain: # if not self.demo_test: os.makedirs(self.run_save_dir, exist_ok=True) self.model_e, self.model_d = None, None self.optim, self.loss_fn = None, None self.train_loader, self.valid_loader = None, None self.test_dataset, self.test_loader = None, None self.mix_rate = int(0.667 * self.configs["batch_size"]) def __setup_dataset(self): data_dir = 'F:/DATAS/mnist' if self.istrain: train_dataset = torchvision.datasets.MNIST(data_dir, train=True, download=True) train_transform = transforms.Compose([transforms.ToTensor(), ]) train_dataset.transform = train_transform if self.configs["iscond"]: train_dataset.target_transform = onehot(self.configs["cond_dim"]) m = len(train_dataset) train_data, val_data = random_split(train_dataset, [int(m - m * 0.2), int(m * 0.2)]) self.train_loader = torch.utils.data.DataLoader(train_data, batch_size=self.configs["batch_size"]) self.valid_loader = torch.utils.data.DataLoader(val_data, batch_size=self.configs["batch_size"]) self.test_dataset = torchvision.datasets.MNIST(data_dir, train=False, download=True) test_transform = transforms.Compose([transforms.ToTensor(), ]) self.test_dataset.transform = test_transform if self.configs["iscond"]: self.test_dataset.target_transform = onehot(self.configs["cond_dim"]) self.test_loader = torch.utils.data.DataLoader(self.test_dataset, batch_size=self.configs["batch_size"], shuffle=True) print("Built Dataset and DataLoader") def setup_models(self): self.model_e = Encoder(encoded_space_dim=self.configs["d"], fc2_input_dim=self.configs["fc_input_dim"], iscond=self.configs["iscond"], cond_dim=self.configs["cond_dim"]) self.model_d = Decoder(encoded_space_dim=self.configs["d"], fc2_input_dim=self.configs["fc_input_dim"], iscond=self.configs["iscond"], cond_dim=self.configs["cond_dim"]) self.model_e.to(self.device) self.model_d.to(self.device) self.loss_fn = torch.nn.MSELoss() if self.istrain: paras_to_optimize = [ {"params": self.model_e.parameters()}, {"params": self.model_d.parameters()} ] self.optim = torch.optim.Adam(paras_to_optimize, lr=self.configs["lr"], weight_decay=self.configs["weight_decay"]) print("Built Model and Optimizer and Loss Function")
然後編寫訓練方法和測試方法, 以及繪圖:
def train_usl(self): """ 無監督,純AutoEncoder""" utils.setup_seed(3407) self.__setup_dataset() self.setup_models() diz_loss = {"train_loss":[], "val_loss":[]} for epoch in range(self.configs["epochs"]): self.model_e.train() self.model_d.train() train_loss = [] print(f"Train Epoch {epoch}") for i, (x, y) in enumerate(self.train_loader): x = x.to(self.device) if self.configs["iscond"]: y = y.to(self.device) encoded_data = self.model_e(x, y) decoded_data = self.model_d(encoded_data, y) loss_value = self.loss_fn(x, decoded_data) else: encoded_data = self.model_e(x) decoded_data = self.model_d(encoded_data) loss_value = self.loss_fn(x, decoded_data) self.optim.zero_grad() loss_value.backward() self.optim.step() if i % 15 == 0: print(f"\t partial train loss (single batch: {loss_value.data:.6f})") train_loss.append(loss_value.detach().cpu().numpy()) train_loss_value = np.mean(train_loss) self.model_e.eval() self.model_d.eval() val_loss = 0. with torch.no_grad(): conc_out = [] conc_label = [] for x, y in self.valid_loader: x = x.to(self.device) if self.configs["iscond"]: y = y.to(self.device) encoded_data = self.model_e(x, y) decoded_data = self.model_d(encoded_data, y) else: encoded_data = self.model_e(x) decoded_data = self.model_d(encoded_data) conc_out.append(decoded_data.cpu()) conc_label.append(x.cpu()) conc_out = torch.cat(conc_out) conc_label = torch.cat(conc_label) val_loss = self.loss_fn(conc_out, conc_label) val_loss_value = val_loss.data print(f"\t Epoch {epoch} test loss: {val_loss.item()}") diz_loss["train_loss"].append(train_loss_value) diz_loss["val_loss"].append(val_loss_value) torch.save(self.model_e.state_dict(), self.run_save_dir + '{}_epoch_{}.pth'.format("aeconve_cond", epoch)) torch.save(self.model_d.state_dict(), self.run_save_dir + '{}_epoch_{}.pth'.format("aeconvd_cond", epoch)) self.plot_ae_outputs(epoch_id=epoch) plt.figure(figsize=(10, 8)) plt.semilogy(diz_loss["train_loss"], label="Train") plt.semilogy(diz_loss["val_loss"], label="Valid") plt.xlabel("Epoch") plt.ylabel("Average Loss") plt.legend() plt.savefig(self.run_save_dir+"LossIter.png", format="png", dpi=300) def plot_ae_outputs(self, epoch_id): plt.figure() for i in range(5): ax = plt.subplot(2, 5, i+1) img = self.test_dataset[i][0].unsqueeze(0).to(self.device) self.model_e.eval() self.model_d.eval() with torch.no_grad(): if self.configs["iscond"]: y = self.test_dataset[i][1].unsqueeze(0).to(self.device) rec_img = self.model_d(self.model_e(img, y), y) else: rec_img = self.model_d(self.model_e(img)) plt.imshow(img.cpu().squeeze().numpy(), cmap="gist_gray") ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) if i == 5//2: ax.set_title("Original images") ax = plt.subplot(2, 5, i+1+5) plt.imshow(rec_img.cpu().squeeze().numpy(), cmap="gist_gray") ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) if i == 5//2: ax.set_title("Reconstructed images") plt.savefig(self.run_save_dir+f"test_plot_epoch_{epoch_id}.png", format="png", dpi=300)
main函式里執行如下:
if __name__ == '__main__': trainer = TrainerMNIST(istrain=True) # trainer.train_usl()
訓練完畢後,在執行儲存的目錄裡面就生成了模型檔案、重構影像、損失函式迭代曲線影像。
再編寫一個測試生成影像的函式:
def recon_test_one(self, resume_path): if (self.model_e is None) or (self.model_d is None): self.setup_models() self.model_e.eval() self.model_d.eval() state_dict_e = torch.load(os.path.join(resume_path, f'aeconve_cond_epoch_11.pth')) self.model_e.load_state_dict(state_dict_e) state_dict_d = torch.load(os.path.join(resume_path, f'aeconvd_cond_epoch_11.pth')) self.model_d.load_state_dict(state_dict_d) print(self.model_d) z = torch.randn(size=(10, 4, ), device=self.device) y_label = torch.zeros(size=(10, 10)) labels = torch.arange(0, 10).unsqueeze(1) # print(labels) y_label.scatter_(1, labels, 1) y_label = y_label.to(self.device) recon_images = self.model_d(z, y_label) recon_images = recon_images.squeeze().detach().cpu().numpy() plt.figure(0) for i in range(10): plt.subplot(2, 5, i+1) plt.imshow(recon_images[i]) plt.xticks([]) plt.yticks([]) plt.title(f"generate_{i}") plt.savefig(resume_path+"generate_0.png", format="png", dpi=300) plt.show() if __name__ == '__main__': trainer = TrainerMNIST(istrain=True) # trainer.test_data() # trainer.train_usl() trainer.recon_test_one("../run/aeconv_mnist/202404141930conditional_mix/")
測試結果發現其實效果並不好,個人覺得這是Conditional AE自身的問題,學習到的表徵並不充分,沒有涵蓋生成新資料的連續空間,如果是Conditional VAE,就能做到。