GithubHelp home page GithubHelp logo

usuyama / pytorch-unet Goto Github PK

View Code? Open in Web Editor NEW
822.0 10.0 226.0 365 KB

Simple PyTorch implementations of U-Net/FullyConvNet (FCN) for image segmentation

Home Page: https://colab.research.google.com/github/usuyama/pytorch-unet/blob/master/pytorch_unet_resnet18_colab.ipynb

License: MIT License

Python 2.23% Jupyter Notebook 97.77%
image-segmentation unet fully-convolutional-networks semantic-segmentation

pytorch-unet's Introduction

UNet/FCN PyTorch Open In Colab

This repository contains simple PyTorch implementations of U-Net and FCN, which are deep learning segmentation methods proposed by Ronneberger et al. and Long et al.

Synthetic images/masks for training

First clone the repository and cd into the project directory.

import matplotlib.pyplot as plt
import numpy as np
import helper
import simulation

# Generate some random images
input_images, target_masks = simulation.generate_random_data(192, 192, count=3)

for x in [input_images, target_masks]:
    print(x.shape)
    print(x.min(), x.max())

# Change channel-order and make 3 channels for matplot
input_images_rgb = [x.astype(np.uint8) for x in input_images]

# Map each channel (i.e. class) to each color
target_masks_rgb = [helper.masks_to_colorimg(x) for x in target_masks]

# Left: Input image (black and white), Right: Target mask (6ch)
helper.plot_side_by_side([input_images_rgb, target_masks_rgb])

Left: Input image (black and white), Right: Target mask (6ch)

png

Prepare Dataset and DataLoader

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets, models

class SimDataset(Dataset):
    def __init__(self, count, transform=None):
        self.input_images, self.target_masks = simulation.generate_random_data(192, 192, count=count)
        self.transform = transform

    def __len__(self):
        return len(self.input_images)

    def __getitem__(self, idx):
        image = self.input_images[idx]
        mask = self.target_masks[idx]
        if self.transform:
            image = self.transform(image)

        return [image, mask]

# use the same transformations for train/val in this example
trans = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # imagenet
])

train_set = SimDataset(2000, transform = trans)
val_set = SimDataset(200, transform = trans)

image_datasets = {
    'train': train_set, 'val': val_set
}

batch_size = 25

dataloaders = {
    'train': DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0),
    'val': DataLoader(val_set, batch_size=batch_size, shuffle=True, num_workers=0)
}

Check the outputs from DataLoader

import torchvision.utils

def reverse_transform(inp):
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    inp = (inp * 255).astype(np.uint8)

    return inp

# Get a batch of training data
inputs, masks = next(iter(dataloaders['train']))

print(inputs.shape, masks.shape)

plt.imshow(reverse_transform(inputs[3]))
torch.Size([25, 3, 192, 192]) torch.Size([25, 6, 192, 192])

png

Create the UNet module

import torch
import torch.nn as nn
from torchvision import models

def convrelu(in_channels, out_channels, kernel, padding):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel, padding=padding),
        nn.ReLU(inplace=True),
    )


