CROP_SIZE = 96
UPSCALE_FACTOR = 4
NUM_EPOCHS = 60
train_set = TrainDatasetFromFolder(train_path, crop_size=CROP_SIZE, upscale_factor=UPSCALE_FACTOR)
val_set = ValDatasetFromFolder(val_path, upscale_factor=UPSCALE_FACTOR)
train_loader = DataLoader(dataset=train_set, num_workers=4, batch_size=16, shuffle=True)
val_loader = DataLoader(dataset=val_set, num_workers=4, batch_size=1, shuffle=False)
netG = Generator(UPSCALE_FACTOR)
print('# generator parameters:', sum(param.numel() for param in netG.parameters()))
netD = Discriminator()
print('# discriminator parameters:', sum(param.numel() for param in netD.parameters()))
# Incase, want to continue the training process from the previous weigths
# netG.load_state_dict(torch.load(G_weights_load))
# netD.load_state_dict(torch.load(D_weights_load))
generator_criterion = GeneratorLoss()
if torch.cuda.is_available():
netG.cuda()
netD.cuda()
generator_criterion.cuda()
optimizerG = optim.Adam(netG.parameters())
optimizerD = optim.Adam(netD.parameters())
results = {'d_loss': [], 'g_loss': [], 'd_score': [], 'g_score': [], 'psnr': [], 'ssim': [], 'mse' : []}
# torch.autograd.set_detect_anomaly(True)
for epoch in range(1, NUM_EPOCHS + 1):
train_bar = tqdm(train_loader)
running_results = {'batch_sizes': 0, 'd_loss': 0, 'g_loss': 0, 'd_score': 0, 'g_score': 0}
netG.train()
netD.train()
for data, target in train_bar:
g_update_first = True
batch_size = data.size(0)
running_results['batch_sizes'] += batch_size
############################
# (1) Update D network: maximize log(D(x)) + log(1-D(G(z)))
###########################
real_img = torch.Tensor(target)
if torch.cuda.is_available():
real_img = real_img.cuda()
z = torch.Tensor(data)
if torch.cuda.is_available():
z = z.cuda()
fake_img = netG(z)
netD.zero_grad()
real_out_1 = netD(real_img)
real_out = torch.mean(real_out_1)
fake_out_1 = netD(fake_img)
fake_out = torch.mean(fake_out_1)
# if fake_out = 1, real_out = 0 => loss should be max
d_loss = -(torch.log(real_out + 1e-6) + torch.log(1-fake_out + 1e-6))
d_loss.backward(retain_graph=True)
############################
# (2) Update G network: minimize -log(D(G(z))) + Perception Loss + Image Loss
###########################
netG.zero_grad()
g_loss = generator_criterion(fake_out, fake_img, real_img)
g_loss.backward()
fake_img = netG(z)
fake_out = netD(fake_img).mean()
optimizerD.step()
optimizerG.step()
# loss for current batch before optimization
running_results['g_loss'] += g_loss.item() * batch_size
running_results['d_loss'] += d_loss.item() * batch_size
running_results['d_score'] += real_out.item() * batch_size
running_results['g_score'] += fake_out.item() * batch_size
train_bar.set_description(desc='[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f' % (
epoch, NUM_EPOCHS, running_results['d_loss'] / running_results['batch_sizes'],
running_results['g_loss'] / running_results['batch_sizes'],
running_results['d_score'] / running_results['batch_sizes'],
running_results['g_score'] / running_results['batch_sizes']))
netG.eval()
out_path = imgs_save
if not os.path.exists(out_path):
os.makedirs(out_path)
with torch.no_grad():
val_bar = tqdm(val_loader)
valing_results = {'mse': 0, 'ssims': 0, 'psnr': 0, 'ssim': 0, 'batch_sizes': 0}
# val_images = []
for val_lr, val_hr_restore, val_hr in val_bar:
batch_size = val_lr.size(0)
valing_results['batch_sizes'] += batch_size
lr = val_lr
hr = val_hr
if torch.cuda.is_available():
lr = lr.cuda()
hr = hr.cuda()
sr = netG(lr)
batch_mse = ((sr - hr) ** 2).data.mean()
valing_results['mse'] += batch_mse * batch_size
batch_ssim = pytorch_ssim.ssim(sr,hr).item()
valing_results['ssims'] += batch_ssim * batch_size
valing_results['psnr'] = 10 * log10((hr.max()**2) / (valing_results['mse'] / valing_results['batch_sizes']))
valing_results['ssim'] = valing_results['ssims'] / valing_results['batch_sizes']
val_bar.set_description(
desc='[converting LR images to SR images] PSNR: %.4f dB SSIM: %.4f' % (
valing_results['psnr'], valing_results['ssim']))
# val_images.extend(
# [[lr.squeeze(0), hr.data.cpu().squeeze(0),
# sr.data.cpu().squeeze(0)]])
# val_save_bar = tqdm(val_images, desc='[saving training results]')
# index = 1
# for image in val_save_bar:
# utils.save_image(image[0], out_path + "lr_" + str(index) + '.png')
# utils.save_image(image[1], out_path + "hr_" + str(index) + '.png')
# utils.save_image(image[2], out_path + "sr_" + str(index) + '.png')
# index += 1
# save loss\scores\psnr\ssim
results['d_loss'].append(running_results['d_loss'] / running_results['batch_sizes'])
results['g_loss'].append(running_results['g_loss'] / running_results['batch_sizes'])
results['d_score'].append(running_results['d_score'] / running_results['batch_sizes'])
results['g_score'].append(running_results['g_score'] / running_results['batch_sizes'])
results['psnr'].append(valing_results['psnr'])
results['ssim'].append(valing_results['ssim'])
results['mse'].append(valing_results['mse'])
if epoch % 10 == 0:
# save model parameters
torch.save(netG.state_dict(), G_weights_save + 'netG_epoch_%d_%d.pth' % (UPSCALE_FACTOR, epoch))
torch.save(netD.state_dict(), D_weights_save + 'netD_epoch_%d_%d.pth' % (UPSCALE_FACTOR, epoch))
data_frame = pd.DataFrame(
data={'Loss_D': results['d_loss'], 'Loss_G': results['g_loss'], 'Score_D': results['d_score'],
'Score_G': results['g_score'], 'PSNR': results['psnr'], 'SSIM': results['ssim'], 'MSE' : results['mse']},
index=range(1, epoch + 1))
data_frame.to_csv(out_stat_path + 'srf_' + str(UPSCALE_FACTOR) + '_train_results.csv', index_label='Epoch')
**ERROR:**
TypeError Traceback (most recent call last)
[<ipython-input-10-26cdce3c7fb0>](https://localhost:8080/#) in <module>()
111 batch_mse = ((sr - hr) ** 2).data.mean()
112 valing_results['mse'] += batch_mse * batch_size
--> 113 batch_ssim = pytorch_ssim.ssim(sr,hr).item()
114 valing_results['ssims'] += batch_ssim * batch_size
115 valing_results['psnr'] = 10 * log10((hr.max()**2) / (valing_results['mse'] / valing_results['batch_sizes']))
1 frames
[/usr/local/lib/python3.7/dist-packages/pytorch_ssim/__init__.py](https://localhost:8080/#) in _ssim(img1, img2, window, window_size, channel, size_average)
16
17 def _ssim(img1, img2, window, window_size, channel, size_average = True):
---> 18 mu1 = F.conv2d(img1, window, padding = window_size/2, groups = channel)
19 mu2 = F.conv2d(img2, window, padding = window_size/2, groups = channel)
20
TypeError: conv2d() received an invalid combination of arguments - got (Tensor, Tensor, groups=int, padding=float), but expected one of:
* (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, tuple of ints padding, tuple of ints dilation, int groups)
* (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, str padding, tuple of ints dilation, int groups)