GithubHelp home page GithubHelp logo

Comments (16)

maximsch2 avatar maximsch2 commented on June 21, 2024 1

This makes sense to me. Another nit/question is that self._forward_cache seems like it doesn't necessarily belong on self as it can also be a heavy object we don't want to store long-term. I would suggest even something like this:

def forward(self, *args, **kwargs):

    if not self.compute_on_step:
        with torch.no_grad():
            self.update(*args, **kwargs)
         return

    self._to_sync = self.dist_sync_on_step # why are we resetting this on every forward btw?

    # save context before switch
    cache = {attr: getattr(self, attr) for attr in self._defaults.keys()}

    # call reset, update, compute, on single batch
    self.reset()
    self.update(*args, **kwargs)
    result = self.compute()

    # merge new and old context without recomputing update
    for attr, val in cache.items():
        setattr(self, attr, self._reductions[attr](val, getattr(self, attr)))

    return result

from torchmetrics.

janvainer avatar janvainer commented on June 21, 2024 1

Yes, I will make a PR :) But probably not today, but on the weekend.

from torchmetrics.

github-actions avatar github-actions commented on June 21, 2024

Hi! thanks for your contribution!, great first issue!

from torchmetrics.

Borda avatar Borda commented on June 21, 2024

@PyTorchLightning/core-metrics thoughts?

from torchmetrics.

janvainer avatar janvainer commented on June 21, 2024

@maximsch2 I think the self._forward_cache is used in tests.

from torchmetrics.

SkafteNicki avatar SkafteNicki commented on June 21, 2024

self._forward_cache is also used for internally logging in lightning, see this file:
https://github.com/PyTorchLightning/pytorch-lightning/blob/555a6fea212e340f0b3a9684829e6027e4ba27c0/pytorch_lightning/core/step_result.py#L303
so that cannot be removed.
@janvainer I really like this suggestion. I cannot completely comprehend if changing will work for all metrics (I am sure it will work for the majority). Could you try locally and see if you can get all tests passing with this change?

from torchmetrics.

janvainer avatar janvainer commented on June 21, 2024

@SkafteNicki the current code is a bit problematic when it comes to the use of dist_reduce_fx argument.
The reduction function should mimic what happens in the update function. But the current framework does not enforce that.
So for example, in the tests, the DummyMetric uses dist_reduce_fx=None, but in update there is self.x += x, so the reduction for distributed use and update are not aligned. This is problematic if we want to use only a single update call in forward, because we don't know how what reduction is used in update. Check out the draft PR #141.

For this to work fine, the update method should calculate a single-step value of the metric and the internal state should be updated with a provided reduction function. What do you think? I am afraid this suggestion requires an api change, which will probably not be approved. Maybe there is some other way how to accomplish the suggestion in this issue, but I don't see it yet. Any ideas?

from torchmetrics.

SkafteNicki avatar SkafteNicki commented on June 21, 2024

@janvainer i agree that this is problemation. I am pretty sure that dist_reduce_fx is aligned with the operations for all "real" implemented metrics (else test should fail) and the dummy metric is a special case where we have not aligned. With that in mind we could make the change.

However, it is troublesome that we do not have a explicit check. To get this correct, we would need a API change, with the most simple being that instead of the user has to do the operations in the update call, they should instead return them and then we call dist_reduce_fx. Something like:

# accuracy example
def __init__(self, ...)
    self.add_state("correct", torch.tensor(0), dist_reduce_fx = 'sum')
    self.add_state("nobs", torch.tensor(0), dist_reduce_fx = 'sum')

def update(self, preds, target):
    equal = (preds==target).sum()
    n_obs = preds.numel()
    return equal, n_obs # need to return in same order as states are initilized

def _wrap_update(self, update):

    @functools.wraps(update)
    def wrapped_func(*args, **kwargs):
        self._computed = None
        out = update(*args, **kwargs)
        for state, step_val in zip(self._defaults.keys(), out):
            setattr(state, self._reductions[state]([getattr(self, state), step_val]))
        return None
    return wrapped_func

from torchmetrics.

janvainer avatar janvainer commented on June 21, 2024

Thanks, what you are suggesting makes sense to me. This looks relatively nice :) The question is how big api changer this would be for existing users? Should I implement this?

def __init__(...):
    self.add_state("l1_distances", torch.tensor(0.0), dist_reduce_fx="sum")
    self.add_state("numel", torch.tensor(0), dist_reduce_fx="sum")

def update(self, prediction, target, mask):
    distances = l1_dist(prediction, target)
    numel = mask.sum()
    return distances, numel

def forward(self, *args, **kwargs):
    ...
    update_results = self.update(*args, **kwargs)
    # Update old state with new results with reduction functions
    ...
    # return single-step metric value (potentially differentiable scalar)
    return self.compute(*update_results)

@staticmethod
def compute(sum_of_distances, total_numel):
    return sum_of_distances / total_numel

from torchmetrics.

janvainer avatar janvainer commented on June 21, 2024

Ok I like this solution. The api change seems quite big though. Is it ok?

from torchmetrics.

SkafteNicki avatar SkafteNicki commented on June 21, 2024

@janvainer as this is a fundamental API change we need more input on it.
@PyTorchLightning/core-metrics any opinions?

from torchmetrics.

maximsch2 avatar maximsch2 commented on June 21, 2024

One potential issue here is that this prevents in-place updates, right?

In https://github.com/PyTorchLightning/metrics/pull/128/files#diff-a605698e7c4a7849117d5d944263ea2218cc58795426cd2c98165794dc31365eR70-R74 I'm going over each column separately to avoid constructing full matrix as it OOMs for us otherwise (very large number of classes).

from torchmetrics.

SkafteNicki avatar SkafteNicki commented on June 21, 2024

@janvainer after some going back and forth internally, here is an API change that suits the needs without breaking backwards compatability:

new inplace arg

Basic idea is that the base Metric class will get a new fifth argument called inplace that as default is True. The idea is that the flag will indicate if metric states are updated in-place inside the update method or instead will be returned from the update method and we will internally do the reduction using some variation of:

out = update(*args, **kwargs)
for state, step_val in zip(self._defaults.keys(), out):
    setattr(state, self._reductions[state]([getattr(self, state), step_val]))

if the flag is True nothing about the current code should change such that it is still backward compatible. The forward method should then look something like:

def forward(self, *args, **kwargs):
    if not self.inplace:
        return self.fast_forward(*args, **kwargs)
    # insert whatever is in forward now

Positive:

  • faster competition for metrics which can be calculated with inplace=False
  • still support for metrics that require in-place updates such as @maximsch2 mentions
  • backwards compatible with all users custom metrics

Negative:

  • more complex code base

Initial PR should implement the API changes in Metric and maybe redo a single metric such as Accuracy using the new faster API. Then in follow up PRs we can begin changing the remaining metrics.

from torchmetrics.

janvainer avatar janvainer commented on June 21, 2024

Hi @SkafteNicki, thanks for the follow-up. The suggested api changes make sense to me! I unfortunately do not have the capacity to look into it in the upcoming weeks, so feel free to reassign if anyone wants to pick this up (it should be possible to continue from my draft PR). I will be able to work on this some time in June I think.

from torchmetrics.

SkafteNicki avatar SkafteNicki commented on June 21, 2024

@janvainer thanks for letting us know. I am going to un-assign you and also close the PR you have created, then we will see if someone feels like picking it up. Else feel free to ping me if you find the time to contribute again :]
Thanks!

from torchmetrics.

stale avatar stale commented on June 21, 2024

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

from torchmetrics.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.