GithubHelp home page GithubHelp logo

weight initialization about pytorch-semseg HOT 6 CLOSED

meetps avatar meetps commented on May 29, 2024
weight initialization

from pytorch-semseg.

Comments (6)

varghesealex90 avatar varghesealex90 commented on May 29, 2024 1

from future import print_function, division
import matplotlib
matplotlib.use("TKAgg")
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import torch.nn.functional as F
import pyt_unet_dataloader
import pyt_unet
import pyt_unet_loss

import nibabel as nib
import torchnet as tnt

import torch.nn.init as init
import numpy as np

define weight initialization

###########################################################
def init_weights(m):

if type(m) == nn.Conv2d :
    init.xavier_uniform(m.weight, gain=np.sqrt(2))
    init.constant(m.bias,0.0)

##########################################################

plt.ion() # interactive mode

use_gpu = torch.cuda.is_available()
data_transforms = {
'train': transforms.Compose([
# transforms.RandomSizedCrop(240),
# transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
# transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
# transforms.Scale(256),
# transforms.CenterCrop(240),
transforms.ToTensor(),
# transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}

nclasses = 4
confusion_meter = tnt.meter.ConfusionMeter(nclasses, normalized=True)

data_dir1 = '/media/brats/Varghese/pyt/data/slices'

image_datasets = {x: pyt_unet_dataloader.Semantic_ImageFolder(os.path.join(data_dir1, x),
data_transforms[x])
for x in ['train','val']}
dataloders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=10,
shuffle=True, num_workers=3)
for x in ['train','val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train','val']}

inputs, targets = next (iter(dataloders['train']))

def showcase(data,gt,number):
plt.imshow(data[number,0,:,:],cmap='gray')
plt.figure()
plt.imshow(data[number,3,:,:],cmap='gray')
plt.figure()
plt.imshow(gt[number,:,:],cmap='gray')
plt.show()

def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
since = time.time()

best_model_wts = model.state_dict()
best_acc = 0.0

for epoch in range(num_epochs):
    print('Epoch {}/{}'.format(epoch, num_epochs - 1))
    print('-' * 10)

    # Each epoch has a training and validation phase
    for phase in ['train', 'val']:
        if phase == 'train':
            scheduler.step()
            model.train(True)  # Set model to training mode
        else:
            model.train(False)  # Set model to evaluate mode

        running_loss = 0.0
        running_corrects = 0

        # Iterate over data.
        confusion_meter.reset()
        for data in dataloders[phase]:
            # get the inputs
            inputs, labels = data
            labels         = labels.long()


            # wrap them in Variable
            if use_gpu:
                inputs = Variable(inputs.cuda())
                labels = Variable(labels.cuda())
            else:
                inputs, labels = Variable(inputs), Variable(labels)
            # inputs, labels = Variable(inputs), Variable(labels)
            # zero the parameter gradients
            optimizer.zero_grad()

            # forward
            outputs = model(inputs)
            _, preds = torch.max(outputs.data, 1)
            # labels = labels.long()
            # labels =labels.cuda()
            loss = criterion(F.log_softmax(outputs), labels)
            confusion_meter.add(preds.view(-1),labels.data.view(-1))

            # backward + optimize only if in training phase
            if phase == 'train':
                loss.backward()
                optimizer.step()

            # statistics
            running_loss += loss.data[0]
            running_corrects += torch.sum(preds == labels.data)
            # print (torch.max(preds))
        epoch_loss = running_loss / dataset_sizes[phase]
        epoch_acc = running_corrects / (dataset_sizes[phase]*240*240)

        print('{} Loss: {:.4f} Acc: {:.4f}'.format(
            phase, epoch_loss, epoch_acc))

        # deep copy the model
        if phase == 'val' and epoch_acc > best_acc:
            best_acc = epoch_acc
            best_model_wts = model.state_dict()
        print (confusion_meter.conf)

    print()

time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(
    time_elapsed // 60, time_elapsed % 60))
print('Best val Acc: {:4f}'.format(best_acc))

# load best model weights
model.load_state_dict(best_model_wts)

return model

model_ft = pyt_unet.unet(feature_scale=2, n_classes=4, is_deconv=True, in_channels=4, is_batchnorm=True)

apply weight initialization###############

model_ft.apply(init_weights)####################
#################################################
weights = torch.from_numpy(np.asarray([ 0.5044884, 1.21449727, 0.84989637, 2.91036189])).float()
weights= weights.cuda()
criterion = nn.NLLLoss2d(weight=weights)

if use_gpu:
model_ft = model_ft.cuda()

optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)

optimizer_ft = optim.RMSprop(model_ft.parameters(), lr=0.001)

learning_rate = 1e-3
optimizer_ft = torch.optim.Adam(model_ft.parameters(), lr=learning_rate)

Decay LR by a factor of 0.1 every 3 epochs

exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=100, gamma=0.01)

model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,
num_epochs=50)

torch.save(model_ft,'unet_model_feature_scale_2_zscore_complete_network_2.pth')

from pytorch-semseg.

varghesealex90 avatar varghesealex90 commented on May 29, 2024

I think I figured it
from torch import nn
import torch.nn.init as init
import numpy as np

def init_weights(m):
if type(m) == nn.Conv2d:
init.xavier_uniform(m.weight, gain=np.sqrt(2))
init.constant(m.bias,0.0)

model= unet() ### load unet model written by shah
model.apply(init_weights) ### apply init_weight

from pytorch-semseg.

4F2E4A2E avatar 4F2E4A2E commented on May 29, 2024

@varghesealex90 could you share the code? Class name and markup formatted would be nice and helpful for others.

from pytorch-semseg.

varghesealex90 avatar varghesealex90 commented on May 29, 2024

Hope this helps

from pytorch-semseg.

4F2E4A2E avatar 4F2E4A2E commented on May 29, 2024

This is amazing, thank you!

from pytorch-semseg.

4F2E4A2E avatar 4F2E4A2E commented on May 29, 2024

@meetshah1995 could we get some permissions to post stuff like that beautifully formatted in a wiki page or something?

from pytorch-semseg.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.