Comments (5)
I think now with zero_division having been added to JaccardIndex, there's no real need for MeanIoU at all. MeanIoU also doesn't follow the classical API.
The original motivation for this was that JaccardIndex used to assign a score of 0 to absent and ignored classes so you couldn't do classwise and macro averaging correctly. Now, you can just set zero_division to NaN and average to None, and get correct class scores. From there, you could do a nanmean to get the correct macro average.
from torchmetrics.
Hi, I also noticed that the MeanIoU was wrong during my experiments.
I developed something that seems to work on my side, which returns the same results as evaluate's mean iou however based on @DimitrisMantas I wonder if it is relevant to submit a PR. I am not familiar enough with the JaccardIndex implementation in torchmetrics.
For reference, here is the undocumented code I developed (which has not been rigorusly tested for now), let me know if submitting a PR is something interesting for you, I'd gladly contribute to this repo.
from typing import Any, Literal
import torch
from torch import Tensor
from torchmetrics import Metric
def _compute_intersection_and_union(
preds: Tensor,
target: Tensor,
num_classes: int,
include_background: bool = False,
input_format: Literal["one-hot", "index", "predictions"] = "index",
) -> tuple[Tensor, Tensor]:
if input_format in ["index", "predictions"]:
if input_format == "predictions":
preds = preds.argmax(1)
preds = torch.nn.functional.one_hot(preds, num_classes=num_classes)
target = torch.nn.functional.one_hot(target, num_classes=num_classes)
if not include_background:
preds[..., 0] = 0
target[..., 0] = 0
reduce_axis = list(range(1, preds.ndim - 1))
intersection = torch.sum(torch.logical_and(preds, target), dim=reduce_axis)
target_sum = torch.sum(target, dim=reduce_axis)
pred_sum = torch.sum(preds, dim=reduce_axis)
union = target_sum + pred_sum - intersection
return intersection, union
class MeanIoU(Metric):
def __init__(
self,
num_classes: int,
include_background: bool = True,
per_class: bool = False,
input_format: Literal["one-hot", "index", "predictions"] = "index",
**kwargs: Any,
) -> None:
Metric.__init__(self, **kwargs)
self.num_classes = num_classes
self.include_background = include_background
self.per_class = per_class
self.input_format = input_format
self.add_state("intersection", default=torch.zeros(num_classes), dist_reduce_fx="sum")
self.add_state("union", default=torch.zeros(num_classes), dist_reduce_fx="sum")
def update(self, preds: Tensor, target: Tensor) -> None:
intersection, union = _compute_intersection_and_union(
preds, target, self.num_classes, self.include_background, self.input_format
)
self.intersection += intersection.sum(0)
self.union += union.sum(0)
def compute(self) -> Tensor:
iou_valid = torch.gt(self.union, 0)
iou = torch.where(
iou_valid,
torch.divide(self.intersection, self.union),
torch.nan,
)
if self.per_class:
return iou
else:
return torch.mean(iou[iou_valid])
from torchmetrics.
Hi! thanks for your contribution!, great first issue!
from torchmetrics.
Hi, I have noticed that MeanIoU is incorrect, when updating it via calling update
method, if I update it via forward
method it works correctly. I have looked into it, and it is because with default MeanIoU parameters forward
method calls _reduce_states
method, which updates score
using the following formula:
reduced = ((self._update_count - 1) * global_state + local_state).float() / self._update_count
,
where global_state
is the score accumulated over previous steps and local_state
is the score on current batch.
The same behavior is observed for all formats and if per_class = True
and per_class = False
Here is the code to reproduce results
import torch
from torchmetrics.segmentation import MeanIoU
def run():
bs = 16
num_classes = 3
h = w = 128
# one-hot, per_class=False
miou_update = MeanIoU(num_classes=num_classes, input_format="one-hot", per_class=False)
miou_call = MeanIoU(num_classes=num_classes, input_format="one-hot", per_class=False)
# test 1, all ones
img_pred = torch.ones(bs, num_classes, h, w, dtype=torch.long)
img_target = torch.ones(bs, num_classes, h, w, dtype=torch.long)
miou_update.update(img_pred, img_target)
miou_call(img_pred, img_target)
print(f"Update: {miou_update.compute()}, {miou_update.score}, {miou_update.update_count}")
print(f"Call: {miou_call.compute()}, {miou_call.score}, {miou_call.update_count}")
# test 2, square in the middle, 100% overlap
img_pred = torch.zeros(bs, num_classes, h, w, dtype=torch.long)
img_target = torch.zeros(bs, num_classes, h, w, dtype=torch.long)
img_pred[:, :, 50:80, 50:80] = 1
img_target[:, :, 50:80, 50:80] = 1
miou_update.update(img_pred, img_target)
miou_call(img_pred, img_target)
print(f"Update: {miou_update.compute()}, {miou_update.score}, {miou_update.update_count}")
print(f"Call: {miou_call.compute()}, {miou_call.score}, {miou_call.update_count}")
# test 3, square in the middle, 50% overlap
img_pred = torch.zeros(bs, num_classes, h, w, dtype=torch.long)
img_target = torch.zeros(bs, num_classes, h, w, dtype=torch.long)
img_pred[:, :, 50:80, 50:80] = 1
img_target[:, :, 65:80, 50:80] = 1
miou_update.update(img_pred, img_target)
miou_call(img_pred, img_target)
print(f"Update: {miou_update.compute()}, {miou_update.score}, {miou_update.update_count}")
print(f"Call: {miou_call.compute()}, {miou_call.score}, {miou_call.update_count}")
# test 4, square in the middle, 0% overlap
img_pred = torch.zeros(bs, num_classes, h, w, dtype=torch.long)
img_target = torch.zeros(bs, num_classes, h, w, dtype=torch.long)
img_pred[:, :, 50:80, 50:80] = 1
img_target[:, :, 80:100, 80:100] = 1
miou_update.update(img_pred, img_target)
miou_call(img_pred, img_target)
print(f"Update: {miou_update.compute()}, { miou_update.score}, { miou_update.update_count}")
print(f"Call: {miou_call.compute()}, { miou_call.score}, { miou_call.update_count}")
# one-hot, per_class=True
miou_update = MeanIoU(num_classes=num_classes, input_format="one-hot", per_class=True)
miou_call = MeanIoU(num_classes=num_classes, input_format="one-hot", per_class=True)
# test 5, all ones
img_pred = torch.ones(bs, num_classes, h, w, dtype=torch.long)
img_target = torch.ones(bs, num_classes, h, w, dtype=torch.long)
miou_update.update(img_pred, img_target)
miou_call(img_pred, img_target)
print(f"Update: {miou_update.compute()}, {miou_update.score}, {miou_update.update_count}")
print(f"Call: {miou_call.compute()}, {miou_call.score}, {miou_call.update_count}")
# test 6, square in the middle, 100% overlap
img_pred = torch.zeros(bs, num_classes, h, w, dtype=torch.long)
img_target = torch.zeros(bs, num_classes, h, w, dtype=torch.long)
img_pred[:, :, 50:80, 50:80] = 1
img_target[:, :, 50:80, 50:80] = 1
miou_update.update(img_pred, img_target)
miou_call(img_pred, img_target)
print(f"Update: {miou_update.compute()}, {miou_update.score}, {miou_update.update_count}")
print(f"Call: {miou_call.compute()}, {miou_call.score}, {miou_call.update_count}")
# test 7, square in the middle, 50% overlap
img_pred = torch.zeros(bs, num_classes, h, w, dtype=torch.long)
img_target = torch.zeros(bs, num_classes, h, w, dtype=torch.long)
img_pred[:, :, 50:80, 50:80] = 1
img_target[:, :, 65:80, 50:80] = 1
miou_update.update(img_pred, img_target)
miou_call(img_pred, img_target)
print(f"Update: {miou_update.compute()}, {miou_update.score}, {miou_update.update_count}")
print(f"Call: {miou_call.compute()}, {miou_call.score}, {miou_call.update_count}")
# test 8, square in the middle, 0% overlap
img_pred = torch.zeros(bs, num_classes, h, w, dtype=torch.long)
img_target = torch.zeros(bs, num_classes, h, w, dtype=torch.long)
img_pred[:, :, 50:80, 50:80] = 1
img_target[:, :, 80:100, 80:100] = 1
miou_update.update(img_pred, img_target)
miou_call(img_pred, img_target)
print(f"Update: {miou_update.compute()}, {miou_update.score}, {miou_update.update_count}")
print(f"Call: {miou_call.compute()}, {miou_call.score}, {miou_call.update_count}")
# index, per_class=False
miou_update = MeanIoU(num_classes=num_classes, input_format="index", per_class=True)
miou_call = MeanIoU(num_classes=num_classes, input_format="index", per_class=True)
# test 9, all ones
img_pred = torch.ones(bs, h, w, dtype=torch.long)
img_target = torch.ones(bs, h, w, dtype=torch.long)
miou_update.update(img_pred, img_target)
miou_call(img_pred, img_target)
print(f"Update: {miou_update.compute()}, {miou_update.score}, {miou_update.update_count}")
print(f"Call: {miou_call.compute()}, {miou_call.score}, {miou_call.update_count}")
# test 10, square in the middle, 100% overlap
img_pred = torch.zeros(bs, h, w, dtype=torch.long)
img_target = torch.zeros(bs, h, w, dtype=torch.long)
img_pred[:, 50:80, 50:80] = 1
img_target[:, 50:80, 50:80] = 1
miou_update.update(img_pred, img_target)
miou_call(img_pred, img_target)
print(f"Update: {miou_update.compute()}, {miou_update.score}, {miou_update.update_count}")
print(f"Call: {miou_call.compute()}, {miou_call.score}, {miou_call.update_count}")
# test 11, square in the middle, 50% overlap
img_pred = torch.zeros(bs, h, w, dtype=torch.long)
img_target = torch.zeros(bs, h, w, dtype=torch.long)
img_pred[:, 50:80, 50:80] = 1
img_target[:, 65:80, 50:80] = 1
miou_update.update(img_pred, img_target)
miou_call(img_pred, img_target)
print(f"Update: {miou_update.compute()}, {miou_update.score}, {miou_update.update_count}")
print(f"Call: {miou_call.compute()}, {miou_call.score}, {miou_call.update_count}")
# test 12, square in the middle, 0% overlap
img_pred = torch.zeros(bs, h, w, dtype=torch.long)
img_pred[:, 50:80, 50:80] = 1
img_target[:, 80:100, 80:100] = 1
miou_update.update(img_pred, img_target)
miou_call(img_pred, img_target)
print(f"Update: {miou_update.compute()}, {miou_update.score}, {miou_update.update_count}")
print(f"Call: {miou_call.compute()}, {miou_call.score}, {miou_call.update_count}")
from torchmetrics.
Hi, I'm new to open-source contributions, and I want to start by working on this issue. @Borda could you please assign it to me?
from torchmetrics.
Related Issues (20)
- Make `ignore_index` work when all batch elements are to be ignored HOT 1
- procrustes alignment HOT 6
- Please don't make us specify the number of outputs for R2Score HOT 1
- The typing on the `compute` method of multi label metrics is wrong HOT 3
- LightningModule attribute as a list of torchmetrics Metric gives error with DDP HOT 1
- SSIM calculation mistake. HOT 4
- BERT score: maximum at self-comparison, symmetry, invariance to additional items HOT 2
- Maybe a new patch release? HOT 8
- Broken "source" link on torchmetrics.Metric HOT 2
- MeanIoU and GeneralizedDiceScore doesn't work properly with index tensors when `per_class=True` HOT 2
- UnboundLocalError: local variable 'b' referenced before assignment HOT 1
- Docs: fix for R^2 and SMAPE HOT 1
- Autograd with DDP HOT 2
- Mean Average Precision producing incorrect values, maybe? HOT 1
- Segmentation IOU compute Ignore some tagged values that don't need to be recorded (such as 255) HOT 1
- PESQ No utterances detected HOT 2
- Logging a class-wise metric with ClasswiseWrapper HOT 1
- BinaryAUROC hangs when calling metric.compute on multi-node multi-cards HOT 3
- Enhancement: Clarification of global map for MeanAveragePrecision HOT 1
- MinMaxMetric not working in Pytorch Lightning
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from torchmetrics.