class ResNetUNet(nn.Module):
    def __init__(self, n_class):
        super().__init__()

        self.base_model = models.resnet18(pretrained=True)
        self.base_layers = list(self.base_model.children())

        self.layer0 = nn.Sequential(*self.base_layers[:3]) # size=(N, 64, x.H/2, x.W/2)
        self.layer0_1x1 = convrelu(64, 64, 1, 0)
        self.layer1 = nn.Sequential(*self.base_layers[3:5]) # size=(N, 64, x.H/4, x.W/4)
        self.layer1_1x1 = convrelu(64, 64, 1, 0)
        self.layer2 = self.base_layers[5]  # size=(N, 128, x.H/8, x.W/8)
        self.layer2_1x1 = convrelu(128, 128, 1, 0)
        self.layer3 = self.base_layers[6]  # size=(N, 256, x.H/16, x.W/16)
        self.layer3_1x1 = convrelu(256, 256, 1, 0)
        self.layer4 = self.base_layers[7]  # size=(N, 512, x.H/32, x.W/32)
        self.layer4_1x1 = convrelu(512, 512, 1, 0)

        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.conv_up3 = convrelu(256 + 512, 512, 3, 1)
        self.conv_up2 = convrelu(128 + 512, 256, 3, 1)
        self.conv_up1 = convrelu(64 + 256, 256, 3, 1)
        self.conv_up0 = convrelu(64 + 256, 128, 3, 1)

        self.conv_original_size0 = convrelu(3, 64, 3, 1)
        self.conv_original_size1 = convrelu(64, 64, 3, 1)
        self.conv_original_size2 = convrelu(64 + 128, 64, 3, 1)

        self.conv_last = nn.Conv2d(64, n_class, 1)

    def forward(self, input):
        x_original = self.conv_original_size0(input)
        x_original = self.conv_original_size1(x_original)

        layer0 = self.layer0(input)
        layer1 = self.layer1(layer0)
        layer2 = self.layer2(layer1)
        layer3 = self.layer3(layer2)
        layer4 = self.layer4(layer3)

        layer4 = self.layer4_1x1(layer4)
        x = self.upsample(layer4)
        layer3 = self.layer3_1x1(layer3)
        x = torch.cat([x, layer3], dim=1)
        x = self.conv_up3(x)

        x = self.upsample(x)
        layer2 = self.layer2_1x1(layer2)
        x = torch.cat([x, layer2], dim=1)
        x = self.conv_up2(x)

        x = self.upsample(x)
        layer1 = self.layer1_1x1(layer1)
        x = torch.cat([x, layer1], dim=1)
        x = self.conv_up1(x)

        x = self.upsample(x)
        layer0 = self.layer0_1x1(layer0)
        x = torch.cat([x, layer0], dim=1)
        x = self.conv_up0(x)

        x = self.upsample(x)
        x = torch.cat([x, x_original], dim=1)
        x = self.conv_original_size2(x)

        out = self.conv_last(x)

        return out

Model summary

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ResNetUNet(n_class=6)
model = model.to(device)

