Comments (4)
Hi @tongye98 ,
Thank you for your question!
TLDR; because trg_mask
is used for loss computation. In the loss computation we are interested in what the model actually predicted, not what was prompted before prediction.
The model predicts auto-regressively, which means the model takes previous trg tokens as input.
Let's say the model processes the sequence <s> hallo welt </s>
in src side (german) and <s> hello world </s>
in trg side (english).
self.trg_input
holds the sequence with BOS <s>
prepended, namely <s> hello world
.
In the first timestep, the model takes the src sequence<s> hallo welt </s>
and trg start token <s>
as input, and tries to predict next trg token hello
. Assume that the model wrongly predicted the first token, i.e. hi
instead of hello
. Still, we feed the ground truth prefix to the model in the second step, <s> hallo welt </s>
and <s> hello
so that the model can predict the second token world
accurately. That's why we need here one token ahead (and don't need </s>
because we don't have the next token which the model should predict after </s>
) for teacher forcing, and self.trg_input
serves for this purpose. (BTW, in inference time, the model doesn't have this guidance. So, the model will predict the second token based on <s> hallo welt </s>
and <s> hi
, can suffer from its own mistakes.)
self.trg
holds the sequence with EOS </s>
appended, namely hello world </s>
.
We assumed that the model wrongly predicted hi
in the first step. We further assume that the model predicted the second token world
correctly, but missed again the third token, i.e. !
instead of </s>
. Then we compute the loss between the ground truth hello world </s>
and the prediction hi world !
, and ignore anything after the third token in the loss computation. self.trg_mask
serves for this masking purpose, therefore, we create self.trg_mask
from self.trg
, not self.trg_input
.
I hope it helps to understand the concept. I'm happy to receive further questions :)
PS. why do we need self.trg_mask
?
because we handle sequences with different lengths within a single batch.
Lets's say we have a batch with 3 sentences
hello world </s> <pad> <pad>
i am a student </s>
mt is fun </s> <pad>
We iterate the batch over timesteps column-wise, that is, the model predicts in the first step hi
, i
and mt
, in the second step world
, was
, is
, and so on. So, the <pad>
positions are also generated, something like
hi world ! ! </s>
i was student </s> a
mt is fun </s> for
but we just ignore them after </s>
in the ground truth. It doesn't matter whether the model predicted there correctly or not.
The trg_mask
indicates on which position to compute the loss.
True True True False False
True True True True True
True True True True False
from joeynmt.
Hi, @may-
Thank you for your detailed answer !
Yes, i understand and agree with what you said above.
But i still have a puzzle: trg_mask
is also used in trg_trg_attention
in the decoder.
h1, _ = self.trg_trg_att(x, x, x, mask=trg_mask)
joeynmt/joeynmt/transformer_layers.py
Line 361 in 0f57e93
So, in the training time, if the target side sentence is <s> hello world </s> <pad> <pad>
for example ( <pad>
here for batch) , trg_input
and trg
are:
trg_input = <s> hello world </s> <pad>
trg = hello world </s> <pad> <pad>
If the trg_mask
is generated by trg
, the trg_mask
will be True True True False False
.
After the subsequent_mask
, the final trg_mask
is
True False False False False
True True False False False
True True True False False
True True True False False
True True True False False
If the trg_mask
is generated by trg_input
, the trg_mask
will be True True True True False
.
After the subsequent_mask
, the final trg_mask
is
True False False False False
True True False False False
True True True False False
True True True *True* False
True True True True False
Let's look the two trg_mask
differences. For the latter, in the trg_trg_attention
, the token </s>
will also attention to the token </s>
itself, (*True* position), wouldn't that make more sense ?
Addtional:
trg_mask
is used for loss computation.
But in the source code, the input of the loss function
is only the logits
and kwargs['trg']
(and pad_index
and smoothing
in the __init__
of XentLoss
), not include the trg_mask
. This is also a little confusing to me
Line 76 in 0f57e93
from joeynmt.
Hi @tongye98 ,
Ah, my explanation was somewhat misleading, sorry.
First of all, in the actual loss function implemented in pytorch, they decide the masking position based on the index, not masking matrix. We specify the pad index to ignore them here:
Line 23 in 0f57e93
And you are right, the actual usage of trg_mask
in the code is to mask out the position not to attend. (I just meant that the overall purpose why we use trg_mask
in the self-attention is because we want to ignore them in loss calculation.)
The important point is that the self-attention mechanism never attends to BOS, so the masking pattern after subsequent mask below
True False False False False
True True False False False
True True True False False
True True True False False
True True True False False
actually implies:
hello False False False False # first token can attend to `hello`
hello world False False False # second token can attend to `hello` or `world`
hello world </s> False False # third token can attend to `hello` or `world` or `</s>`
hello world </s> False False # and so on ...
hello world </s> False False
in our example. You know, self-attention mechanism attends to the token positions to be predicted. The start position token is always given, it's not the job of the model to predict BOS...
I recommend the article illustrated transformer for further details of self-attention.
This issue #151 may help, too.
If you mean, attending to future tokens makes more sense, then yeah, there are several research in that direction., i.e.
Attending to Future Tokens for Bidirectional Sequence Generation, but we most probably won't include such techniques to our minimalistic codebase.
from joeynmt.
Please feel free to reopen the issue, if you still have questions.
from joeynmt.
Related Issues (20)
- Multi-GPU training. HOT 5
- JoeyNMT v1 procedure is no more compatible with JoeyNMT v2 HOT 12
- better config validation
- "AutocastCPU only supports Bfloat16" error when following rnn_reverse tutorial HOT 5
- RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu) HOT 1
- AttributeError: module 'packaging' has no attribute 'version' HOT 2
- Unit test FAIL: testSentencepieceTokenizer (test.unit.test_tokenizer.TestTokenizer) HOT 4
- Running build_vocab.py for wmt17_bpe with or without --joint? HOT 3
- RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu) HOT 5
- run predict function in Colab, get ConfigurationError: Invalid `batch_type` option.
- (enhancement) Deploying trained models on HuggingFace Space HOT 2
- Basic iwslt config train failure due to directory errors HOT 1
- Early stopping criteria is only checked for the `ReduceLROnPlateau` scheduler HOT 5
- Link in Tutorial to Collab dead HOT 4
- Tutorial - Test Set Evaluation HOT 5
- Columns and DataType Not Explicitly Set on line 387 of datasets.py
- Unit Test Fails - Windows Installation HOT 4
- serving & ONNX compat ?
- Implementing Knowledge distillation HOT 3
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from joeynmt.