Thank you very much for the training code provided by the author.
I use trained model and your pre-trained model to test the Levin dataset, but there are some issues. The objective metric (PSNR) as follows.
- pre-trained model provided by you: 20.056
- trained model by your training codes: 20.1841
- trained model by my training codes: 20.629
But, in another of your papers named "Photon-Limited Blind Deconvolution using Unsupervised Iterative Kernel Estimation", The PSNR result of P4IP method t is 22.45 under the same conditions and data set.
.
I would like to ask which part I need to improve in order to achieve the best results, or where my problems arise. Please correct me. Good luck.
Note: the Poisson peak is 10, and test dataset is Levin, you can get it from Google Drive
the test codes as follows:
import torch
import numpy as np
from scipy.io import loadmat
from skimage.metrics import peak_signal_noise_ratio as cal_psnr
from models.network_p4ip import P4IP_Net
np.random.seed(20000320)
data_dir = "E:/Project/Image_Processing/Poisson_Deblurring/TestDataSet/Levin.mat"
data = loadmat(data_dir)["data"]
psnr = 0.0
peak = 10.0
model = P4IP_Net()
# model_file = "docs/p4ip_100epoch.pth" # pre-trained
model_file = "model_zoo/p4ip_net_6epoch.pth"
checkpoint = torch.load(model_file)
model.load_state_dict(checkpoint)
model = model.cuda()
for i in range(data.shape[0]):
sharp, blurry, kernel = data[i]
blurry = np.float32(np.random.poisson(blurry * peak))
kernel = np.float32(kernel)
peak_m = np.array([peak])
peak_m = peak_m[np.newaxis, np.newaxis, np.newaxis,]
kernel = kernel[np.newaxis, np.newaxis, ]
blurry = blurry[np.newaxis, np.newaxis, ]
# 转为tensor
blurry_t = torch.from_numpy(blurry.copy())
kernel_t = torch.from_numpy(kernel.copy())
peak_t = torch.tensor(peak_m, dtype=torch.float64)
blurry_t = blurry_t.cuda()
kernel_t = kernel_t.cuda()
peak_t = peak_t.cuda()
restores = model(blurry_t, kernel_t, peak_t)
restored = restores[-1]
restored_m = restored.data.cpu().numpy().squeeze()
psnr += cal_psnr(image_true=sharp, image_test=restored_m)
print(psnr / data.shape[0])