# check keras-like model summary using torchsummary
from torchsummary import summary
summary(model, input_size=(3, 224, 224))
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 224, 224]           1,792
              ReLU-2         [-1, 64, 224, 224]               0
            Conv2d-3         [-1, 64, 224, 224]          36,928
              ReLU-4         [-1, 64, 224, 224]               0
            Conv2d-5         [-1, 64, 112, 112]           9,408
       BatchNorm2d-6         [-1, 64, 112, 112]             128
              ReLU-7         [-1, 64, 112, 112]               0
         MaxPool2d-8           [-1, 64, 56, 56]               0
            Conv2d-9           [-1, 64, 56, 56]           4,096
      BatchNorm2d-10           [-1, 64, 56, 56]             128
             ReLU-11           [-1, 64, 56, 56]               0
           Conv2d-12           [-1, 64, 56, 56]          36,864
      BatchNorm2d-13           [-1, 64, 56, 56]             128
             ReLU-14           [-1, 64, 56, 56]               0
           Conv2d-15          [-1, 256, 56, 56]          16,384
      BatchNorm2d-16          [-1, 256, 56, 56]             512
           Conv2d-17          [-1, 256, 56, 56]          16,384
      BatchNorm2d-18          [-1, 256, 56, 56]             512
             ReLU-19          [-1, 256, 56, 56]               0
       Bottleneck-20          [-1, 256, 56, 56]               0
           Conv2d-21           [-1, 64, 56, 56]          16,384
      BatchNorm2d-22           [-1, 64, 56, 56]             128
             ReLU-23           [-1, 64, 56, 56]               0
           Conv2d-24           [-1, 64, 56, 56]          36,864
      BatchNorm2d-25           [-1, 64, 56, 56]             128
             ReLU-26           [-1, 64, 56, 56]               0
           Conv2d-27          [-1, 256, 56, 56]          16,384
      BatchNorm2d-28          [-1, 256, 56, 56]             512
             ReLU-29          [-1, 256, 56, 56]               0
       Bottleneck-30          [-1, 256, 56, 56]               0
           Conv2d-31           [-1, 64, 56, 56]          16,384
      BatchNorm2d-32           [-1, 64, 56, 56]             128
             ReLU-33           [-1, 64, 56, 56]               0
           Conv2d-34           [-1, 64, 56, 56]          36,864
      BatchNorm2d-35           [-1, 64, 56, 56]             128
             ReLU-36           [-1, 64, 56, 56]               0
           Conv2d-37          [-1, 256, 56, 56]          16,384
      BatchNorm2d-38          [-1, 256, 56, 56]             512
             ReLU-39          [-1, 256, 56, 56]               0
       Bottleneck-40          [-1, 256, 56, 56]               0
           Conv2d-41          [-1, 128, 56, 56]          32,768
      BatchNorm2d-42          [-1, 128, 56, 56]             256
             ReLU-43          [-1, 128, 56, 56]               0
           Conv2d-44          [-1, 128, 28, 28]         147,456
      BatchNorm2d-45          [-1, 128, 28, 28]             256
             ReLU-46          [-1, 128, 28, 28]               0
           Conv2d-47          [-1, 512, 28, 28]          65,536
      BatchNorm2d-48          [-1, 512, 28, 28]           1,024
           Conv2d-49          [-1, 512, 28, 28]         131,072
      BatchNorm2d-50          [-1, 512, 28, 28]           1,024
             ReLU-51          [-1, 512, 28, 28]               0
       Bottleneck-52          [-1, 512, 28, 28]               0
           Conv2d-53          [-1, 128, 28, 28]          65,536
      BatchNorm2d-54          [-1, 128, 28, 28]             256
             ReLU-55          [-1, 128, 28, 28]               0
           Conv2d-56          [-1, 128, 28, 28]         147,456
      BatchNorm2d-57          [-1, 128, 28, 28]             256
             ReLU-58          [-1, 128, 28, 28]               0
           Conv2d-59          [-1, 512, 28, 28]          65,536
      BatchNorm2d-60          [-1, 512, 28, 28]           1,024
             ReLU-61          [-1, 512, 28, 28]               0
       Bottleneck-62          [-1, 512, 28, 28]               0
           Conv2d-63          [-1, 128, 28, 28]          65,536
      BatchNorm2d-64          [-1, 128, 28, 28]             256
             ReLU-65          [-1, 128, 28, 28]               0
           Conv2d-66          [-1, 128, 28, 28]         147,456
      BatchNorm2d-67          [-1, 128, 28, 28]             256
             ReLU-68          [-1, 128, 28, 28]               0
           Conv2d-69          [-1, 512, 28, 28]          65,536
      BatchNorm2d-70          [-1, 512, 28, 28]           1,024
             ReLU-71          [-1, 512, 28, 28]               0
       Bottleneck-72          [-1, 512, 28, 28]               0
           Conv2d-73          [-1, 128, 28, 28]          65,536
      BatchNorm2d-74          [-1, 128, 28, 28]             256
             ReLU-75          [-1, 128, 28, 28]               0
           Conv2d-76          [-1, 128, 28, 28]         147,456
      BatchNorm2d-77          [-1, 128, 28, 28]             256
             ReLU-78          [-1, 128, 28, 28]               0
           Conv2d-79          [-1, 512, 28, 28]          65,536
      BatchNorm2d-80          [-1, 512, 28, 28]           1,024
             ReLU-81          [-1, 512, 28, 28]               0
       Bottleneck-82          [-1, 512, 28, 28]               0
           Conv2d-83          [-1, 256, 28, 28]         131,072
      BatchNorm2d-84          [-1, 256, 28, 28]             512
             ReLU-85          [-1, 256, 28, 28]               0
           Conv2d-86          [-1, 256, 14, 14]         589,824
      BatchNorm2d-87          [-1, 256, 14, 14]             512
             ReLU-88          [-1, 256, 14, 14]               0
           Conv2d-89         [-1, 1024, 14, 14]         262,144
      BatchNorm2d-90         [-1, 1024, 14, 14]           2,048
           Conv2d-91         [-1, 1024, 14, 14]         524,288
      BatchNorm2d-92         [-1, 1024, 14, 14]           2,048
             ReLU-93         [-1, 1024, 14, 14]               0
       Bottleneck-94         [-1, 1024, 14, 14]               0
           Conv2d-95          [-1, 256, 14, 14]         262,144
      BatchNorm2d-96          [-1, 256, 14, 14]             512
             ReLU-97          [-1, 256, 14, 14]               0
           Conv2d-98          [-1, 256, 14, 14]         589,824
      BatchNorm2d-99          [-1, 256, 14, 14]             512
            ReLU-100          [-1, 256, 14, 14]               0
          Conv2d-101         [-1, 1024, 14, 14]         262,144
     BatchNorm2d-102         [-1, 1024, 14, 14]           2,048
            ReLU-103         [-1, 1024, 14, 14]               0
      Bottleneck-104         [-1, 1024, 14, 14]               0
          Conv2d-105          [-1, 256, 14, 14]         262,144
     BatchNorm2d-106          [-1, 256, 14, 14]             512
            ReLU-107          [-1, 256, 14, 14]               0
          Conv2d-108          [-1, 256, 14, 14]         589,824
     BatchNorm2d-109          [-1, 256, 14, 14]             512
            ReLU-110          [-1, 256, 14, 14]               0
          Conv2d-111         [-1, 1024, 14, 14]         262,144
     BatchNorm2d-112         [-1, 1024, 14, 14]           2,048
            ReLU-113         [-1, 1024, 14, 14]               0
      Bottleneck-114         [-1, 1024, 14, 14]               0
          Conv2d-115          [-1, 256, 14, 14]         262,144
     BatchNorm2d-116          [-1, 256, 14, 14]             512
            ReLU-117          [-1, 256, 14, 14]               0
          Conv2d-118          [-1, 256, 14, 14]         589,824
     BatchNorm2d-119          [-1, 256, 14, 14]             512
            ReLU-120          [-1, 256, 14, 14]               0
          Conv2d-121         [-1, 1024, 14, 14]         262,144
     BatchNorm2d-122         [-1, 1024, 14, 14]           2,048
            ReLU-123         [-1, 1024, 14, 14]               0
      Bottleneck-124         [-1, 1024, 14, 14]               0
          Conv2d-125          [-1, 256, 14, 14]         262,144
     BatchNorm2d-126          [-1, 256, 14, 14]             512
            ReLU-127          [-1, 256, 14, 14]               0
          Conv2d-128          [-1, 256, 14, 14]         589,824
     BatchNorm2d-129          [-1, 256, 14, 14]             512
            ReLU-130          [-1, 256, 14, 14]               0
          Conv2d-131         [-1, 1024, 14, 14]         262,144
     BatchNorm2d-132         [-1, 1024, 14, 14]           2,048
            ReLU-133         [-1, 1024, 14, 14]               0
      Bottleneck-134         [-1, 1024, 14, 14]               0
          Conv2d-135          [-1, 256, 14, 14]         262,144
     BatchNorm2d-136          [-1, 256, 14, 14]             512
            ReLU-137          [-1, 256, 14, 14]               0
          Conv2d-138          [-1, 256, 14, 14]         589,824
     BatchNorm2d-139          [-1, 256, 14, 14]             512
            ReLU-140          [-1, 256, 14, 14]               0
          Conv2d-141         [-1, 1024, 14, 14]         262,144
     BatchNorm2d-142         [-1, 1024, 14, 14]           2,048
            ReLU-143         [-1, 1024, 14, 14]               0
      Bottleneck-144         [-1, 1024, 14, 14]               0
          Conv2d-145          [-1, 512, 14, 14]         524,288
     BatchNorm2d-146          [-1, 512, 14, 14]           1,024
            ReLU-147          [-1, 512, 14, 14]               0
          Conv2d-148            [-1, 512, 7, 7]       2,359,296
     BatchNorm2d-149            [-1, 512, 7, 7]           1,024
            ReLU-150            [-1, 512, 7, 7]               0
          Conv2d-151           [-1, 2048, 7, 7]       1,048,576
     BatchNorm2d-152           [-1, 2048, 7, 7]           4,096
          Conv2d-153           [-1, 2048, 7, 7]       2,097,152
     BatchNorm2d-154           [-1, 2048, 7, 7]           4,096
            ReLU-155           [-1, 2048, 7, 7]               0
      Bottleneck-156           [-1, 2048, 7, 7]               0
          Conv2d-157            [-1, 512, 7, 7]       1,048,576
     BatchNorm2d-158            [-1, 512, 7, 7]           1,024
            ReLU-159            [-1, 512, 7, 7]               0
          Conv2d-160            [-1, 512, 7, 7]       2,359,296
     BatchNorm2d-161            [-1, 512, 7, 7]           1,024
            ReLU-162            [-1, 512, 7, 7]               0
          Conv2d-163           [-1, 2048, 7, 7]       1,048,576
     BatchNorm2d-164           [-1, 2048, 7, 7]           4,096
            ReLU-165           [-1, 2048, 7, 7]               0
      Bottleneck-166           [-1, 2048, 7, 7]               0
          Conv2d-167            [-1, 512, 7, 7]       1,048,576
     BatchNorm2d-168            [-1, 512, 7, 7]           1,024
            ReLU-169            [-1, 512, 7, 7]               0
          Conv2d-170            [-1, 512, 7, 7]       2,359,296
     BatchNorm2d-171            [-1, 512, 7, 7]           1,024
            ReLU-172            [-1, 512, 7, 7]               0
          Conv2d-173           [-1, 2048, 7, 7]       1,048,576
     BatchNorm2d-174           [-1, 2048, 7, 7]           4,096
            ReLU-175           [-1, 2048, 7, 7]               0
      Bottleneck-176           [-1, 2048, 7, 7]               0
          Conv2d-177           [-1, 1024, 7, 7]       2,098,176
            ReLU-178           [-1, 1024, 7, 7]               0
        Upsample-179         [-1, 1024, 14, 14]               0
          Conv2d-180          [-1, 512, 14, 14]         524,800
            ReLU-181          [-1, 512, 14, 14]               0
          Conv2d-182          [-1, 512, 14, 14]       7,078,400
            ReLU-183          [-1, 512, 14, 14]               0
        Upsample-184          [-1, 512, 28, 28]               0
          Conv2d-185          [-1, 512, 28, 28]         262,656
            ReLU-186          [-1, 512, 28, 28]               0
          Conv2d-187          [-1, 512, 28, 28]       4,719,104
            ReLU-188          [-1, 512, 28, 28]               0
        Upsample-189          [-1, 512, 56, 56]               0
          Conv2d-190          [-1, 256, 56, 56]          65,792
            ReLU-191          [-1, 256, 56, 56]               0
          Conv2d-192          [-1, 256, 56, 56]       1,769,728
            ReLU-193          [-1, 256, 56, 56]               0
        Upsample-194        [-1, 256, 112, 112]               0
          Conv2d-195         [-1, 64, 112, 112]           4,160
            ReLU-196         [-1, 64, 112, 112]               0
          Conv2d-197        [-1, 128, 112, 112]         368,768
            ReLU-198        [-1, 128, 112, 112]               0
        Upsample-199        [-1, 128, 224, 224]               0
          Conv2d-200         [-1, 64, 224, 224]         110,656
            ReLU-201         [-1, 64, 224, 224]               0
          Conv2d-202          [-1, 6, 224, 224]             390
