optimizer = adamw.AdamW(model.parameters(), lr=opt.lr, weight_decay=0)
scheduler = cosine_scheduler.CosineLRWithRestarts(optimizer, batch_size=opt.batch_size, epoch_size=len(src_set), restart_period=5, t_mult=1.2)
def train(train_gen, model, criterion, optimizer, epoch):
epoch_loss = 0
for iteration, batch in enumerate(train_gen, 1):
nr = batch[0].to(device)
hr = batch[1].to(device)
optimizer.zero_grad()
loss = criterion(model(nr), hr)
epoch_loss += loss.item()
loss.backward()
optimizer.step()
scheduler.batch_step()
if iteration % 1000 == 0:
print('===> Epoch[{e}]({it}/{dl}): Loss{l:.4f};'.format(e=epoch, it=iteration, dl=len(train_gen), l=loss.cpu()))
Current_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
epoch_loss_average = epoch_loss / len(train_gen)
print('===> {ct} Epoch {e} Complete: Avg Loss: {avg_loss:.4f}, Sum Loss: {sum_loss:.4f}'
.format(e=epoch, avg_loss=epoch_loss_average, sum_loss=epoch_loss, ct=Current_time))