Comments (14)
Hi @gatoniel
So after a longer investigation it seemed this is a general problem, that due to some changes in torchmetrics (not related to retrieval metrics) passing a metric object directly to self.log
in lightning resulted in that error.
This PR on the lightning side will solve the error:
Lightning-AI/pytorch-lightning#7055
from torchmetrics.
Hi! thanks for your contribution!, great first issue!
from torchmetrics.
Actually, the above example does not work properly with the mentioned MeanAbsoluteError
either. It produces NaNs.
The MeanAbsoluteError
metric produces NaNs in compute
if it is set back to its default values.
The default value for self.idx
of the RetrievalMAP
is an empty list. I assume that in trainer.run_evaluation
(https://github.com/PyTorchLightning/pytorch-lightning/blob/a72a7992a283f2eb5183d129a8cf6466903f1dc8/pytorch_lightning/trainer/trainer.py#L636) the metrics are set back before they are logged in this line https://github.com/PyTorchLightning/pytorch-lightning/blob/a72a7992a283f2eb5183d129a8cf6466903f1dc8/pytorch_lightning/trainer/trainer.py#L719 .
But I do not understand the pytorch-lightning code base and where the metrics are actually set back to their default state, neither do I know if this is rather an issue for pytorch-lightning.
The following changes to the above code do work. However, according to the last note in https://torchmetrics.readthedocs.io/en/latest/pages/lightning.html the above code should also work.
def validation_step_end(self, outputs):
self.metric(outputs["indexes"], outputs["preds"], outputs["targets"])
def validation_epoch_end(self, validation_step_outputs):
self.log("val_MAP", self.metric.compute())
from torchmetrics.
Which version of pytorch-lightning
are you using? I cannot reproduce your error by installing torchmetrics
both from your commit and from master and with pytorch-lightning==1.2.6
. Retrieval
metrics had problem with fast_dev_run
because with the skip
action there could have been cases with empty tensors. But know they should be solved (and tested).
from torchmetrics.
That´s odd, I have pytorch-lightning version 1.2.6 installed.
from torchmetrics.
Cannot reproduce either using latest master of both torchmetrics
and pytorch-lightning
.
from torchmetrics.
I am sorry, but I just created a new conda environment with python 3.8, installed torch and the latest branches of torchmetrics
and pytorch-lightning
via pip. Then I copied the above code into a new file and ran it. I got the same error as reported above.
Is this function correct
def validation_step_end(self, outputs):
self.metric(outputs["indexes"], outputs["preds"], outputs["targets"])
self.log("val_MAP", self.metric)
or should self.log("val_MAP", self.metric)
not be called here?
from torchmetrics.
Please check whether outputs["indexes"]
, outputs["preds"]
and outputs["targets"]
are not empty.
def validation_step_end(self, outputs):
print(outputs)
self.metric(outputs["indexes"], outputs["preds"], outputs["targets"])
self.log("val_MAP", self.metric)
from torchmetrics.
Hi, they are not empty.
As I said, the metric collects the values and calculates them when calling self.metric.compute()
in the validation_epoch_end
function.
from torchmetrics.
@gatoniel can I ask why you are actually using validation_step_end
and not putting all the code in validation_step
?
from torchmetrics.
I wanted to prepare the metric for use in data parallel mode as explained in the last note here https://torchmetrics.readthedocs.io/en/stable/pages/lightning.html
This gives the same error:
def validation_step(self, batch, batch_idx):
x, y = batch
x = x.view(x.size(0), -1)
preds = self.encoder(x).squeeze()
indexes = torch.randint(100, size=preds.size())
targets = torch.randint(2, size=preds.size()).to(bool)
print(indexes, preds, targets)
self.metric(indexes, preds, targets)
self.log("val_MAP", self.metric)
But this runs without problems:
def validation_step(self, batch, batch_idx):
x, y = batch
x = x.view(x.size(0), -1)
preds = self.encoder(x).squeeze()
indexes = torch.randint(100, size=preds.size())
targets = torch.randint(2, size=preds.size()).to(bool)
print(indexes, preds, targets)
self.metric(indexes, preds, targets)
def validation_epoch_end(self, outputs):
self.log("val_MAP", self.metric.compute())
Then again using only
def validation_epoch_end(self, outputs):
self.log("val_MAP", self.metric)
gives the error.
from torchmetrics.
Hi @gatoniel . Sorry but I still cannot reproduce your error. What I believe is that you are calling .compute()
on a metric that didn't receive and update
yet (it was not provided with data). Please try again by installing from master
branch because we pushed a lot of updates yesterday.
Btw, I do not understand why you are using the indexes
in this way. Did you set them to random values only for testing purposes? Otherwise I suggest to wait for the release of 0.3.0
to see examples in the documentation. indexes
are used to group queries because IR metrics work at the query level and not at the example level.
For example, if your batch is composed by 5 examples:
(query 1, document 1) -> NN -> relevance score
(query 1, document 2) -> NN -> relevance score
(query 2, document 1) -> NN -> relevance score
(query 2, document 2) -> NN -> relevance score
(query 2, document 3) - > NN -> relevance score
You should use indexes = tensor([0, 0, 1, 1, 1])
.
from torchmetrics.
Hi, yes, I just used random indexes for testing purposes.
I will try to reproduce with the boring model on colab and then share the notebook.
Yes, I think the problem is that the metric did not receive updates. But I think this is because the metric was set back to defaults in the epoch_end call so that all values it received during the epoch are erased.
from torchmetrics.
It could be an explanation but it does not explain why we cannot reproduce this error with the same pytorch-lightning
and torchmetrics
versions. Using Colab to reproduce is a fantastic idea :).
P.s. torchmetrics==0.3.0rc0
is out.
from torchmetrics.
Related Issues (20)
- Cannot import the "clustering" module HOT 5
- Potentially Misleading Error Message for multiclass_precision HOT 2
- Accuracy incurs a GPU/CPU sync point HOT 3
- Support equivalent of torcheval's `merge_state` HOT 1
- Multilabel metrics based on threshold should accept a threshold per label HOT 1
- Recall and Accuracy produce different values when used with logits HOT 1
- Cloning `MultitaskWrapper` with postfix parameters results in incorrect dictionary HOT 7
- f-string should be used in exception HOT 2
- `zero_division` does not work for binary IoU (Jaccard index) calculation and returns NaN HOT 2
- `higher_is_better` for `InfoLM` should change depending on `information_measure` HOT 2
- GPL 2.0 License detected in `functional/text/chrf.py` HOT 3
- truncation cause error while using `bert_score` HOT 4
- `log_dict` to support `ClasswiseWrapper` HOT 4
- `FrechetInceptionDistance` hangs when used with multiple devices in PytorchLightning HOT 3
- Make `ignore_index` work when all batch elements are to be ignored HOT 1
- procrustes alignment HOT 6
- Please don't make us specify the number of outputs for R2Score HOT 1
- The typing on the `compute` method of multi label metrics is wrong HOT 3
- LightningModule attribute as a list of torchmetrics Metric gives error with DDP HOT 1
- SSIM calculation mistake. HOT 4
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from torchmetrics.