# pred_mean = pred.float().mean() # label_img_mean = label_img.float().mean() # pred_adjust = torch.clamp((pred * (label_img_mean / pred_mean)), 0, 1) per_psnr_1 = psnr(pred, label_img) per_ssim_1 = ssim(pred, label_img).item() per_lpips_value = loss_fn(pred, label_img).item()