GithubHelp home page GithubHelp logo

Implementation of loss function about cdan HOT 10 OPEN

thuml avatar thuml commented on August 22, 2024
Implementation of loss function

from cdan.

Comments (10)

caozhangjie avatar caozhangjie commented on August 22, 2024

We use a trick to reverse the gradient before the gradient back-propagate from discriminator to the feature extractor. So we do not need to use -1 at the discriminator loss to train the feature extractor.

There is no input from the discriminator to the classification loss. According to the Pytorch auto-grad principle. There is no gradient from the classification loss even we back-propagate from the sum of the classification loss and the transfer loss.

from cdan.

CN-BiGLiu avatar CN-BiGLiu commented on August 22, 2024

Thanks for your answer, the trick is x.register_hook(grl_hook(coeff)), is that right?

from cdan.

caozhangjie avatar caozhangjie commented on August 22, 2024

Yes

from cdan.

MeLonJ10 avatar MeLonJ10 commented on August 22, 2024

How about the tensorflow version? Where is the trick of reversing the gradient?
Thanks a lot!

from cdan.

sy565612345 avatar sy565612345 commented on August 22, 2024

How about the tensorflow version? Where is the trick of reversing the gradient?
Thanks a lot!

The Tensorflow version is under implementation.
The trick of gradient reversing is in pytorch/network.py line 388. The grl_hook adds a grl layer between the ResNet CNN and the Domain Discriminator, which enables the update of the two adversarial players in one feedforward and backward propagation.

from cdan.

sxwawa avatar sxwawa commented on August 22, 2024

In the DANN model, after inserting a GRL between the generator and the discriminator, the gradient of domain loss w.r.t the feature extractor F will be multiplied by -1. But in CDAN model, the input of discriminator is the tensor product between feature vector and the predicted probability vector. So during backward propagation, the domain loss would have gradient with regard to both feature extractor F and classifier G. May I know how your algorithm computes the gradient of domain loss w.r.t. the predicted probabilities output by classifier G? Will the grl_hook also reverse the gradient of domain loss w.r.t. the classifier G? Thanks a lot!

from cdan.

sy565612345 avatar sy565612345 commented on August 22, 2024

In the DANN model, after inserting a GRL between the generator and the discriminator, the gradient of domain loss w.r.t the feature extractor F will be multiplied by -1. But in CDAN model, the input of discriminator is the tensor product between feature vector and the predicted probability vector. So during backward propagation, the domain loss would have gradient with regard to both feature extractor F and classifier G. May I know how your algorithm computes the gradient of domain loss w.r.t. the predicted probabilities output by classifier G? Will the grl_hook also reverse the gradient of domain loss w.r.t. the classifier G? Thanks a lot!

In pytorch/loss.py line 22, softmax_output = input_list[1].detach()
This detaches G from the domain loss during back-propagation, so the domain loss will not be used to update classifier G.

from cdan.

xyqfountain avatar xyqfountain commented on August 22, 2024

I cannot undstand two things. I appreciate it if you can explain. (1)pytorch/loss.py line 33. entropy.register_hook(grl_hook(coeff)) , Why the entropy need this *-1 hook? The grads passed back from the domain discriminator to the feature extractor have been inverted by using x.register_hook(grl_hook(coeff)) , Registering a *-1 hook for the entropy confuses me. (2) I noticed that you use softmax_output=input_list[1].detach() which blocks the grads from the discrininator to the classifier, but the entropy is obtained by loss_func.Entropy(softmax_output) resulting to entropy.requires_grad=True. This means the grads can be back-propagated to the classifier through entropy (am I right?), What is this for?

from cdan.

buerzlh avatar buerzlh commented on August 22, 2024

I cannot undstand two things. I appreciate it if you can explain. (1)pytorch/loss.py line 33. entropy.register_hook(grl_hook(coeff)) , Why the entropy need this *-1 hook? The grads passed back from the domain discriminator to the feature extractor have been inverted by using x.register_hook(grl_hook(coeff)) , Registering a *-1 hook for the entropy confuses me. (2) I noticed that you use softmax_output=input_list[1].detach() which blocks the grads from the discrininator to the classifier, but the entropy is obtained by loss_func.Entropy(softmax_output) resulting to entropy.requires_grad=True. This means the grads can be back-propagated to the classifier through entropy (am I right?), What is this for?

I also feel strange about problem(1). Do you understand now?

from cdan.

buerzlh avatar buerzlh commented on August 22, 2024

In the DANN model, after inserting a GRL between the generator and the discriminator, the gradient of domain loss w.r.t the feature extractor F will be multiplied by -1. But in CDAN model, the input of discriminator is the tensor product between feature vector and the predicted probability vector. So during backward propagation, the domain loss would have gradient with regard to both feature extractor F and classifier G. May I know how your algorithm computes the gradient of domain loss w.r.t. the predicted probabilities output by classifier G? Will the grl_hook also reverse the gradient of domain loss w.r.t. the classifier G? Thanks a lot!

In pytorch/loss.py line 22, softmax_output = input_list[1].detach()
This detaches G from the domain loss during back-propagation, so the domain loss will not be used to update classifier G.

But generally speaking, the domain loss needs to optimize the feature extraction network G

from cdan.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.