Comments (8)
Hi! thanks for your contribution!, great first issue!
from torchmetrics.
Hi @TuanDTr, thanks for reporting this issue.
I just tried different input tensors based on your description and I cannot reproduce the behavior either. I am pretty sure that the metric should not be able to output values larger than 1.
Would it be possible to share how you initialize the metric? Alternatively, how your training step where you log the metric looks like?
from torchmetrics.
@SkafteNicki Thank you for your quick response. Please find bellow the methods for forward, training steps as well as for validation steps where I initialize the metrics. Basically the metric is initialized in on_validation_model_eval
, updated per step and aggregated and reset in on_validation_epoch_end
. I have used this setup for a while without any problem but noticed this when evaluating 3D tensors (since I moved from 2D to 3D diffusion models).
def training_step(self, batch: Union[Tuple, torch.Tensor], batch_idx: int) -> torch.Tensor:
if self.use_profiler:
self.profiler.step()
x = batch["t1c"]
z = self.get_latent_code(x)
z_cond = []
for m in self.hparams.cond_modality:
x_cond = batch[m]
z_cond.append(self.get_latent_code(x_cond))
z_cond = torch.cat(z_cond, dim=1)
noise = torch.randn_like(z).to(self.device)
timesteps = torch.randint(0, self.hparams.num_train_steps, (z.shape[0], ), device=self.device).long()
noisy_z = self._scheduler.add_noise(original_samples=z, noise=noise, timesteps=timesteps)
noise_pred = self._unet(torch.cat((noisy_z, z_cond), dim=1), timesteps=timesteps)
loss = self.criterion(noise_pred, noise)
self.log("train/noise_recons_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
return loss
@torch.inference_mode()
def forward(self, z_cond):
self._ema.ema_model.eval()
z_dim = [z_cond.shape[0], self._unet.out_channels, *z_cond.shape[2:]]
z = torch.randn(z_dim, device=self.device)
self._scheduler.set_timesteps(num_inference_steps=self.hparams.num_inference_steps)
for t in range(self.hparams.num_inference_steps):
model_output = self._ema.ema_model(
torch.cat((z, z_cond), dim=1),
timesteps=torch.Tensor((t,)).to(self.device).long()
)
z, _ = self._scheduler.step(model_output, t, z)
x = self.decode_from_latent_code(z)
return x, z
def on_validation_model_eval(self) -> None:
"""Prepare before validation."""
self.metrics = {
"PSNR": PeakSignalNoiseRatio(data_range=None).to(self.device),
"SSIM": StructuralSimilarityIndexMeasure(data_range=None).to(self.device),
"MAE_image": MeanAbsoluteError().to(self.device),
"MAE_latent": MeanAbsoluteError().to(self.device)
}
super().on_validation_model_eval()
def validation_step(self, batch: Union[Tuple, torch.Tensor], batch_idx: int) -> None:
x = batch["t1c"]
if not self.hparams.preloaded_latent:
z = self.get_latent_code(x)
else:
z = batch["latent_t1c"]
z_cond = []
for m in self.hparams.cond_modality:
if not self.hparams.preloaded_latent:
x_cond = batch[m]
z_cond.append(self.get_latent_code(x_cond))
else:
z_cond.append(batch[f"latent_{m}"])
z_cond = torch.cat(z_cond, dim=1)
preds, latents = self.forward(z_cond)
# Compute score
self.metrics["PSNR"].update(preds, x)
self.metrics["SSIM"].update(preds, x)
self.metrics["MAE_image"].update(preds, x)
self.metrics["MAE_latent"].update(latents, z)
# Inverse transform
inverse_transform = BatchInverseTransform(self.val_dataloader().dataset.transforms, self.val_dataloader())
with allow_missing_keys_mode(self.val_dataloader().dataset.transforms):
preds = inverse_transform({"latent_t1c": preds})
self.save_to_h5_dataset(preds)
def on_validation_epoch_end(self) -> None:
psnr = self.metrics["PSNR"].compute()
ssim = self.metrics["SSIM"].compute()
mae_image = self.metrics["MAE_image"].compute()
mae_latent = self.metrics["MAE_latent"].compute()
self.log("val/psnr", psnr, on_epoch=True, logger=True, prog_bar=True)
self.log("val/ssim", ssim, on_epoch=True, logger=True, prog_bar=True)
self.log("val/mae_image", mae_image, on_epoch=True, logger=True, prog_bar=True)
self.log("val/mae_latent", mae_latent, on_epoch=True, logger=True, prog_bar=True)
self.metrics["PSNR"].reset()
self.metrics["SSIM"].reset()
self.metrics["MAE_image"].reset()
self.metrics["MAE_latent"].reset()
from torchmetrics.
Hi @TuanDTr,
I tried to reproduce the error again without success. Also the code you send me looks fine, nothing there.
I am going to assume that the reason is that you have different scaling on your input. The underlying assumption in SSIM is that the input is scaled in a similar way.
Else you would have to send me the full metric state when the error happens:
def on_validation_epoch_end(self) -> None:
ssim = self.metrics["SSIM"].compute()
if ssim > 1:
torch.save(self.metrics["SSIM"].metric_state, "ssim_state.pt")
from torchmetrics.
Hi @SkafteNicki, here is the state when the error happens:
{'similarity': metatensor(105.5943, device='cuda:0'),
'total': tensor(101., device='cuda:0')}
I'll further scale all input tensors to the same range and see if this still occurs. I will follow up with you.
from torchmetrics.
@TuanDTr I been trying to further debug the issue on my end and I am still unable to reproduce the problem. From the output you send it is very clear to me that 105/101 > 1
but not how the similarity gets to be higher than 101.
Have you tried the rescaling of the range to see if it helped?
from torchmetrics.
Hi @SkafteNicki, I have tried rescaling the range of inputs to [0, 1] (see below) but I still encountered the SSIM > 1. I am setting the data_range
to None, which I believe will eventually set the data_range
to 1, right? I will assess the range of saved outputs, hopefully it can shed light on something else. I am sorry for the late response as I am working on other stuffs at the moment but I still follow up on this.
def forward(self, z_cond):
....
return x.clamp(0, 1), z
from torchmetrics.
@SkafteNicki Hello and I am sorry for the late update. I might have an idea why SSIM is larger than 1. I inspected my evaluation script and found that the tensors were in float16. If I change them to float32, I will get the correct results. However, I cannot reproduce this issue outside my training script. I think my setting that uses mixed precision training could have something to do with this. Do you have any idea how to inspect this further. Thanks!
from torchmetrics.
Related Issues (20)
- Discrepancy in optimal threshold calculation between sklearn and torchmetrics ROC implementations HOT 2
- Mean Average Detection ignores `warn_on_many_detections` set to False HOT 1
- `list` states leak (`Tensor`) memory HOT 1
- Behaviors of AUROC and Average Precision are inconsistent when all labels are equal HOT 4
- The default value of `compute_with_cache` should be `True` . HOT 2
- Bug in ERGAS HOT 2
- Ordinal classification metrics HOT 1
- Binary Classification Expected Calibration Error HOT 2
- `intersection_over_union` error HOT 2
- Clarify that nan is supported in zero_division HOT 4
- Add `_filter_kwargs` in `ClasswiseWrapper` metric wrapper HOT 3
- MetricTracker use higher_is_better as default for maximize HOT 2
- Unpredictable class order when `panoptic_quality(..., return_per_class=True)`
- torchmetrics Accuracy HOT 8
- BootStrapper.reset() does not reset properly
- segmentation.MeanIoU is wrong HOT 3
- DataLoader worker is killed in Docker HOT 4
- MeanAveragePrecision - bug in `max_detection_thresholds` HOT 3
- Documentation of ERGAS HOT 1
- Support for DLM (AIM) metric HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from torchmetrics.