Comments (16)
This makes sense to me. Another nit/question is that self._forward_cache
seems like it doesn't necessarily belong on self
as it can also be a heavy object we don't want to store long-term. I would suggest even something like this:
def forward(self, *args, **kwargs):
if not self.compute_on_step:
with torch.no_grad():
self.update(*args, **kwargs)
return
self._to_sync = self.dist_sync_on_step # why are we resetting this on every forward btw?
# save context before switch
cache = {attr: getattr(self, attr) for attr in self._defaults.keys()}
# call reset, update, compute, on single batch
self.reset()
self.update(*args, **kwargs)
result = self.compute()
# merge new and old context without recomputing update
for attr, val in cache.items():
setattr(self, attr, self._reductions[attr](val, getattr(self, attr)))
return result
from torchmetrics.
Yes, I will make a PR :) But probably not today, but on the weekend.
from torchmetrics.
Hi! thanks for your contribution!, great first issue!
from torchmetrics.
@PyTorchLightning/core-metrics thoughts?
from torchmetrics.
@maximsch2 I think the self._forward_cache
is used in tests.
from torchmetrics.
self._forward_cache
is also used for internally logging in lightning, see this file:
https://github.com/PyTorchLightning/pytorch-lightning/blob/555a6fea212e340f0b3a9684829e6027e4ba27c0/pytorch_lightning/core/step_result.py#L303
so that cannot be removed.
@janvainer I really like this suggestion. I cannot completely comprehend if changing will work for all metrics (I am sure it will work for the majority). Could you try locally and see if you can get all tests passing with this change?
from torchmetrics.
@SkafteNicki the current code is a bit problematic when it comes to the use of dist_reduce_fx
argument.
The reduction function should mimic what happens in the update function. But the current framework does not enforce that.
So for example, in the tests, the DummyMetric uses dist_reduce_fx=None
, but in update there is self.x += x
, so the reduction for distributed use and update are not aligned. This is problematic if we want to use only a single update
call in forward
, because we don't know how what reduction is used in update
. Check out the draft PR #141.
For this to work fine, the update
method should calculate a single-step value of the metric and the internal state should be updated with a provided reduction function. What do you think? I am afraid this suggestion requires an api change, which will probably not be approved. Maybe there is some other way how to accomplish the suggestion in this issue, but I don't see it yet. Any ideas?
from torchmetrics.
@janvainer i agree that this is problemation. I am pretty sure that dist_reduce_fx
is aligned with the operations for all "real" implemented metrics (else test should fail) and the dummy metric is a special case where we have not aligned. With that in mind we could make the change.
However, it is troublesome that we do not have a explicit check. To get this correct, we would need a API change, with the most simple being that instead of the user has to do the operations in the update
call, they should instead return them and then we call dist_reduce_fx
. Something like:
# accuracy example
def __init__(self, ...)
self.add_state("correct", torch.tensor(0), dist_reduce_fx = 'sum')
self.add_state("nobs", torch.tensor(0), dist_reduce_fx = 'sum')
def update(self, preds, target):
equal = (preds==target).sum()
n_obs = preds.numel()
return equal, n_obs # need to return in same order as states are initilized
def _wrap_update(self, update):
@functools.wraps(update)
def wrapped_func(*args, **kwargs):
self._computed = None
out = update(*args, **kwargs)
for state, step_val in zip(self._defaults.keys(), out):
setattr(state, self._reductions[state]([getattr(self, state), step_val]))
return None
return wrapped_func
from torchmetrics.
Thanks, what you are suggesting makes sense to me. This looks relatively nice :) The question is how big api changer this would be for existing users? Should I implement this?
def __init__(...):
self.add_state("l1_distances", torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("numel", torch.tensor(0), dist_reduce_fx="sum")
def update(self, prediction, target, mask):
distances = l1_dist(prediction, target)
numel = mask.sum()
return distances, numel
def forward(self, *args, **kwargs):
...
update_results = self.update(*args, **kwargs)
# Update old state with new results with reduction functions
...
# return single-step metric value (potentially differentiable scalar)
return self.compute(*update_results)
@staticmethod
def compute(sum_of_distances, total_numel):
return sum_of_distances / total_numel
from torchmetrics.
Ok I like this solution. The api change seems quite big though. Is it ok?
from torchmetrics.
@janvainer as this is a fundamental API change we need more input on it.
@PyTorchLightning/core-metrics any opinions?
from torchmetrics.
One potential issue here is that this prevents in-place updates, right?
In https://github.com/PyTorchLightning/metrics/pull/128/files#diff-a605698e7c4a7849117d5d944263ea2218cc58795426cd2c98165794dc31365eR70-R74 I'm going over each column separately to avoid constructing full matrix as it OOMs for us otherwise (very large number of classes).
from torchmetrics.
@janvainer after some going back and forth internally, here is an API change that suits the needs without breaking backwards compatability:
new inplace
arg
Basic idea is that the base Metric
class will get a new fifth argument called inplace
that as default is True
. The idea is that the flag will indicate if metric states are updated in-place inside the update method or instead will be returned from the update method and we will internally do the reduction using some variation of:
out = update(*args, **kwargs)
for state, step_val in zip(self._defaults.keys(), out):
setattr(state, self._reductions[state]([getattr(self, state), step_val]))
if the flag is True
nothing about the current code should change such that it is still backward compatible. The forward
method should then look something like:
def forward(self, *args, **kwargs):
if not self.inplace:
return self.fast_forward(*args, **kwargs)
# insert whatever is in forward now
Positive:
- faster competition for metrics which can be calculated with
inplace=False
- still support for metrics that require in-place updates such as @maximsch2 mentions
- backwards compatible with all users custom metrics
Negative:
- more complex code base
Initial PR should implement the API changes in Metric
and maybe redo a single metric such as Accuracy
using the new faster API. Then in follow up PRs we can begin changing the remaining metrics.
from torchmetrics.
Hi @SkafteNicki, thanks for the follow-up. The suggested api changes make sense to me! I unfortunately do not have the capacity to look into it in the upcoming weeks, so feel free to reassign if anyone wants to pick this up (it should be possible to continue from my draft PR). I will be able to work on this some time in June I think.
from torchmetrics.
@janvainer thanks for letting us know. I am going to un-assign you and also close the PR you have created, then we will see if someone feels like picking it up. Else feel free to ping me if you find the time to contribute again :]
Thanks!
from torchmetrics.
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.
from torchmetrics.
Related Issues (20)
- `list` states leak (`Tensor`) memory HOT 1
- Behaviors of AUROC and Average Precision are inconsistent when all labels are equal HOT 4
- The default value of `compute_with_cache` should be `True` . HOT 2
- Bug in ERGAS HOT 2
- Ordinal classification metrics HOT 1
- Binary Classification Expected Calibration Error HOT 2
- `intersection_over_union` error HOT 2
- Clarify that nan is supported in zero_division HOT 4
- Add `_filter_kwargs` in `ClasswiseWrapper` metric wrapper HOT 3
- MetricTracker use higher_is_better as default for maximize HOT 2
- Unpredictable class order when `panoptic_quality(..., return_per_class=True)`
- torchmetrics Accuracy HOT 8
- BootStrapper.reset() does not reset properly
- segmentation.MeanIoU is wrong HOT 3
- DataLoader worker is killed in Docker HOT 4
- MeanAveragePrecision - bug in `max_detection_thresholds` HOT 3
- Documentation of ERGAS HOT 1
- Support for DLM (AIM) metric HOT 1
- Error for index tensor verification in GeneralizedDice if target is zeros HOT 1
- Incorrect caching (`m._compute`) of metrics inside a `MetricCollection` if `compute_groups` are used and `.compute` is called twice HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from torchmetrics.