================================================================
Total params: 40,549,382
Trainable params: 40,549,382
Non-trainable params: 0
----------------------------------------------------------------

Define the main training loop

from collections import defaultdict
import torch.nn.functional as F
from loss import dice_loss

def calc_loss(pred, target, metrics, bce_weight=0.5):
    bce = F.binary_cross_entropy_with_logits(pred, target)

    pred = F.sigmoid(pred)
    dice = dice_loss(pred, target)

    loss = bce * bce_weight + dice * (1 - bce_weight)

    metrics['bce'] += bce.data.cpu().numpy() * target.size(0)
    metrics['dice'] += dice.data.cpu().numpy() * target.size(0)
    metrics['loss'] += loss.data.cpu().numpy() * target.size(0)

    return loss

def print_metrics(metrics, epoch_samples, phase):
    outputs = []
    for k in metrics.keys():
        outputs.append("{}: {:4f}".format(k, metrics[k] / epoch_samples))

    print("{}: {}".format(phase, ", ".join(outputs)))

def train_model(model, optimizer, scheduler, num_epochs=25):
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = 1e10

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        since = time.time()

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                scheduler.step()
                for param_group in optimizer.param_groups:
                    print("LR", param_group['lr'])

                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            metrics = defaultdict(float)
            epoch_samples = 0

            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss = calc_loss(outputs, labels, metrics)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                epoch_samples += inputs.size(0)

            print_metrics(metrics, epoch_samples, phase)
            epoch_loss = metrics['loss'] / epoch_samples

            # deep copy the model
            if phase == 'val' and epoch_loss < best_loss:
                print("saving best model")
                best_loss = epoch_loss
                best_model_wts = copy.deepcopy(model.state_dict())

        time_elapsed = time.time() - since
        print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))

    print('Best val loss: {:4f}'.format(best_loss))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

