Comments (15)
Yes, we did not use the validation step when we have been familiar with the training process since it wastes some time. The 0.82 accuracy is tested on the development set, which contains some examples selected randomly from the training set.
from nmt_gan.
Thanks! So the dev set is a (randomly selected ) subset of the training data, NOT an unseen reserved set?
from nmt_gan.
Maybe I did not explain it clearly. The training data I mean, is the whole set we can get for each training (it contains the positive and negative data). The whole set will be divided into two parts, one is the evaluation set(200 positive samples and 200 negative samples) which is randomly selected from the whole set, and the rest is used for training. Hence, the dev set is an unseen reserved set.
from nmt_gan.
OK. It's clear now. Thanks!
from nmt_gan.
I re-implement you discriminator in PyTorch, I construct the following Conv2d layers:
self.conv2d = nn.ModuleList([
nn.Sequential(
Conv2d(in_channels=1,
out_channels=num_filters[i],
kernel_size=(kernel_sizes[i], args.decoder_embed_dim),
stride=1,
padding=0),
nn.BatchNorm2d(num_filters[i]),
nn.ReLU(),
nn.MaxPool2d(kernel_size=(self.max_len - kernel_sizes[i] + 1, 1))
)
for i in range(len(self.kernel_sizes))
])
I think this is similar to your discriminator CNN, use various numbers of kernels with different windows sizes to extract different features. Then concat them on out_channel dimension, and then go through a Highway Network, a FC layer to yield the scores.
However, my discriminator always predicts by chance (around 0.5 accuracies even after many epochs). It seems you only train the discriminator 2 epochs in your code...
Is there any "tricks" I should notice?
My pre-trained generator(Luong's architecture) achieved about 28 BLEU on IWLST14 (DE-EN).
Do you force your generator to make translation within 50 words? Currently, I force my generator to finish translation within 50 words. Do you think I should I remove this constrains?
Thanks!!!
from nmt_gan.
@wangyirui Have you checked your input for the discriminators. In our experiments, the training epochs for the discriminator dependents on the size of training data. Anyway, it is not difficult to pre-train the discriminator to give right predictions.
We did not test our model on IWSLT14(De-EN). But to our knowledge, 28 BLEU points on IWSLT test set seems too low. Have you compared with other implementations? In out experiments, we did remove training example which has the length more than 50 words.
from nmt_gan.
Yes. I compared my implementation with Facebook's fairseq-py system (which yield 28.14). And I also removed the sentences that exceed 50 words. And in each batch, I sampled both positive and negative samples (each batch contains these two class at the same time). It seems your implementation also did that. I don't know why I cannot pre-train the discriminator... If you can release your dataset I think it may be helpful.
BTW, I would like to confirm that the "image" we send to conv2d has the size (seqlen--height, embedSize--width, 1--in_channel), right?
from nmt_gan.
Yes, you are right about the input image. I did not think your problem is closely related with the training data and I am so sorry that I can not release our training data. My suggestion is that you should recheck your input to the discriminator, like, the sentences padded to the same length? The tag for positive and negative sample is right? ....
from nmt_gan.
OK. Thanks for your suggestions, I will check that. In addition, the embedding layer in the discriminator is trained from scratch? Not loaded from pre-trained generator?
from nmt_gan.
@wangyirui The embedding layer in the discriminator is trained from scratch.
from nmt_gan.
Thanks! I think I might figure out the problem... When train the discriminator, the positive sample means the human translation of sentence A, and the negative sample means the machine translation of sentence A, right? I made a stupid mistake there...
from nmt_gan.
Yes, you are right!
from nmt_gan.
@wangyirui - was the issue simply reversing the positive/negative data? I have the same issue of the discriminator predicting at chance, even after many epochs. Wondering how you were able to make training succeed, as I wouldn't expect simply switching positive/negative data to do so. (I've tried the same, with no success).
For reference, my positive data is ~65K lines from my target training data, and my negative data is the corresponding negative predictions from running generate_sample.sh.
from nmt_gan.
@kellymarchisio It turns out that my translation has some problems, which cause the discriminator to predict almost 1 for all sentence. Carefully check your input may be helpful.
from nmt_gan.
@wangyirui
Would you like to tell me what are the exact problems with your translation? I have the same problem as you, I can not pretrain the discriminator which always get about 50% accuracy on development set.
from nmt_gan.
Related Issues (20)
- Preprocessing of en2de data HOT 4
- InvalidArgumentError: Assign requires shapes of both tensors to match.
- Clarification on files in config_gan_train.yaml HOT 5
- How to get the vocab files? HOT 7
- Dropout HOT 1
- NoneType HOT 2
- file not found HOT 1
- the g_loss in function gan_output (in model.py) HOT 1
- Hi,Where is the training data? Thanks.
- Mistake in vocab.py?
- How to get the dis_negative_data in the config_discriminator_pretrain.yaml? HOT 2
- dis_saveto HOT 10
- list index out of range HOT 2
- Can't find data
- GAN training is too slow.
- dis_negative_data HOT 2
- the Adversal training is very very very slow!!!!
- About discriminator
- Nmt GANs
- Training data size for the generator and discriminator HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from nmt_gan.