GithubHelp home page GithubHelp logo

Comments (5)

DimitrisMantas avatar DimitrisMantas commented on September 27, 2024 1

I think now with zero_division having been added to JaccardIndex, there's no real need for MeanIoU at all. MeanIoU also doesn't follow the classical API.

The original motivation for this was that JaccardIndex used to assign a score of 0 to absent and ignored classes so you couldn't do classwise and macro averaging correctly. Now, you can just set zero_division to NaN and average to None, and get correct class scores. From there, you could do a nanmean to get the correct macro average.

from torchmetrics.

juliendenize avatar juliendenize commented on September 27, 2024 1

Hi, I also noticed that the MeanIoU was wrong during my experiments.

I developed something that seems to work on my side, which returns the same results as evaluate's mean iou however based on @DimitrisMantas I wonder if it is relevant to submit a PR. I am not familiar enough with the JaccardIndex implementation in torchmetrics.

For reference, here is the undocumented code I developed (which has not been rigorusly tested for now), let me know if submitting a PR is something interesting for you, I'd gladly contribute to this repo.

from typing import Any, Literal

import torch
from torch import Tensor
from torchmetrics import Metric

def _compute_intersection_and_union(
    preds: Tensor,
    target: Tensor,
    num_classes: int,
    include_background: bool = False,
    input_format: Literal["one-hot", "index", "predictions"] = "index",
) -> tuple[Tensor, Tensor]:
    if input_format in ["index", "predictions"]:
        if input_format == "predictions":
            preds = preds.argmax(1)
        preds = torch.nn.functional.one_hot(preds, num_classes=num_classes)
        target = torch.nn.functional.one_hot(target, num_classes=num_classes)

    if not include_background:
        preds[..., 0] = 0
        target[..., 0] = 0

    reduce_axis = list(range(1, preds.ndim - 1))
    intersection = torch.sum(torch.logical_and(preds, target), dim=reduce_axis)
    target_sum = torch.sum(target, dim=reduce_axis)
    pred_sum = torch.sum(preds, dim=reduce_axis)
    union = target_sum + pred_sum - intersection

    return intersection, union


class MeanIoU(Metric):
    def __init__(
        self,
        num_classes: int,
        include_background: bool = True,
        per_class: bool = False,
        input_format: Literal["one-hot", "index", "predictions"] = "index",
        **kwargs: Any,
    ) -> None:
        Metric.__init__(self, **kwargs)

        self.num_classes = num_classes
        self.include_background = include_background
        self.per_class = per_class
        self.input_format = input_format

        self.add_state("intersection", default=torch.zeros(num_classes), dist_reduce_fx="sum")
        self.add_state("union", default=torch.zeros(num_classes), dist_reduce_fx="sum")

    def update(self, preds: Tensor, target: Tensor) -> None:
        intersection, union = _compute_intersection_and_union(
            preds, target, self.num_classes, self.include_background, self.input_format
        )
        self.intersection += intersection.sum(0)
        self.union += union.sum(0)

    def compute(self) -> Tensor:
        iou_valid = torch.gt(self.union, 0)

        iou = torch.where(
            iou_valid,
            torch.divide(self.intersection, self.union),
            torch.nan,
        )

        if self.per_class:
            return iou
        else:
            return torch.mean(iou[iou_valid])

from torchmetrics.

github-actions avatar github-actions commented on September 27, 2024

Hi! thanks for your contribution!, great first issue!

from torchmetrics.

vkinakh avatar vkinakh commented on September 27, 2024

Hi, I have noticed that MeanIoU is incorrect, when updating it via calling update method, if I update it via forward method it works correctly. I have looked into it, and it is because with default MeanIoU parameters forward method calls _reduce_states method, which updates score using the following formula:

reduced = ((self._update_count - 1) * global_state + local_state).float() / self._update_count,

where global_state is the score accumulated over previous steps and local_state is the score on current batch.

The same behavior is observed for all formats and if per_class = True and per_class = False

