GithubHelp home page GithubHelp logo

guxd / dialogwae Goto Github PK

View Code? Open in Web Editor NEW
125.0 7.0 25.0 17.94 MB

Source Code for DialogWAE: Multimodal Response Generation with Conditional Wasserstein Autoencoder (https://arxiv.org/abs/1805.12352)

License: Other

Python 0.24% OpenEdge ABL 99.76%

dialogwae's Introduction

DialogWAE

It is a PyTorch implementation of the DialogWAE model described in DialogWAE: Multimodal Response Generation with Conditional Wasserstein Auto-Encoder.

Dependency

  • PyTorch 0.4.0
  • Python 3.6
  • NLTK
pip install -r requirements.txt

Train

  • Use pre-trained Word2vec Download Glove word embeddings glove.twitter.27B.200d.txt from https://nlp.stanford.edu/projects/glove/ and save it to the ./data folder. The default setting use 200 dimension word embedding trained on Twitter.

  • Modify the arguments at the top of train.py

  • Train model by

      python train.py --visual
    

The logs and temporary results will be printed to stdout and saved in the ./output path.

  • Visualize the training status in Tensorboard
      tensorboard --logdir output
    

Evaluation

Modify the arguments at the bottom of sample.py

Run model testing by:

    python sample.py

The outputs will be printed to stdout and generated responses will be saved at results.txt in the ./output path.

References

If you use any source code included in this toolkit in your work, please cite the following paper:

@inproceedings{gu2018dialogwae,
      title={Dialog{WAE}: Multimodal Response Generation with Conditional Wasserstein Auto-Encoder},
      author={Gu, Xiaodong and Cho, Kyunghyun and Ha, Jung-Woo and Kim, Sunghun},
      journal={arXiv preprint arXiv:1805.12352},
      year={2018}
}

LICENSE

Copyright 2018 NAVER Corp. All rights reserved.

Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:

  1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.

  2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.

  3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America and IDIAP Research Institute nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

dialogwae's People

Contributors

guxd avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

dialogwae's Issues

Confused about the evaluation of inter-dist metrics.

Hi, thanks for your insightful work!
I'm confused about the evaluation of inter-dist metrics in the ``sample.py'', as follows (simplified):

while True:
    batch = test_loader.next_batch()  # batch_size is 1 here
    ...
    intra_dist1, intra_dist2, inter_dist1, inter_dist2 = metrics.div_distinct(sample_words, sample_lens)
    inter_dist1s.append(inter_dist1)
    inter_dist2s.append(inter_dist2)
    ...
inter_dist1 = float(np.mean(inter_dist1s))
inter_dist2 = float(np.mean(inter_dist2s))
print("inter_dist1 %f, inter_dist2 %f" % (inter_dist1, inter_dist2))
print("Done testing")

I understand that that inter_dist is computed regarding #n_samples predictions in the metrics.div_distinct(). However, shouldn't this inter_dist be calculated on the entire test set?
That is, inter_dist1 = #distinct_unigram/#total_unigram. From the code, it seems that the inter_dist just measures the single batch, and averages the results. Besides, the batch_size is 1 here, so the inter_dist only measures #n_samples predictions (#n_samples=5 in the code). If #n_samples == 1, then the inter_dist is equivalent to the intra_dist?

Hoping for your replies! Thanks ahead!

Code explanation

image

Hi, I want to ask about your code.
What does backward(one) or backward(minus_one) do here.

Warning when run sample.py:RNN module weights are not part of single contiguous chunk of memory.

After training for 100 epoch,I try to test the performance of models using the sample.py.
The first question is: which metric should I use to select the best model? The BLEU precision keeps decreasing while the recall reach the highest point at 88th epoch. I am a freshman in GAN.

Then I decided to find the best performer by testing the models. Here comes the problem:

/search/odin/hejunqing/DialogWAE/modules.py:86: UserWarning: RNN module weights are not part of single contiguous chunk of memory. This means they need to be compacted at every call, possibly greatly increasing memory usage. To compact weights again call flatten_parameters().
hids, h_n = self.rnn(inputs, init_hidden)
/search/odin/hejunqing/DialogWAE/modules.py:141: UserWarning: RNN module weights are not part of single contiguous chunk of memory. This means they need to be compacted at every call, possibly greatly increasing memory usage. To compact weights again call flatten_parameters().
hids, h_n = self.rnn(utt_floor_encs, init_hidden)
/search/odin/hejunqing/DialogWAE/modules.py:289: UserWarning: RNN module weights are not part of single contiguous chunk of memory. This means they need to be compacted at every call, possibly greatly increasing memory usage. To compact weights again call flatten_parameters().
decoder_output, decoder_hidden = self.rnn(decoder_input, decoder_hidden)

These warnings appear when I run sample.py. with pytorch 0.4.0.

seems the loss of both generator and discriminator would collapse?

Hello, I am running your code with SWDA dataset but the loss going like this:

train_loss_AE:2.7721 train_loss_G:298.7391 train_loss_D:-300.4698
DialogWAE_GMP-basic|SWDA@gpu0 epo:[84/100] iter:[1200/1279] step_time:56s elapsed:0:11:33<0:0:46

train_loss_AE:3.0337 train_loss_G:203.4530 train_loss_D:-84.7853

Valid begins with 41 batches with 28 left over samples
Validation valid_loss_AE:3.0522 valid_loss_G:353.3996 valid_loss_D:-353.3996
Valid begins with 5165 batches with 0 left over samples

I did not change any configuration, but I haven't checked the generated data yet. The loss
is just supposed to be like this? or it has already collapsed?

Thank you for the help!

Code explanation about data prepocessing

Hello, thank for your open source. I am trying to understand your code. However, in the data.py, it is confused for me to preprocess the data.

In building vocabulary,

print("Load corpus with train size %d, valid size %d, "
              "test size %d raw vocab size %d vocab size %d at cut_off %d OOV rate %f"
              % (len(self.train_corpus), len(self.valid_corpus), len(self.test_corpus),
                 raw_vocab_size, len(vocab_count), vocab_count[-1][1], float(discard_wc) / len(all_words)))

What do the train size, valid size, and test size mean?
The values of all are 2 since they are a tuple with length of 2.

Do you mean that all vocabularies are from the training, testing, and validation data?
However, it only uses the training data to build the vocabulary in the code.

In formatting dialogue,
Is it essential to add [<s>,<d>,</s>] in the start of the dialogue?
Can I not use this?

thank you.

Wasserstein distance between prior and posterior

Hi,

I'm trying to find the part of the code that attempts to compute the Wasserstein distance between prior and posterior (as in Eq. 5 in your ICLR paper), but couldn't find it. Would you please point to the part of the code for this distance?

Moreover, I found that the latent variables are computed directly from the model (e.g., a fully connected layer) rather than predicting \mu and \sigma and then sampling from that distribution, as stated in Eq. 3 and Eq. 4. Would you please clarify this?

Thanks

SWDA seems to give lot of repetition in the sample responses for the test data.

Context 4-1: ('yeah ', 2)
Context 5-0: ("and the people in the city were saying well why should i go do that make the government do that that ' s not my job ", 27)
Context 6-1: ("right they ' ve got a lot of adjustments to make with coming out of what they ' ve been through ", 22)
Context 7-1: ('now and ', 3) Context 8-1: ("they don ' t understand that to make that work they ' ve got to take some responsibility for themselves it ' s not just the government ' s responsibility anymore ", 32)
Target >> you can't just blame it on the government when they give you the freedom to take care of yourself then that puts some responsibility on you as well

Sample 0 >> it the their their their their their their the she she she she she she she she'she'i she'i she'she'i she'she'i'' she'
Sample 1 >> yeah
Sample 2 >> it is is is but she is
Sample 3 >> the the high school is high high school system
Sample 4 >> and but it is
Sample 5 >> these are just
Sample 6 >> in their of their high of their life is worth of an life life life life life life life life life life life life life life life life life life life life life life life life life life life life
Sample 7 >> it in their their their something is something is something is something something something her life life something something something something something something something something something something something something something something something something something something something something something something something
Sample 8 >> but but but i but i i'i i
Sample 9 >> but and but of their name of their life and their life just never just

Is there some way to avoid this?

Stop training the context during train_G/train_D?

I find it reasonable to stop training utterance encoder when training generator and discriminator.

However, why don't you shut down context encoder as well?

Is there any specific consideration?

Thanks.

Unable to achieve published result in DailyDialogue

Hi,
I am trying to retrain your model as a baseline, and till now SWDA gave the results as per the paper. actually, slightly better. But for the DailyDialog dataset, even after multiple runs the best we got is, (row1 is no validation, row2 on test set

A, E, G are for sim_bow
BLEU-R | BLEU-P | F1 | A | E | G
0.305 | 0.170 | 0.218 | 0.940 | 0.609 | 0.857
0.298 | 0.163 | 0.211 | 0.940 | 0.605 | 0.857

Whereas the paper mentions the best results to be

image

Was there any changes made to the code with respect to the configuration in the paper? I couldn't find any discrepancy. Can you point me to what might be the issue?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.