Comments (14)
这个是根据原文设定的
from libmtl.
from libmtl.
从理论上讲,大部分方法是这样的。但GradDrop的只能为True。
from libmtl.
from libmtl.
LibMTL/LibMTL/weighting/GradDrop.py
Line 40 in 8ccfc9a
这里的维度要对齐
from libmtl.
from libmtl.
从理论上讲,大部分方法是这样的。但GradDrop的只能为True。
对不起,新入门MTL loss优化这块,看你们对话,意思PCGrad、GradVac、CAGrad等方法其实都可以令rep_grad=True的对吧?
from libmtl.
@TangJiakai 可以这样实现,但这样做是否合理以及效果都是未知
from libmtl.
嗯嗯好的,我是想用一下CAGrad,你这么说我就放心了,我试一下看效果如何,谢谢您!
from libmtl.
@TangJiakai 可以这样实现,但这样做是否合理以及效果都是未知
我在我任务试了一下,五个loss (CAGrad方法)minimize之后,权重都是0.2,我寻思这啥也没优化出来啊。我是做图任务的,参数只有共享的Embeddings
from libmtl.
我在我任务试了一下,五个loss (CAGrad方法)minimize之后,权重都是0.2,我寻思这啥也没优化出来啊。我是做图任务的,参数只有共享的Embeddings
@TangJiakai 你用的是CAGrad rep_grad=True ?
from libmtl.
我在我任务试了一下,五个loss (CAGrad方法)minimize之后,权重都是0.2,我寻思这啥也没优化出来啊。我是做图任务的,参数只有共享的Embeddings
@TangJiakai 你用的是CAGrad rep_grad=True ?
@Baijiong-Lin 我不是全部照搬您的框架,我只是在我的模型训练代码里加了LibMTL框架里CAGrad的Backward部分,也就是得到多个loss之后,如何改变参数的梯度。您可以理解成我的rep_grad=True(虽然您的框架里面CAGrad是不允许的)
顺便问一下,
- 我把参数的L2正则化loss也作为一个loss(或者说任务),这种操作合理吗?
- train_loss1: 50.8130, train_loss2: 143.3860, train_loss3: 2705.6676, train_loss4: 2707.7221, train_loss5: 29.0251
上面是我不同loss原始的值,经过CAGrad优化后,权重都是0.2,也就是还是平均的,CAGrad没预期效果捏,我可视化了一下权重变化
你看,五个任务都是0.2
from libmtl.
@TangJiakai 首先,我没法确定你的实现是否正确;其次,CAGrad没有预期效果,你可以问一下CAGrad的作者。L2正则化项不是多任务学习的问题,是辅助学习的问题
from libmtl.
@TangJiakai 首先,我没法确定你的实现是否正确;其次,CAGrad没有预期效果,你可以问一下CAGrad的作者。L2正则化项不是多任务学习的问题,是辅助学习的问题
@Baijiong-Lin 好的,谢谢您!实现应该问题不大,模型部分都是照搬您的,我改动不大
def get_share_params(self):
return self.model.parameters()
def zero_grad_share_params(self):
self.optimizer.zero_grad()
def _compute_grad_dim(self):
self.grad_index = []
for param in self.get_share_params():
self.grad_index.append(param.data.numel())
self.grad_dim = sum(self.grad_index)
def _grad2vec(self):
grad = torch.zeros(self.grad_dim)
count = 0
for param in self.get_share_params():
if param.grad is not None:
beg = 0 if count == 0 else sum(self.grad_index[:count])
end = sum(self.grad_index[:(count+1)])
grad[beg:end] = param.grad.data.view(-1)
count += 1
return grad
def _compute_grad(self, losses, mode):
'''
mode: backward, autograd
'''
grads = torch.zeros(self.task_num, self.grad_dim).to(self.device)
for tn in range(self.task_num):
if mode == 'backward':
losses[tn].backward(retain_graph=True) if (tn+1)!=self.task_num else losses[tn].backward()
grads[tn] = self._grad2vec()
elif mode == 'autograd':
grad = list(torch.autograd.grad(losses[tn], self.get_share_params(), retain_graph=True))
grads[tn] = torch.cat([g.view(-1) for g in grad])
else:
raise ValueError('No support {} mode for gradient computation')
self.zero_grad_share_params()
return grads
def _reset_grad(self, new_grads):
count = 0
for param in self.get_share_params():
if param.grad is not None:
beg = 0 if count == 0 else sum(self.grad_index[:count])
end = sum(self.grad_index[:(count+1)])
param.grad.data = new_grads[beg:end].contiguous().view(param.data.size()).data.clone()
count += 1
def _backward(self, losses):
self.task_num = len(losses)
calpha, rescale = self.calpha, self.rescale
self._compute_grad_dim()
grads = self._compute_grad(losses, mode='backward')
GG = torch.matmul(grads, grads.t()).cpu() # [num_tasks, num_tasks]
g0_norm = (GG.mean()+1e-8).sqrt() # norm of the average gradient
x_start = np.ones(self.task_num) / self.task_num
bnds = tuple((0,1) for x in x_start)
cons=({'type':'eq','fun':lambda x:1-sum(x)})
A = GG.numpy()
b = x_start.copy()
c = (calpha * g0_norm + 1e-8).item()
def objfn(x):
return (x.reshape(1,-1).dot(A).dot(b.reshape(-1,1))+c*np.sqrt(x.reshape(1,-1).dot(A).dot(x.reshape(-1,1))+1e-8)).sum()
res = minimize(objfn, x_start, bounds=bnds, constraints=cons)
w_cpu = res.x
ww = torch.Tensor(w_cpu).to(self.device)
gw = (grads * ww.view(-1, 1)).sum(0)
gw_norm = gw.norm()
lmbda = c / (gw_norm+1e-8)
g = grads.mean(0) + lmbda * gw
if rescale == 0:
new_grads = g
elif rescale == 1:
new_grads = g / (1+calpha**2)
elif rescale == 2:
new_grads = g / (1 + calpha)
else:
raise ValueError('No support rescale type {}'.format(rescale))
self._reset_grad(new_grads)
return w_cpu
我应该只需要把任务的多个loss传给_backward()就好了,那我再检查检查,不行我再问问CAGrad作者吧,谢谢啦~
from libmtl.
Related Issues (20)
- How to export saved models to other formats, such as onnx, mnn, etc HOT 2
- Questions about AlignMTL HOT 4
- When running the example code for QM9, the program seems to enter an infinite loop. QM9案例训练代码无响应
- When running the example code for QM9, the program seems to enter an infinite loop. QM9案例训练代码无响应 HOT 8
- Identical result for CAGrad and MoCo HOT 11
- Question about AlignedMTL HOT 2
- 关于rep_grads参数的问题 HOT 5
- 关于tabular数据的训练问题 HOT 4
- Not found the script for testing in examples/* HOT 6
- Image size of NYUv2 dataset should be 3*288*384 HOT 3
- Error while "from torchvision.models.utils import load_state_dict_from_url" HOT 3
- How to implement MTL scenario when each sample has some of the labels available and not for all the tasks. HOT 1
- GradNorm求梯度 HOT 5
- Inconsistency between formula and implementation in count_improvement function HOT 2
- It seems that some functions are not compatible with the latest pytorch HOT 1
- 关于abstract_weighting.py中get_share_params的问题 HOT 2
- MMOE - Replicate the Original paper Chapter 3.2 (Synthetic Data) HOT 1
- Question about my understanding of aligned-MTL HOT 6
- Distributed DataParallel support HOT 2
- AttributeError: 'Net' object has no attribute 'conv1' HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from libmtl.