...
...
for i in tqdm(range(1000)):
if i % 50 == 0:
J1 = torch.zeros((D1, n_params))
J2 = torch.zeros((D2, n_params))
J3 = torch.zeros((D3, n_params))
batch_ind = np.random.choice(len(x_res), kernel_size, replace=False)
x_train, t_train = x_res[batch_ind], t_res[batch_ind]
pred_res = model(x_train, t_train)
pred_left = model(x_left, t_left)
pred_upper = model(x_upper, t_upper)
pred_lower = model(x_lower, t_lower)
for j in range(len(x_train)):
model.zero_grad()
pred_res[j].backward(retain_graph=True)
J1[j, :] = torch.cat([p.grad.view(-1) for p in model.parameters()])
for j in range(len(x_left)):
model.zero_grad()
pred_left[j].backward(retain_graph=True)
J2[j, :] = torch.cat([p.grad.view(-1) for p in model.parameters()])
for j in range(len(x_lower)):
model.zero_grad()
pred_lower[j].backward(retain_graph=True)
pred_upper[j].backward(retain_graph=True)
J3[j, :] = torch.cat([p.grad.view(-1) for p in model.parameters()])
...
...
Here is the code I have roughly modified, I am not sure if it is correct.
J1 = torch.zeros((D1, n_params))
J2 = torch.zeros((D2, n_params))
J3 = torch.zeros((D3, n_params))
batch_ind = np.random.choice(len(x_res), kernel_size, replace=False)
x_train, t_train = x_res[batch_ind], t_res[batch_ind]
pred_res = model(x_train, t_train)
pred_left = model(x_left, t_left)
pred_upper = model(x_upper, t_upper)
pred_lower = model(x_lower, t_lower)
u_x = torch.autograd.grad(pred_res, x_train, grad_outputs=torch.ones_like(pred_res), retain_graph=True, create_graph=True)[0]
u_xx = torch.autograd.grad(u_x, x_train, grad_outputs=torch.ones_like(pred_res), retain_graph=True, create_graph=True)[0]
u_t = torch.autograd.grad(pred_res, t_train, grad_outputs=torch.ones_like(pred_res), retain_graph=True, create_graph=True)[0]
u_tt = torch.autograd.grad(u_t, t_train, grad_outputs=torch.ones_like(pred_res), retain_graph=True, create_graph=True)[0]
wave_opt = u_tt - 4 * u_xx # wave operator
del u_x, u_xx, u_t, u_tt
pred_t = torch.autograd.grad(pred_left, t_left, grad_outputs=torch.ones_like(pred_left), retain_graph=True, create_graph=True)[0]
for j in range(len(x_train)):
model.zero_grad()
wave_opt[j].backward(retain_graph=True)
J1[j, :] = torch.cat([p.grad.view(-1) if p.grad is not None else torch.tensor([0.]).view(-1) for p in model.parameters()])
for j in range(len(x_left)):
model.zero_grad()
pred_t[j].backward(retain_graph=True)
J2[j, :] = torch.cat([p.grad.view(-1) if p.grad is not None else torch.tensor([0.]).view(-1) for p in model.parameters()])
for j in range(len(x_lower)):
model.zero_grad()
pred_left[j].backward(retain_graph=True)
pred_lower[j].backward(retain_graph=True)
pred_upper[j].backward(retain_graph=True)
J3[j, :] = torch.cat([p.grad.view(-1) for p in model.parameters()])