GithubHelp home page GithubHelp logo

Comments (14)

SkafteNicki avatar SkafteNicki commented on September 22, 2024 1

Hi @gatoniel
So after a longer investigation it seemed this is a general problem, that due to some changes in torchmetrics (not related to retrieval metrics) passing a metric object directly to self.log in lightning resulted in that error.
This PR on the lightning side will solve the error:
Lightning-AI/pytorch-lightning#7055

from torchmetrics.

github-actions avatar github-actions commented on September 22, 2024

Hi! thanks for your contribution!, great first issue!

from torchmetrics.

gatoniel avatar gatoniel commented on September 22, 2024

Actually, the above example does not work properly with the mentioned MeanAbsoluteError either. It produces NaNs.

The MeanAbsoluteError metric produces NaNs in compute if it is set back to its default values.

The default value for self.idx of the RetrievalMAP is an empty list. I assume that in trainer.run_evaluation (https://github.com/PyTorchLightning/pytorch-lightning/blob/a72a7992a283f2eb5183d129a8cf6466903f1dc8/pytorch_lightning/trainer/trainer.py#L636) the metrics are set back before they are logged in this line https://github.com/PyTorchLightning/pytorch-lightning/blob/a72a7992a283f2eb5183d129a8cf6466903f1dc8/pytorch_lightning/trainer/trainer.py#L719 .

But I do not understand the pytorch-lightning code base and where the metrics are actually set back to their default state, neither do I know if this is rather an issue for pytorch-lightning.

The following changes to the above code do work. However, according to the last note in https://torchmetrics.readthedocs.io/en/latest/pages/lightning.html the above code should also work.

    def validation_step_end(self, outputs):
        self.metric(outputs["indexes"], outputs["preds"], outputs["targets"])

    def validation_epoch_end(self, validation_step_outputs):
        self.log("val_MAP", self.metric.compute())

from torchmetrics.

lucadiliello avatar lucadiliello commented on September 22, 2024

Which version of pytorch-lightning are you using? I cannot reproduce your error by installing torchmetrics both from your commit and from master and with pytorch-lightning==1.2.6. Retrieval metrics had problem with fast_dev_run because with the skip action there could have been cases with empty tensors. But know they should be solved (and tested).

from torchmetrics.

gatoniel avatar gatoniel commented on September 22, 2024

That´s odd, I have pytorch-lightning version 1.2.6 installed.

from torchmetrics.

SkafteNicki avatar SkafteNicki commented on September 22, 2024

Cannot reproduce either using latest master of both torchmetrics and pytorch-lightning.

from torchmetrics.

gatoniel avatar gatoniel commented on September 22, 2024

I am sorry, but I just created a new conda environment with python 3.8, installed torch and the latest branches of torchmetrics and pytorch-lightning via pip. Then I copied the above code into a new file and ran it. I got the same error as reported above.

Is this function correct

    def validation_step_end(self, outputs):
        self.metric(outputs["indexes"], outputs["preds"], outputs["targets"])
        self.log("val_MAP", self.metric)

or should self.log("val_MAP", self.metric) not be called here?

from torchmetrics.

lucadiliello avatar lucadiliello commented on September 22, 2024

Please check whether outputs["indexes"], outputs["preds"] and outputs["targets"] are not empty.

def validation_step_end(self, outputs):
    print(outputs)
    self.metric(outputs["indexes"], outputs["preds"], outputs["targets"])
    self.log("val_MAP", self.metric)

from torchmetrics.

gatoniel avatar gatoniel commented on September 22, 2024

Hi, they are not empty.

As I said, the metric collects the values and calculates them when calling self.metric.compute() in the validation_epoch_end function.

from torchmetrics.

SkafteNicki avatar SkafteNicki commented on September 22, 2024

@gatoniel can I ask why you are actually using validation_step_end and not putting all the code in validation_step?

from torchmetrics.

gatoniel avatar gatoniel commented on September 22, 2024

I wanted to prepare the metric for use in data parallel mode as explained in the last note here https://torchmetrics.readthedocs.io/en/stable/pages/lightning.html

This gives the same error:

    def validation_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)
        preds = self.encoder(x).squeeze()

        indexes = torch.randint(100, size=preds.size())
        targets = torch.randint(2, size=preds.size()).to(bool)

        print(indexes, preds, targets)
        self.metric(indexes, preds, targets)
        self.log("val_MAP", self.metric)

But this runs without problems:

    def validation_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)
        preds = self.encoder(x).squeeze()

        indexes = torch.randint(100, size=preds.size())
        targets = torch.randint(2, size=preds.size()).to(bool)

        print(indexes, preds, targets)
        self.metric(indexes, preds, targets)

    def validation_epoch_end(self, outputs):
        self.log("val_MAP", self.metric.compute())

Then again using only

    def validation_epoch_end(self, outputs):
        self.log("val_MAP", self.metric)

gives the error.

from torchmetrics.

lucadiliello avatar lucadiliello commented on September 22, 2024

Hi @gatoniel . Sorry but I still cannot reproduce your error. What I believe is that you are calling .compute() on a metric that didn't receive and update yet (it was not provided with data). Please try again by installing from master branch because we pushed a lot of updates yesterday.

Btw, I do not understand why you are using the indexes in this way. Did you set them to random values only for testing purposes? Otherwise I suggest to wait for the release of 0.3.0 to see examples in the documentation. indexes are used to group queries because IR metrics work at the query level and not at the example level.

For example, if your batch is composed by 5 examples:

(query 1, document 1) -> NN -> relevance score
(query 1, document 2) -> NN -> relevance score
(query 2, document 1) -> NN -> relevance score
(query 2, document 2) -> NN -> relevance score
(query 2, document 3) - > NN -> relevance score

You should use indexes = tensor([0, 0, 1, 1, 1]).

from torchmetrics.

gatoniel avatar gatoniel commented on September 22, 2024

Hi, yes, I just used random indexes for testing purposes.

I will try to reproduce with the boring model on colab and then share the notebook.

Yes, I think the problem is that the metric did not receive updates. But I think this is because the metric was set back to defaults in the epoch_end call so that all values it received during the epoch are erased.

from torchmetrics.

lucadiliello avatar lucadiliello commented on September 22, 2024

It could be an explanation but it does not explain why we cannot reproduce this error with the same pytorch-lightning and torchmetrics versions. Using Colab to reproduce is a fantastic idea :).
P.s. torchmetrics==0.3.0rc0 is out.

from torchmetrics.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.