Training

import torch
import torch.optim as optim
from torch.optim import lr_scheduler
import time
import copy

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

num_class = 6
model = ResNetUNet(num_class).to(device)

# freeze backbone layers
#for l in model.base_layers:
#    for param in l.parameters():
#        param.requires_grad = False

optimizer_ft = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)

exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=30, gamma=0.1)

model = train_model(model, optimizer_ft, exp_lr_scheduler, num_epochs=60)
cuda:0
Epoch 0/59
----------
LR 0.0001
train: bce: 0.070256, dice: 0.856320, loss: 0.463288
val: bce: 0.014897, dice: 0.515814, loss: 0.265356
saving best model
0m 51s
Epoch 1/59
----------
LR 0.0001
train: bce: 0.011369, dice: 0.309445, loss: 0.160407
val: bce: 0.003790, dice: 0.113682, loss: 0.058736
saving best model
0m 51s
Epoch 2/59
----------
LR 0.0001
train: bce: 0.003480, dice: 0.089928, loss: 0.046704
val: bce: 0.002525, dice: 0.067604, loss: 0.035064
saving best model
0m 51s

(Omitted)

Epoch 57/59
----------
LR 1e-05
train: bce: 0.000523, dice: 0.010289, loss: 0.005406
val: bce: 0.001558, dice: 0.030965, loss: 0.016261
0m 51s
Epoch 58/59
----------
LR 1e-05
train: bce: 0.000518, dice: 0.010209, loss: 0.005364
val: bce: 0.001548, dice: 0.031034, loss: 0.016291
0m 51s
Epoch 59/59
----------
LR 1e-05
train: bce: 0.000518, dice: 0.010168, loss: 0.005343
val: bce: 0.001566, dice: 0.030785, loss: 0.016176
0m 50s
Best val loss: 0.016171

