GithubHelp home page GithubHelp logo

zhenyangiacas / nmt_gan Goto Github PK

View Code? Open in Web Editor NEW
119.0 5.0 37.0 4.11 MB

generative adversarial nets for neural machine translation

License: Apache License 2.0

Shell 0.86% Python 91.32% Perl 7.82%
nmt gan naacl

nmt_gan's Introduction

NMT_GAN

This is an open source implementation of our framework to improve NMT with conditional sequence generative adversarial nets.which is described in the following papers:

Yang, Z., Chen, W., Wang, F., & Xu, B. Improving neural machine translation with conditional sequence generative adversarial nets. (NAACL 2018)

Requirements: Tensorflow 1.2.0, python 2.x

Useage:

pre-train the discriminator by: sh discriminator_pretrain.sh

pre-train the generator by: sh train.sh

generate the samples by: sh generate_sample.sh

run the gan training by: sh gan_train.sh

nmt_gan's People

Contributors

kellymarchisio avatar zhenyangiacas avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

nmt_gan's Issues

Mistake in vocab.py?

Should

        worddict['<PAD>'] = 0
        worddict['<UNK>'] = 1
        worddict['<S>'] = 1
        worddict['</S>'] = 1

be

        worddict['<PAD>'] = 0
        worddict['<UNK>'] = 1
        worddict['<S>'] = 2
        worddict['</S>'] = 3

Release RNNSearch model?

Thanks for releasing this - I've enjoyed playing around with it. Do you intend to release your RNNSearch generator code? I imagine you've released only the Transformer code because it got better results (thank you!), but I'd enjoy comparing to your other published results gotten with RNNSearch. Thanks!

Hi,Where is the training data? Thanks.

Sorry, I didn't read the questions raised by others before. Can you make a data format to the original project? No data is needed, only a small number of data is needed to make a format. Thank you very much!

Dropout

Could you explain the dropout used? I see it is hardcoded to 1.0 (which means it is not used at all). Do you set it to something else when you run? do you have it in one of the config files?
Thanks

the g_loss in function gan_output (in model.py)

Hi, zhen yang, I am confused about the g_loss in function gan_output(662 in model.py) after reading through your paper and Yu's SeqGAN paper, is this line

