from Utils.option import * from data import test_dataloader from Utils.utils import * from Utils.metrics import * # from Utils.metrics import ssim as ssim from torchvision import utils as vutils from models.network import Network as Net import lpips from Utils.option import * import torch def _eval(testLogger): test_datalaoder = test_dataloader(args.testset_path, batch_size=1) device = args.device loss_fn = lpips.LPIPS(net='vgg').to(device) backbone_model = Net(args.mode) backbone_model = backbone_model.to(device) state_dict = torch.load(args.test_model) backbone_model.load_state_dict(state_dict['backbone_model']) backbone_model = backbone_model.to(device) backbone_model.eval() torch.cuda.empty_cache() ssims_1, psnrs_1, lpips_1 = [], [], [] for iter_idx, test_data in enumerate(test_datalaoder): input_img, label_img, name = test_data input_img = input_img.to(device) label_img = label_img.to(device) with torch.no_grad(): # 1 test # pred = backbone_model(input_img, input_img)[2] # 2 test # prompt_img = backbone_model(input_img)[2] # pred = backbone_model(input_img, prompt_img)[2] # 3 test # prompt_img = backbone_model(input_img)[2] # prompt_img = backbone_model(input_img, prompt_img)[2] # pred = backbone_model(input_img, prompt_img)[2] # 4 test prompt_img = backbone_model(input_img)[2] prompt_img = backbone_model(input_img, prompt_img)[2] prompt_img = backbone_model(input_img, prompt_img)[2] pred = backbone_model(input_img, prompt_img)[2] per_psnr_1 = psnr(pred, label_img) per_ssim_1 = ssim(pred, label_img).item() per_lpips_value = loss_fn(pred, label_img).item() # print(per_lpips_value) ssims_1.append(per_ssim_1) psnrs_1.append(per_psnr_1) lpips_1.append(per_lpips_value) print(f'\n {name[0]} iter processing:{iter_idx + 1} psnr:{per_psnr_1:.4f} ssim:{per_ssim_1:.4f} lpips:{per_lpips_value:.4f}', end='',flush=True) testLogger.write(f'\n {name[0]} iter processing:{iter_idx + 1} psnr:{per_psnr_1:.4f} ssim:{per_ssim_1:.4f} lpips:{per_lpips_value:.4f}') if args.save_image: vutils.save_image(pred, os.path.join(args.output_dir, f'{name[0]}'), normalize=True) avg_ssim_1 = np.mean(ssims_1) avg_psnr_1 = np.mean(psnrs_1) avg_lpips_1 = np.mean(lpips_1) print(f'\n------------------------------------------------------') testLogger.write(f'\n------------------------------------------------------') print(f'\navg_psnr:{avg_psnr_1:.4f} avg_ssim:{avg_ssim_1:.4f} avg_lpips:{avg_lpips_1:.4f}', end='', flush=True) testLogger.write(f'\navg_psnr:{avg_psnr_1:.4f} avg_ssim is:{avg_ssim_1:.4f} avg_lpips:{avg_lpips_1:.4f}')