Use the trained model

import math

model.eval()   # Set model to the evaluation mode

# Create another simulation dataset for test
test_dataset = SimDataset(3, transform = trans)
test_loader = DataLoader(test_dataset, batch_size=3, shuffle=False, num_workers=0)

# Get the first batch
inputs, labels = next(iter(test_loader))
inputs = inputs.to(device)
labels = labels.to(device)

# Predict
pred = model(inputs)
# The loss functions include the sigmoid function.
pred = F.sigmoid(pred)
pred = pred.data.cpu().numpy()
print(pred.shape)

# Change channel-order and make 3 channels for matplot
input_images_rgb = [reverse_transform(x) for x in inputs.cpu()]

# Map each channel (i.e. class) to each color
target_masks_rgb = [helper.masks_to_colorimg(x) for x in labels.cpu().numpy()]
pred_rgb = [helper.masks_to_colorimg(x) for x in pred]

helper.plot_side_by_side([input_images_rgb, target_masks_rgb, pred_rgb])
(3, 6, 192, 192)

Left: Input image, Middle: Correct mask (Ground-truth), Rigth: Predicted mask

png

pytorch-unet's People

Contributors

karanchahal avatar usuyama avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

pytorch-unet's Issues

Image dimensions do not match

I've just downloaded your jupyter notebook and got an error that the dimensions do not match. I'm talking about the very first cell in pytorch_fcn.ipynb.

Printed output: "(3,192,192,3) (3,6,192,192)"

Thanks in advance for your help. 👍

Dimensions not working

Dimensions don't match up.

/usr/local/lib/python2.7/dist-packages/torch/nn/functional.py:1474: UserWarning: Using a target size (torch.Size([1966080])) that is different to the input size (torch.Size([3010560])) is deprecated. Please ensure they have the same size.
"Please ensure they have the same size.".format(target.size(), input.size()))

I am using BCELoss and i have tried with 192*192 dimensions as well and it is not working.

Grayscale Images

Hi, How do I change the model to input grayscale images instead of 3 channel images?
Thanks

Use ConvTranspose2d instead

Here in the code of U-Net you have used upsampling layer. Instead of it you should be using ConvTranspose2d.

UpSampling2D is just a simple scaling up of the image by using nearest neighbour or bilinear upsampling, so nothing smart. Advantage is it's cheap.

Conv2DTranspose is a convolution operation whose kernel is learnt (just like normal conv2d operation) while training your model. Using Conv2DTranspose will also upsample its input but the key difference is the model should learn what is the best upsampling for the job.

RuntimeError with ResNetUNet

Hello!

When I get to the line

model = ResNetUNet(6)

I receive the following error message:

Traceback (most recent call last):
File "", line 1, in
File "", line 4, in init
File "/home/paulr/Programming/anaconda3/lib/python3.7/site-packages/torchvision/models/resnet.py", line 237, in resnet18
**kwargs)
File "/home/paulr/Programming/anaconda3/lib/python3.7/site-packages/torchvision/models/resnet.py", line 220, in resnet
model = ResNet(block, layers, **kwargs)
File "/home/paulr/Programming/anaconda3/lib/python3.7/site-packages/torchvision/models/resnet.py", line 142, in init
bias=False)
File "/home/paulr/Programming/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py", line 332, in init
False, pair(0), groups, bias, padding_mode)
File "/home/paulr/Programming/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py", line 46, in init
self.reset_parameters()
File "/home/paulr/Programming/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py", line 49, in reset_parameters
init.kaiming_uniform
(self.weight, a=math.sqrt(5))
File "/home/paulr/Programming/anaconda3/lib/python3.7/site-packages/torch/nn/init.py", line 315, in kaiming_uniform

return tensor.uniform_(-bound, bound)
RuntimeError: a leaf Variable that requires grad has been used in an in-place operation.

This looks likeit might be a bug in pytorch (or maybe just how I'm using it), but if you have any pointers I'd be greatful.

Thanks,
P

Can dice and bce loss work on a multi-class task?

Thanks for the great implementation code.

I am confusing about the loss function. As far as I can see dice and bce are both used in binary-class task. Can they work well on multi-class task? From your code I can see the losses work ok, but what about bigger data set.

I tried F.cross_entropy(), but it gives me this: RuntimeError: 1only batches of spatial targets supported (non-empty 3D tensors) but got targets of size: : [36, 4, 224, 224]. Could you please tell me whats wrong? thx

def calc_loss(pred, target, metrics, bce_weight=0.5):
    bce = F.binary_cross_entropy_with_logits(pred, target)
    target_long = target.type(torch.LongTensor)
    ce = F.cross_entropy(pred, target_long.cuda())

    # pred = F.sigmoid(pred)
    pred = torch.sigmoid(pred)
    dice = dice_loss(pred, target)

    loss = bce * bce_weight + dice * (1 - bce_weight)

    metrics['bce'] += bce.data.cpu().numpy() * target.size(0)
    metrics['dice'] += dice.data.cpu().numpy() * target.size(0)
    metrics['loss'] += loss.data.cpu().numpy() * target.size(0)

    return loss

Dice value is high at the beginning

Hi @usuyama,

Thanks for sharing your good work. I try to run the training using the simulation data and the value of dice is quite high at the beginning (0.98)

Have you ever experience this case? I assume it is something wrong, any suggestion? Thanks

making label

Your implementation is very good, why don't you use the background as a segmentation category? FCN uses one-hot is often to use the background as a category

CUDA out of memory

Hello!
Thanks for your excellent work! I'm just learning the pytorch and unet network. When I ran your scripts I got the following error:
OutOfMemoryError: CUDA out of memory. Tried to allocate 226.00 MiB (GPU 0; 4.00 GiB total capacity; 3.24 GiB already allocated; 0 bytes free; 3.27 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
The error was got in the step of training. And there were no other process hobbling the gpu.
Thanks for your reply.

License?

I hope your license is MIT or GPL3

network output

Using this implementation for 2 class problem (I'm using 1 channel target image with binary(0,1) value), but my output is not binary (0,1). did I miss something? should I add some activation function at the last layer?

RGB Target mask

Firstly thanks for your implantation, it is really so good.

Now I want to try if the input image is grayscale [512, 512] and the generated target mask is RGB image [512, 512, 3] (same figures). After SimDataset and DataLoader we have shapes torch.Size([20, 512, 512]) torch.Size([20, 512, 512, 3]) and training show RuntimeError: Given groups=1, weight of size [64, 3, 3, 3], expected input[1, 20, 512, 512] to have 3 channels, but got 20 channels instead.

How to fix this issue?

Resnet Model Validity

Hi,

I can't get your pytorch_resnet18_unet.ipynb working.

In layer4 = self.layer4_1x1(layer4) line, it throws the following error:

RuntimeError: Given groups=1, weight of size [1024, 2048, 1, 1], expected input[2, 512, 7, 7] to have 2048 channels, but got 512 channels instead

I haven't modified any of your code.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.