g_loss = -tf.reduce_sum(
should be the following codes ?

g_loss = -tf.reduce_sum(
                tf.reduce_sum(
                    tf.one_hot(tf.reshape(Y, [-1]), self.config.dst_vocab_size, 1.0, 0.0) * \
                    tf.log(
                        tf.reshape(probs, [-1, self.config.dst_vocab_size])
                    ), 1) * tf.reshape(reward, [-1])
            )

i.e. according to the \nabla J(\theta) in your paper, not probs but log-probs should be used
Is there any problem with this understanding? Thank you for your reply.

the difference between generating sample and sentence translation

Hi, I notice that when generating samples, you just build a model for generation and feed data in it and calculate "generate_samples" directly. However, when translating sentences at the beginning of evaluation, you use beam search to translate, instead of using above models. I wonder why you use such different methods since in my opinion, generating samples and translation are the same tasks. Thanks!

Training data size for the generator and discriminator

I'm sorry to bother you after so many years:

  1. I would be interested to know the size of the data used to pre-train the generator and the pre-train the discriminator on the zh-en data in the paper.
  2. And whether the data used in train:, generator: and discriminator: during adversarial training in config_gan_train.yaml is the same as the data used in the previous pre-training of the generator and discriminator.
    Looking forward to your reply

Preprocessing of en2de data

Hi,
Can you upload your pre-processing script or add information about which scripts did you use?
For example, did you use true-casing, tokenization, removing weird characters, checking the sentence alignments ?

Thanks!

NoneType

Hi,
python2 /NMT_GAN-master/discriminator_pretrain.py

Traceback (most recent call last):
File "/NMT_GAN-master/discriminator_pretrain.py", line 73, in
config = AttrDict(yaml.load(open(args.config)))
TypeError: coercing to Unicode: need string or buffer, NoneType found

Looking forward to your advice or answers.
Best regards,
Thank you very much!

Clarification on files in config_discriminator.yaml

I'm hoping for clarification on the files passed into config_discriminator.yaml.

As I understand it:

  • dis_positive_data and dis_source_data are 1 million random lines from the original target (o_dst) and source training files (o_src), manually padded with to length 50 (I wrote my own script to do this). These files are created using the same files used in generator pretraining. Must we pad the sentences ourselves as I have?
  • dis_negative_data (negative_bpe_u8.txt) is the output from sh generate_sample.sh predicted using o_src
  • dis_src_vocab and dis_dst_vocab are BPE vocabulary lists corresponding to the vocabulary of all BPEs used in o_dst and o_src. The vocabulary lists are created by running o_dst and o_src through vocab.py
    • alternatively, since I am running the en-de translation, should I expect to use vocab.bpe.32000 for src_vocab and dst_vocab instead of separate vocab files with 38831/34432 words, as in the demo config?

Is my understanding correct? Can you please provide clarification on how the files in the discriminator pretraining are created?

For context, I am trying to solve an issue where when running the discriminator pretraining, my loss falls to 2-3, but the accuracy always oscillates around 0.5 - even after 700K steps.

when to stop pretraining generator

Hi, when pretraining the generator after 200 epochs, the accuracy almost reaches 1, and loss is extremely small. However, when pretraining discriminator, the accuracy keeps 0.5 from the beginning to the end. I guess the generator is so powerful that the negative samples generated by it are very close to the positive samples. Then the gan training can't move on under this circumstance. So I wonder when should I stop pretraining generator? (at about 70% accuracy or something else?) Thanks.

GAN training is too slow.

After pretraining the generator and discriminator, the gan_training is too slow even for 1K training data.

gan_train.py line 102

where the get_reward takes the most time.

The step is
rewards = generator.get_reward(x=x,
x_to_maxlen=x_to_maxlen,
y_sample=y_sample_dealed,
y_sample_mask=y_sample_mask,
rollnum=config.rollnum,
disc=discriminator,
max_len=config.discriminator.dis_max_len,
bias_num=config.bias_num,
data_util=du)

May I ask why this happens ?

Thank you in advance.

main process and puzzles about config_generator_train

  1. Since all of my train and evaluation data are positive, is that the right sequence to run files:vocab.py -> train.py -> generate_samples.py -> discriminator_pretrain.py -> gan_train.py -> evaluate.py?

2.I'm a little confused about the function of "evaluation.py". what's the meaning of these keys in config_generator_train:
"test:
src_path: '/data3/jeicy/data/eval/eval/eval_jieba.zh'
dst_path: '/data3/jeicy/data/eval/eval/eval_jieba.zh'
ori_dst_path: '/data/zhyang/dl4mt/corpus/data_450w_en_de/transformer/lf_50_for_gan_train/bleuTest/newstest2013/newstest2013.tok.de'
output_path:
"

decreasing bleu score after gan

Hi, the accuracy keeps between 0.4-0.6 from epoch 4 to nearly epoch 100 when pretraining the generator. BLEU score on dev set is 19.9 right after pretraining the generator. When training the whole model with gan, BLEU score on dev set is decreasing to 19.7. So I've got some questions:

  1. Why does the accuracy of my generator nearly get improved during pretraining? And also from the BLEU score, it seems not good, right?
  2. Why the BLEU score decreases instead? Is that because the initial state of generator and discriminator is not so that synchronous, when the discriminator has no idea how to guide the generator?
  3. When generating samples, there are nearly 10% words are "UNK", I wonder what you do to these "UNK" words?
    Thanks

Can't find data

When I run codes based on config file, the error says there are no such files.
Do you know how to solve this problem? Thank you.

list index out of range

When using a dictionary generated by BPE-processed data, the following error occurred, but there is no problem when using a dictionary generated from data that has not been processed by BPE, how should I modify it?

Traceback (most recent call last):
File "train.py", line 83, in
train(config)
File "train.py", line 18, in train
du.load_vocab()
File "/home/sxq/NMT_GAN-master/utils.py", line 56, in load_vocab
self.src2idx, self.idx2src = load_vocab_(self.config.src_vocab, self.config.src_vocab_size)
File "/home/sxq/NMT_GAN-master/utils.py", line 43, in load_vocab_
vocab = [line.split()[0] for line in codecs.open(path, 'r', 'utf-8')]
IndexError: list index out of range

Looking forward to your reply.

generate blank space when generating samples

Hi, after I've pretrained the generator within 200 epochs, I started generating negative samples. However, I found that 90% of the outputs are blank space. I wonder what kind of problems may cause this output? Thanks!

Clarification on files in config_gan_train.yaml

Hi,

I am working on en2de data.
I have some questions regarding the gan train step:

  • What is the difference between train.src_path, train.dst_path and generator.src_path, generator.dst_path ?

  • and what is the size of discriminator.dis_positive_data, discriminator.dis_negative_data, discriminator.dis_source_data? is it the same file as in the discriminator pretraining (1 million sentences) or is it the whole data set this time (4.5 million)?

  • What is the variable that represents the λ variable in your paper?

  • when does the gan training stop? I didn't find an evaluation part in the gan training code.

Thanks!

InvalidArgumentError: Assign requires shapes of both tensors to match.

After pretraining both generator and discriminator, I started running gan_train.sh but it threw this error "InvalidArgumentError (see above for traceback): Assign requires shapes of both tensors to match. lhs shape= [7753,256] rhs shape= [9069,256]".
After some investigation, I found that it had issues in reloading discriminator model. I figured out the issue as well. In discriminator_pretrain.py, only vocab_size parameter is there but in gan_train.py, vocab_size and vocab_size_s parameters are there. When I commented out vocab_size_s, it resolved this error.

How to get the vocab files?

HI, the src_vocab and dst_vocab in vocab files with 38831/34432 words, is generated by how many pieces of data are generated separately for the vocab.py?
By vocab.py,The result has two columns of data, taking the first column as a vocabulary file?
Thanks~

dis_saveto

Hi, What file corresponds to dis_saveto?
Looking forword to your reply.

file not found

Thank you very much for helping me solve the last question!
when I run "sh discriminator_pretrain.sh" ,
it gives the following error:
"Not found: Unsuccessful TensorSliceReader constructor: Failed to find any matching files for ../experience/data_180w/dis_pretrain/model-180w_zh2en_bpe/dis_180w_bpe"

what should I do?
please give me some suggestions, thank you!

pretrain the discriminator

I found that the validation step in discriminator training code is commented. Does that mean we don't do validation after training in each epoch? In addition, in the paper, you mentioned that the best performance achieved when pretrain the discriminator to get the accuracy of 0.82. Does this accuracy mean the training accuracy? Thanks!

About discriminator

Hi,
Could you tell me the difference between build_train_model and build_discriminator_model in cnn_discriminator.py ? What's the reason for both of them?

GPU out of memory

Hi, I use a GPU with 24G memory to pretrain the generator and generate samples. However, when pretraining the generator for only 2 epochs, it stops because of OOM(out of memory). And when I want to try generating samples with the model lastly saved, it soon got OOM again. So I wonder how much GPU memory should leave for running this project, thanks a lot!
By the way, I have adjusted the batch size from 256 in your code into 100. I don't know whether it works

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.