Comments (9)
This code runs on python3 only. Could you plz upload your error message.
As for the num_layers, you may print the model directly to have an insight into it. There's should be an LSTM
module with layer=2
.
from pytorch-nce.
error just say about except hidden[0] dose not match got.
In index_gru.py
self.rnn = nn.GRU(self.ninp, self.nhid, num_layers=1, dropout=args.dropout, batch_first=True)
I change the 1 to args.num_layers, and also change get_noise_score.batched_rnn_output
view(1,-1,...) to view(args.num_layer,-1,...)
def get_noise_score(self, noise_idx, rnn_output):
"""Get the score of noise given supervised context
Args:
- noise_idx: (B, N, N_r) the noise word index
- rnn_output: output of rnn model
Return:
- noise_score: (B, N, N_r) score for noise word index
"""
noise_emb = self.encoder(noise_idx.view(-1))
noise_ratio = noise_idx.size(2)
# rnn_output of </s> is useless for sentence scoring
batched_rnn_output = rnn_output[:, :-1].unsqueeze(2).expand(
-1, -1, noise_ratio, -1
).contiguous().view(1, -1, self.nhid)
noise_output, _last_hidden = self.rnn(
noise_emb.view(-1, 1, self.nhid),
batched_rnn_output,
)
noise_score = self.scorer(noise_output).view_as(noise_idx)
return noise_score
ex: it will give me error like this:
except hidden[0] size (2, 500, 200) but got (2, 250, 200).
I don't know where I should change the code for fix it.
from pytorch-nce.
Well the GRU version supports only 1 layer
. It's because the CUDNN's GRU only gives the hidden states of the last layer, but this kind of contrasting needs all the hidden states across layers.
See
Line 39 in 0f4cf44
Actually we could stack CUDNN 's bulit in GRU for all the hidden states we need. Nice hint!
from pytorch-nce.
sorry, I didn't read the code carefully.
Thanks a lot.
from pytorch-nce.
I'm going to raise a warning for this situation until the multi-layered version is ready.
from pytorch-nce.
hi, I'm sorry to have your time again. I got other question.
After I run this code which I change a little bit. It gives me pretty high ppl. Val loss will be 6~7 or greater. PPL gonna be hundreds or thousands. So I thought maybe I did some wrong.
I use your code without any change run on wikitext-2. With same params in pytorch/example which will reaching perplexity of 110.44 after epochs 6. But NCE + gru gave higher result, after 18 epochs val loss is 5.4, and ppl is 222.15. The val loss basically reamin the same.
To deal with this situation, any suggestions about this?
from pytorch-nce.
Well, since the actually PPL of index GRU
is hard to compute, so the printed loss is simply the NCE loss, which is not comparable with the CrossEntropy loss.
from pytorch-nce.
Yeah, Index_linear can work well. but still a little bit higher than pytorch/example
. After 20 epoch ppl is 165. Based on that code, change how to generate batch , and control vocab size, the model val loss is hard to reduce. Cannot training. quit strange.
from pytorch-nce.
Hi, Eric. I failed to reproduce the PPL of 165 on my server, could you plz delete the data/penn/vocab.pkl
and runs again to see if it happens again. I suspect this bug is caused by the vocabulary building.
from pytorch-nce.
Related Issues (14)
- main.py does not run 'as is' on penn data HOT 1
- Why need to sub math.log(self.noise_ratio) HOT 1
- Why the target index is not removed from noise samples? HOT 1
- Why the nec_linear output loss while output prob for testing? HOT 1
- why the labels in sampled_softmax_loss func are all zero? HOT 1
- Can I use this loss on my customization model ?
- Is the implementation batched NCE? HOT 1
- Error in NCE expression? HOT 1
- How to select negative samples for NCE loss HOT 1
- truncated bptt without padding? HOT 2
- --nce HOT 2
- why squeeze here? HOT 3
- Target Sample can be included in Noise sample HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pytorch-nce.