GithubHelp home page GithubHelp logo

pretrain the discriminator about nmt_gan HOT 15 CLOSED

wangyirui avatar wangyirui commented on May 27, 2024
pretrain the discriminator

from nmt_gan.

Comments (15)

ZhenYangIACAS avatar ZhenYangIACAS commented on May 27, 2024

Yes, we did not use the validation step when we have been familiar with the training process since it wastes some time. The 0.82 accuracy is tested on the development set, which contains some examples selected randomly from the training set.

from nmt_gan.

wangyirui avatar wangyirui commented on May 27, 2024

Thanks! So the dev set is a (randomly selected ) subset of the training data, NOT an unseen reserved set?

from nmt_gan.

ZhenYangIACAS avatar ZhenYangIACAS commented on May 27, 2024

Maybe I did not explain it clearly. The training data I mean, is the whole set we can get for each training (it contains the positive and negative data). The whole set will be divided into two parts, one is the evaluation set(200 positive samples and 200 negative samples) which is randomly selected from the whole set, and the rest is used for training. Hence, the dev set is an unseen reserved set.

from nmt_gan.

wangyirui avatar wangyirui commented on May 27, 2024

OK. It's clear now. Thanks!

from nmt_gan.

wangyirui avatar wangyirui commented on May 27, 2024

I re-implement you discriminator in PyTorch, I construct the following Conv2d layers:

self.conv2d = nn.ModuleList([
            nn.Sequential(
                Conv2d(in_channels=1,
                       out_channels=num_filters[i],
                       kernel_size=(kernel_sizes[i], args.decoder_embed_dim),
                       stride=1,
                       padding=0),
                nn.BatchNorm2d(num_filters[i]),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=(self.max_len - kernel_sizes[i] + 1, 1))
            )
            for i in range(len(self.kernel_sizes))
        ])

I think this is similar to your discriminator CNN, use various numbers of kernels with different windows sizes to extract different features. Then concat them on out_channel dimension, and then go through a Highway Network, a FC layer to yield the scores.

However, my discriminator always predicts by chance (around 0.5 accuracies even after many epochs). It seems you only train the discriminator 2 epochs in your code...
Is there any "tricks" I should notice?
My pre-trained generator(Luong's architecture) achieved about 28 BLEU on IWLST14 (DE-EN).
Do you force your generator to make translation within 50 words? Currently, I force my generator to finish translation within 50 words. Do you think I should I remove this constrains?
Thanks!!!

from nmt_gan.

ZhenYangIACAS avatar ZhenYangIACAS commented on May 27, 2024

@wangyirui Have you checked your input for the discriminators. In our experiments, the training epochs for the discriminator dependents on the size of training data. Anyway, it is not difficult to pre-train the discriminator to give right predictions.
We did not test our model on IWSLT14(De-EN). But to our knowledge, 28 BLEU points on IWSLT test set seems too low. Have you compared with other implementations? In out experiments, we did remove training example which has the length more than 50 words.

from nmt_gan.

wangyirui avatar wangyirui commented on May 27, 2024

Yes. I compared my implementation with Facebook's fairseq-py system (which yield 28.14). And I also removed the sentences that exceed 50 words. And in each batch, I sampled both positive and negative samples (each batch contains these two class at the same time). It seems your implementation also did that. I don't know why I cannot pre-train the discriminator... If you can release your dataset I think it may be helpful.
BTW, I would like to confirm that the "image" we send to conv2d has the size (seqlen--height, embedSize--width, 1--in_channel), right?

from nmt_gan.

ZhenYangIACAS avatar ZhenYangIACAS commented on May 27, 2024

Yes, you are right about the input image. I did not think your problem is closely related with the training data and I am so sorry that I can not release our training data. My suggestion is that you should recheck your input to the discriminator, like, the sentences padded to the same length? The tag for positive and negative sample is right? ....

from nmt_gan.

wangyirui avatar wangyirui commented on May 27, 2024

OK. Thanks for your suggestions, I will check that. In addition, the embedding layer in the discriminator is trained from scratch? Not loaded from pre-trained generator?

from nmt_gan.

ZhenYangIACAS avatar ZhenYangIACAS commented on May 27, 2024

@wangyirui The embedding layer in the discriminator is trained from scratch.

from nmt_gan.

wangyirui avatar wangyirui commented on May 27, 2024

Thanks! I think I might figure out the problem... When train the discriminator, the positive sample means the human translation of sentence A, and the negative sample means the machine translation of sentence A, right? I made a stupid mistake there...

from nmt_gan.

ZhenYangIACAS avatar ZhenYangIACAS commented on May 27, 2024

Yes, you are right!

from nmt_gan.

kellymarchisio avatar kellymarchisio commented on May 27, 2024

@wangyirui - was the issue simply reversing the positive/negative data? I have the same issue of the discriminator predicting at chance, even after many epochs. Wondering how you were able to make training succeed, as I wouldn't expect simply switching positive/negative data to do so. (I've tried the same, with no success).

For reference, my positive data is ~65K lines from my target training data, and my negative data is the corresponding negative predictions from running generate_sample.sh.

from nmt_gan.

wangyirui avatar wangyirui commented on May 27, 2024

@kellymarchisio It turns out that my translation has some problems, which cause the discriminator to predict almost 1 for all sentence. Carefully check your input may be helpful.

from nmt_gan.

xixiddd avatar xixiddd commented on May 27, 2024

@wangyirui
Would you like to tell me what are the exact problems with your translation? I have the same problem as you, I can not pretrain the discriminator which always get about 50% accuracy on development set.

from nmt_gan.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.