GithubHelp home page GithubHelp logo

Comments (8)

github-actions avatar github-actions commented on July 22, 2024

Hi! thanks for your contribution!, great first issue!

from torchmetrics.

SkafteNicki avatar SkafteNicki commented on July 22, 2024

Hi @TuanDTr, thanks for reporting this issue.
I just tried different input tensors based on your description and I cannot reproduce the behavior either. I am pretty sure that the metric should not be able to output values larger than 1.
Would it be possible to share how you initialize the metric? Alternatively, how your training step where you log the metric looks like?

from torchmetrics.

TuanDTr avatar TuanDTr commented on July 22, 2024

@SkafteNicki Thank you for your quick response. Please find bellow the methods for forward, training steps as well as for validation steps where I initialize the metrics. Basically the metric is initialized in on_validation_model_eval, updated per step and aggregated and reset in on_validation_epoch_end. I have used this setup for a while without any problem but noticed this when evaluating 3D tensors (since I moved from 2D to 3D diffusion models).

    def training_step(self, batch: Union[Tuple, torch.Tensor], batch_idx: int) -> torch.Tensor:
        if self.use_profiler:
            self.profiler.step()
            
        x = batch["t1c"]
        z = self.get_latent_code(x)

        z_cond = []

        for m in self.hparams.cond_modality:
                x_cond = batch[m]
                z_cond.append(self.get_latent_code(x_cond))

        z_cond = torch.cat(z_cond, dim=1)

        noise = torch.randn_like(z).to(self.device)
        timesteps = torch.randint(0, self.hparams.num_train_steps, (z.shape[0], ), device=self.device).long()
        noisy_z = self._scheduler.add_noise(original_samples=z, noise=noise, timesteps=timesteps)
        noise_pred = self._unet(torch.cat((noisy_z, z_cond), dim=1), timesteps=timesteps)
        loss = self.criterion(noise_pred, noise)

        self.log("train/noise_recons_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)

        return loss

    @torch.inference_mode()
    def forward(self, z_cond):
        self._ema.ema_model.eval()
        z_dim = [z_cond.shape[0], self._unet.out_channels, *z_cond.shape[2:]]
        z = torch.randn(z_dim, device=self.device)

        self._scheduler.set_timesteps(num_inference_steps=self.hparams.num_inference_steps)
        for t in range(self.hparams.num_inference_steps):
            model_output = self._ema.ema_model(
                torch.cat((z, z_cond), dim=1),
                timesteps=torch.Tensor((t,)).to(self.device).long()
            )
            z, _ = self._scheduler.step(model_output, t, z)

        x = self.decode_from_latent_code(z)

        return x, z

    def on_validation_model_eval(self) -> None:
        """Prepare before validation."""

        self.metrics = {
            "PSNR": PeakSignalNoiseRatio(data_range=None).to(self.device),
            "SSIM": StructuralSimilarityIndexMeasure(data_range=None).to(self.device),
            "MAE_image": MeanAbsoluteError().to(self.device),
            "MAE_latent": MeanAbsoluteError().to(self.device)
        }
        super().on_validation_model_eval()

    def validation_step(self, batch: Union[Tuple, torch.Tensor], batch_idx: int) -> None:
        x = batch["t1c"]
        if not self.hparams.preloaded_latent:
            z = self.get_latent_code(x)
        else:
            z = batch["latent_t1c"]

        z_cond = []

        for m in self.hparams.cond_modality:
            if not self.hparams.preloaded_latent:
                x_cond = batch[m]
                z_cond.append(self.get_latent_code(x_cond))
            else:
                z_cond.append(batch[f"latent_{m}"])

        z_cond = torch.cat(z_cond, dim=1)

        preds, latents = self.forward(z_cond)

        # Compute score
        self.metrics["PSNR"].update(preds, x)
        self.metrics["SSIM"].update(preds, x)
        self.metrics["MAE_image"].update(preds, x)
        self.metrics["MAE_latent"].update(latents, z)

        # Inverse transform
        inverse_transform = BatchInverseTransform(self.val_dataloader().dataset.transforms, self.val_dataloader())
        with allow_missing_keys_mode(self.val_dataloader().dataset.transforms):
            preds = inverse_transform({"latent_t1c": preds})
        
        self.save_to_h5_dataset(preds)

    def on_validation_epoch_end(self) -> None:
        psnr = self.metrics["PSNR"].compute()
        ssim = self.metrics["SSIM"].compute()
        mae_image = self.metrics["MAE_image"].compute()
        mae_latent = self.metrics["MAE_latent"].compute()

        self.log("val/psnr", psnr, on_epoch=True, logger=True, prog_bar=True)
        self.log("val/ssim", ssim, on_epoch=True, logger=True, prog_bar=True)
        self.log("val/mae_image", mae_image, on_epoch=True, logger=True, prog_bar=True)
        self.log("val/mae_latent", mae_latent, on_epoch=True, logger=True, prog_bar=True)

        self.metrics["PSNR"].reset()
        self.metrics["SSIM"].reset()
        self.metrics["MAE_image"].reset()
        self.metrics["MAE_latent"].reset()

from torchmetrics.

SkafteNicki avatar SkafteNicki commented on July 22, 2024

Hi @TuanDTr,
I tried to reproduce the error again without success. Also the code you send me looks fine, nothing there.
I am going to assume that the reason is that you have different scaling on your input. The underlying assumption in SSIM is that the input is scaled in a similar way.
Else you would have to send me the full metric state when the error happens:

def on_validation_epoch_end(self) -> None:
    ssim = self.metrics["SSIM"].compute()
    if ssim > 1:
        torch.save(self.metrics["SSIM"].metric_state, "ssim_state.pt")

from torchmetrics.

TuanDTr avatar TuanDTr commented on July 22, 2024

Hi @SkafteNicki, here is the state when the error happens:

{'similarity': metatensor(105.5943, device='cuda:0'),
 'total': tensor(101., device='cuda:0')}

I'll further scale all input tensors to the same range and see if this still occurs. I will follow up with you.

from torchmetrics.

SkafteNicki avatar SkafteNicki commented on July 22, 2024

@TuanDTr I been trying to further debug the issue on my end and I am still unable to reproduce the problem. From the output you send it is very clear to me that 105/101 > 1 but not how the similarity gets to be higher than 101.
Have you tried the rescaling of the range to see if it helped?

from torchmetrics.

TuanDTr avatar TuanDTr commented on July 22, 2024

Hi @SkafteNicki, I have tried rescaling the range of inputs to [0, 1] (see below) but I still encountered the SSIM > 1. I am setting the data_range to None, which I believe will eventually set the data_range to 1, right? I will assess the range of saved outputs, hopefully it can shed light on something else. I am sorry for the late response as I am working on other stuffs at the moment but I still follow up on this.

def forward(self, z_cond):
   ....
  return x.clamp(0, 1), z

from torchmetrics.

TuanDTr avatar TuanDTr commented on July 22, 2024

@SkafteNicki Hello and I am sorry for the late update. I might have an idea why SSIM is larger than 1. I inspected my evaluation script and found that the tensors were in float16. If I change them to float32, I will get the correct results. However, I cannot reproduce this issue outside my training script. I think my setting that uses mixed precision training could have something to do with this. Do you have any idea how to inspect this further. Thanks!

from torchmetrics.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.