Comments (2)
I've reviewed the clm masking code, and I'm little confused about this line here:
I would like to know, what is the purpose of removing the last padding item if the input is padded? Should the last item be removed instead of the feature of the last padding item?
from transformers4rec.
anyway, I changed some code in apply_mask_to_inputs
of clm masking:
def apply_mask_to_inputs_CLM(
self,
inputs: torch.Tensor,
mask_schema: torch.Tensor,
training: bool = False,
testing: bool = False,
) -> torch.Tensor:
if not training and not testing:
# Replacing the inputs corresponding to padded items with a trainable embedding
# To mimic training and evaluation masking strategy
inputs = torch.where(
mask_schema.unsqueeze(-1).bool(),
inputs,
self.masked_item_embedding.to(inputs.dtype),
)
return inputs
# # shift sequence of interaction embeddings
# pos_emb_inp = inputs[:, :-1]
# # Adding a masked item in the sequence to return to the initial sequence.
# pos_emb_inp = torch.cat( # type: ignore
# [
# pos_emb_inp,
# torch.zeros(
# (pos_emb_inp.shape[0], 1, pos_emb_inp.shape[2]),
# dtype=pos_emb_inp.dtype,
# ).to(inputs.device),
# ],
# axis=1,
# )
pos_emb_inp = inputs
pos_emb_inp_new = pos_emb_inp.clone()
# Iterate over each row in the boolean tensor
for i in range(mask_schema.shape[0]):
# Find the index of the last True value in the row
# If there's no True value, idx will be -1
idx = (mask_schema[i].nonzero(as_tuple=True)[0]).max().item() if mask_schema[i].any() else -1
# Replace corresponding item in feature tensor with a zero matrix
if idx != -1:
pos_emb_inp_new[i, idx] = torch.zeros(pos_emb_inp.shape[2], dtype=pos_emb_inp.dtype).to(inputs.device)
pos_emb_inp = pos_emb_inp_new
# Replacing the inputs corresponding to padded items with a trainable embedding
pos_emb_inp = torch.where(
mask_schema.unsqueeze(-1).bool(),
pos_emb_inp,
self.masked_item_embedding.to(pos_emb_inp.dtype),
)
return pos_emb_inp
Interestingly, with this modification, the metrics of XLNet in the CLM setting have decreased compared to before, making it more reasonable. I've also noticed that this #719 and #746 mentioned a similar issue. Additionally, I observed that the outputs of the predict
and evaluate
functions have become similar:
# same inputs as before!
=========inference===============
PredictionOutput(predictions=(array([[ 4, 3, 7, 5, 6],
[ 5, 7, 30, 3, 22],
[17, 7, 6, 26, 11],
[ 3, 4, 25, 7, 18],
[71, 26, 24, 4, 3],
[ 4, 5, 3, 7, 22]]), array([[ 8.456369 , 5.8312187, 5.6498675, 5.1875997, 4.956415 ],
[10.017425 , 6.9653053, 6.9261403, 6.617892 , 6.447962 ],
[ 9.492199 , 8.0066 , 6.5381365, 6.097307 , 5.9652367],
[ 6.56517 , 6.2221594, 5.269926 , 5.1904283, 5.092497 ],
[ 6.0176606, 5.6072693, 5.431714 , 5.273334 , 5.039026 ],
[ 8.518238 , 6.537841 , 5.880717 , 5.6688414, 5.0900187]],
dtype=float32)), label_ids=None, metrics={'test_runtime': 2.0936, 'test_samples_per_second': 2.866, 'test_steps_per_second': 1.433})
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 14.08it/s]
=========evaluation===============
PredictionOutput(predictions=(array([[ 4, 7, 3, 6, 30],
[ 5, 7, 30, 3, 4],
[17, 7, 6, 26, 4],
[ 3, 4, 25, 39, 7],
[ 4, 24, 71, 26, 14],
[ 4, 30, 5, 3, 7]]), array([[7.4311004, 4.9409456, 4.876472 , 4.848051 , 4.7532825],
[7.4449983, 6.3823347, 6.3242016, 6.160487 , 5.6924496],
[7.9830294, 7.4659915, 6.0589595, 5.7209487, 5.57537 ],
[6.721603 , 5.4623756, 5.102635 , 4.9805446, 4.9469085],
[4.9136906, 4.843173 , 4.796401 , 4.7077065, 4.5187464],
[7.7905493, 5.1441355, 5.0995426, 5.0602183, 5.031133 ]],
dtype=float32)), label_ids=array([ 4, 9, 7, 43, 23, 91]), metrics={'eval_/next-item/ndcg_at_5': 0.27182161808013916, 'eval_/next-item/ndcg_at_10': 0.27182161808013916, 'eval_/next-item/recall_at_5': 0.3333333432674408, 'eval_/next-item/recall_at_10': 0.3333333432674408, 'eval_/next-item/avg_precision_at_5': 0.25, 'eval_/next-item/avg_precision_at_10': 0.25, 'eval_/loss': 4.302186489105225, 'eval_runtime': 0.6978, 'eval_samples_per_second': 8.599, 'eval_steps_per_second': 4.299})
I'm not sure if my changes are correct, and I strongly recommend that you pay attention to this issue. Thanks!
from transformers4rec.
Related Issues (20)
- [QST] ValueError: For masking a categorical_module is required including an item_id.
- [QST] Projecting inputs of NextItemPredictionTask to'64' As weight tying requires the input dimension '320' to be equal to the item-id embedding dimension '64' HOT 4
- [QST] Cross-entropy and pairwise losses are supported in Next Item Prediction
- [QST] How to print metrics while training?
- RuntimeError: CUDF failure at: /__w/cudf/cudf/cpp/src/io/parquet/reader_impl_helpers.cpp:379: Invalid rowgroup index[BUG] HOT 10
- Génerating predictions HOT 5
- [BUG] CausalLanguageModeling masking error on last item only condition HOT 1
- [QST] Help with creating two tower model with transformers. HOT 1
- [FEA] Post context fusion using T4rec api HOT 1
- [BUG] CausalLanguageModeling do not mask last input item HOT 3
- [QST] Extracting User Representation Vectors from Pre-trained Next Item Prediction Model
- [BUG] AttributeError: 'list' object has no attribute 'output_node'" HOT 3
- Model is not generating accurate recommandations [QST]
- [BUG] RuntimeError: PyTorch execute failure: Expected Tensor but got GenericList
- [QST] Problem with defining input module, item embedding table. HOT 4
- [QST] examples/tutorial/02-ETL-with-NVTabular.ipynb
- [BUG] examples/tutorial/01-preprocess.ipynb: Convert timestamp from datetime - NotImplementedError: cuDF does not yet support timezone-aware datetimes
- [QST] Prediction Output Length Not Matching Input Length HOT 1
- Compound Tags.ITEM_ID
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from transformers4rec.