GithubHelp home page GithubHelp logo

Comments (11)

huanghoujing avatar huanghoujing commented on July 2, 2024

你好,我并不是文章的作者。。。

需要注意classification loss和KLD loss的比重,我的实验中两者是1:1,另外,我没有对KLD的特征维度取平均,只取了样本的平均。

from alignedreid-re-production-pytorch.

lawpdas avatar lawpdas commented on July 2, 2024

特征维度取平均是什么意思,就是 size_average 这个参数吗?会有很大影响吗,不太懂

from alignedreid-re-production-pytorch.

huanghoujing avatar huanghoujing commented on July 2, 2024

是的,是size_average这个参数。一个batch算出来的probability维度是[N, D]D是特征维度。如果size_average=True,那么得到的loss是对特征维度、样本维度都取了平均,也即loss除以N*D,这样会导致kld loss权重太小,起不到作用。我的做法是size_average=False,然后得到的loss是[N, D]维度的,对这个loss在D维度上求和,然后再除以N

from alignedreid-re-production-pytorch.

lawpdas avatar lawpdas commented on July 2, 2024

我发现改成 False 后 loss 很难下降,修改 loss weight 也没什么变化。另外我对 pytorch 不太熟,不是很明白您的程序,我这样写有问题吗?

        """compute p1 p2"""
        logits1 = FC1(Feat1(inputs).view(-1,256))
        logits2 = FC2(Feat2(inputs).view(-1,256))
        
        """update net1"""
        l1 = F.cross_entropy(logits1, labels)
        kl1 = F.kl_div(F.log_softmax(logits1, dim=1),F.softmax(logits2.detach(), dim=1),False)/logits1.shape[0]
        loss1 = l1 + kl1
        
        optimizer1.zero_grad()
        loss1.backward()
        optimizer1.step()
        
        """compute p1"""
        logits1 = FC1(Feat1(inputs).view(-1,256))

        """update net2"""
        l2 = F.cross_entropy(logits2, labels)
        kl2 = F.kl_div(F.log_softmax(logits2, dim=1),F.softmax(logits1.detach(), dim=1),False)/logits2.shape[0]
        loss2 = l2 + kl2
        
        optimizer2.zero_grad()
        loss2.backward()
        optimizer2.step()

from alignedreid-re-production-pytorch.

huanghoujing avatar huanghoujing commented on July 2, 2024

你好,单看这个部分的话没发现什么问题,所以我也看不出来原因。

from alignedreid-re-production-pytorch.

huanghoujing avatar huanghoujing commented on July 2, 2024

你好,单纯classification probability mutual loss我之前有在SGD + staircase学习率下降方式上实验过,可以提升,其它优化器和学习率下降方式没有实验过。

from alignedreid-re-production-pytorch.

lawpdas avatar lawpdas commented on July 2, 2024

谢谢,我后来参考 DML 的程序计算 KL ,现在可以正常收敛了,在 cifar100 上也看到了提升。

 kl1 = torch.sum(pred2.detach() * torch.log(1e-8 + pred2.detach() / (pred1 + 1e-8)), 1).mean()

这里也相当于 size_average=False,跟之前的程序应该没什么区别,但不知道为什么之前 loss 就是不降。

from alignedreid-re-production-pytorch.

michuanhaohao avatar michuanhaohao commented on July 2, 2024

@lawpdas 你好,我是论文的一作,关于mutual这一个分支,原版论文是使用了ResNet50和ResNet-Xception50,mutual这种方法有点像模型ensemble,需要两个网络尽可能不一致,如果用两个ImageNet初始化的ResNet50可能没什么效果。

from alignedreid-re-production-pytorch.

lawpdas avatar lawpdas commented on July 2, 2024

谢谢,有用的信息

from alignedreid-re-production-pytorch.

zxy14120448 avatar zxy14120448 commented on July 2, 2024

您好,我按照 Deep Mutual Learning复现 classification 的 mutual learning,但是一直没效果。偶然看到您论文里这部分实验,想问一下单独只做 classification 的 mutual learning 有什么要注意的?

你好,你的网络是怎么设计的啊?能分享一下GitHub链接吗?

from alignedreid-re-production-pytorch.

ZQSIAT avatar ZQSIAT commented on July 2, 2024

您好,我按照 Deep Mutual Learning复现 classification 的 mutual learning,但是一直没效果。偶然看到您论文里这部分实验,想问一下单独只做 classification 的 mutual learning 有什么要注意的?

请问可以分享你的复现代码吗?谢谢~~jack.zhaoqingsong[at]gmail.com

from alignedreid-re-production-pytorch.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.