Here is the code to reproduce results

import torch
from torchmetrics.segmentation import MeanIoU


def run():
    bs = 16
    num_classes = 3
    h = w = 128

    # one-hot, per_class=False
    miou_update = MeanIoU(num_classes=num_classes, input_format="one-hot", per_class=False)
    miou_call = MeanIoU(num_classes=num_classes, input_format="one-hot", per_class=False)

    # test 1, all ones
    img_pred = torch.ones(bs, num_classes, h, w, dtype=torch.long)
    img_target = torch.ones(bs, num_classes, h, w, dtype=torch.long)

    miou_update.update(img_pred, img_target)
    miou_call(img_pred, img_target)
    print(f"Update: {miou_update.compute()}, {miou_update.score}, {miou_update.update_count}")
    print(f"Call: {miou_call.compute()}, {miou_call.score}, {miou_call.update_count}")

    # test 2, square in the middle, 100% overlap
    img_pred = torch.zeros(bs, num_classes, h, w, dtype=torch.long)
    img_target = torch.zeros(bs, num_classes, h, w, dtype=torch.long)

    img_pred[:, :, 50:80, 50:80] = 1
    img_target[:, :, 50:80, 50:80] = 1

    miou_update.update(img_pred, img_target)
    miou_call(img_pred, img_target)
    print(f"Update: {miou_update.compute()}, {miou_update.score}, {miou_update.update_count}")
    print(f"Call: {miou_call.compute()}, {miou_call.score}, {miou_call.update_count}")

    # test 3, square in the middle, 50% overlap
    img_pred = torch.zeros(bs, num_classes, h, w, dtype=torch.long)
    img_target = torch.zeros(bs, num_classes, h, w, dtype=torch.long)

    img_pred[:, :, 50:80, 50:80] = 1
    img_target[:, :, 65:80, 50:80] = 1

    miou_update.update(img_pred, img_target)
    miou_call(img_pred, img_target)
    print(f"Update: {miou_update.compute()}, {miou_update.score}, {miou_update.update_count}")
    print(f"Call: {miou_call.compute()}, {miou_call.score}, {miou_call.update_count}")

    # test 4, square in the middle, 0% overlap
    img_pred = torch.zeros(bs, num_classes, h, w, dtype=torch.long)
    img_target = torch.zeros(bs, num_classes, h, w, dtype=torch.long)

    img_pred[:, :, 50:80, 50:80] = 1
    img_target[:, :, 80:100, 80:100] = 1

    miou_update.update(img_pred, img_target)
    miou_call(img_pred, img_target)
    print(f"Update: {miou_update.compute()}, { miou_update.score}, { miou_update.update_count}")
    print(f"Call: {miou_call.compute()}, { miou_call.score}, { miou_call.update_count}")

    # one-hot, per_class=True
    miou_update = MeanIoU(num_classes=num_classes, input_format="one-hot", per_class=True)
    miou_call = MeanIoU(num_classes=num_classes, input_format="one-hot", per_class=True)

    # test 5, all ones
    img_pred = torch.ones(bs, num_classes, h, w, dtype=torch.long)
    img_target = torch.ones(bs, num_classes, h, w, dtype=torch.long)

    miou_update.update(img_pred, img_target)
    miou_call(img_pred, img_target)
    print(f"Update: {miou_update.compute()}, {miou_update.score}, {miou_update.update_count}")
    print(f"Call: {miou_call.compute()}, {miou_call.score}, {miou_call.update_count}")

    # test 6, square in the middle, 100% overlap
    img_pred = torch.zeros(bs, num_classes, h, w, dtype=torch.long)
    img_target = torch.zeros(bs, num_classes, h, w, dtype=torch.long)

    img_pred[:, :, 50:80, 50:80] = 1
    img_target[:, :, 50:80, 50:80] = 1

    miou_update.update(img_pred, img_target)
    miou_call(img_pred, img_target)
    print(f"Update: {miou_update.compute()}, {miou_update.score}, {miou_update.update_count}")
    print(f"Call: {miou_call.compute()}, {miou_call.score}, {miou_call.update_count}")

    # test 7, square in the middle, 50% overlap
    img_pred = torch.zeros(bs, num_classes, h, w, dtype=torch.long)
    img_target = torch.zeros(bs, num_classes, h, w, dtype=torch.long)

    img_pred[:, :, 50:80, 50:80] = 1
    img_target[:, :, 65:80, 50:80] = 1

    miou_update.update(img_pred, img_target)
    miou_call(img_pred, img_target)
    print(f"Update: {miou_update.compute()}, {miou_update.score}, {miou_update.update_count}")
    print(f"Call: {miou_call.compute()}, {miou_call.score}, {miou_call.update_count}")

    # test 8, square in the middle, 0% overlap
    img_pred = torch.zeros(bs, num_classes, h, w, dtype=torch.long)
    img_target = torch.zeros(bs, num_classes, h, w, dtype=torch.long)

    img_pred[:, :, 50:80, 50:80] = 1
    img_target[:, :, 80:100, 80:100] = 1

    miou_update.update(img_pred, img_target)
    miou_call(img_pred, img_target)
    print(f"Update: {miou_update.compute()}, {miou_update.score}, {miou_update.update_count}")
    print(f"Call: {miou_call.compute()}, {miou_call.score}, {miou_call.update_count}")

    # index, per_class=False
    miou_update = MeanIoU(num_classes=num_classes, input_format="index", per_class=True)
    miou_call = MeanIoU(num_classes=num_classes, input_format="index", per_class=True)

    # test 9, all ones
    img_pred = torch.ones(bs, h, w, dtype=torch.long)
    img_target = torch.ones(bs, h, w, dtype=torch.long)

    miou_update.update(img_pred, img_target)
    miou_call(img_pred, img_target)
    print(f"Update: {miou_update.compute()}, {miou_update.score}, {miou_update.update_count}")
    print(f"Call: {miou_call.compute()}, {miou_call.score}, {miou_call.update_count}")

    # test 10, square in the middle, 100% overlap
    img_pred = torch.zeros(bs, h, w, dtype=torch.long)
    img_target = torch.zeros(bs, h, w, dtype=torch.long)

    img_pred[:, 50:80, 50:80] = 1
    img_target[:, 50:80, 50:80] = 1

    miou_update.update(img_pred, img_target)
    miou_call(img_pred, img_target)
    print(f"Update: {miou_update.compute()}, {miou_update.score}, {miou_update.update_count}")
    print(f"Call: {miou_call.compute()}, {miou_call.score}, {miou_call.update_count}")

    # test 11, square in the middle, 50% overlap
    img_pred = torch.zeros(bs, h, w, dtype=torch.long)
    img_target = torch.zeros(bs, h, w, dtype=torch.long)

    img_pred[:, 50:80, 50:80] = 1
    img_target[:, 65:80, 50:80] = 1

    miou_update.update(img_pred, img_target)
    miou_call(img_pred, img_target)
    print(f"Update: {miou_update.compute()}, {miou_update.score}, {miou_update.update_count}")
    print(f"Call: {miou_call.compute()}, {miou_call.score}, {miou_call.update_count}")

    # test 12, square in the middle, 0% overlap
    img_pred = torch.zeros(bs, h, w, dtype=torch.long)

    img_pred[:, 50:80, 50:80] = 1
    img_target[:, 80:100, 80:100] = 1

    miou_update.update(img_pred, img_target)
    miou_call(img_pred, img_target)
    print(f"Update: {miou_update.compute()}, {miou_update.score}, {miou_update.update_count}")
    print(f"Call: {miou_call.compute()}, {miou_call.score}, {miou_call.update_count}")

from torchmetrics.

rittik9 avatar rittik9 commented on September 27, 2024

Hi, I'm new to open-source contributions, and I want to start by working on this issue. @Borda could you please assign it to me?

from torchmetrics.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.