Comments (6)
from future import print_function, division
import matplotlib
matplotlib.use("TKAgg")
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import torch.nn.functional as F
import pyt_unet_dataloader
import pyt_unet
import pyt_unet_loss
import nibabel as nib
import torchnet as tnt
import torch.nn.init as init
import numpy as np
define weight initialization
###########################################################
def init_weights(m):
if type(m) == nn.Conv2d :
init.xavier_uniform(m.weight, gain=np.sqrt(2))
init.constant(m.bias,0.0)
##########################################################
plt.ion() # interactive mode
use_gpu = torch.cuda.is_available()
data_transforms = {
'train': transforms.Compose([
# transforms.RandomSizedCrop(240),
# transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
# transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
# transforms.Scale(256),
# transforms.CenterCrop(240),
transforms.ToTensor(),
# transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
nclasses = 4
confusion_meter = tnt.meter.ConfusionMeter(nclasses, normalized=True)
data_dir1 = '/media/brats/Varghese/pyt/data/slices'
image_datasets = {x: pyt_unet_dataloader.Semantic_ImageFolder(os.path.join(data_dir1, x),
data_transforms[x])
for x in ['train','val']}
dataloders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=10,
shuffle=True, num_workers=3)
for x in ['train','val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train','val']}
inputs, targets = next (iter(dataloders['train']))
def showcase(data,gt,number):
plt.imshow(data[number,0,:,:],cmap='gray')
plt.figure()
plt.imshow(data[number,3,:,:],cmap='gray')
plt.figure()
plt.imshow(gt[number,:,:],cmap='gray')
plt.show()
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
since = time.time()
best_model_wts = model.state_dict()
best_acc = 0.0
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-' * 10)
# Each epoch has a training and validation phase
for phase in ['train', 'val']:
if phase == 'train':
scheduler.step()
model.train(True) # Set model to training mode
else:
model.train(False) # Set model to evaluate mode
running_loss = 0.0
running_corrects = 0
# Iterate over data.
confusion_meter.reset()
for data in dataloders[phase]:
# get the inputs
inputs, labels = data
labels = labels.long()
# wrap them in Variable
if use_gpu:
inputs = Variable(inputs.cuda())
labels = Variable(labels.cuda())
else:
inputs, labels = Variable(inputs), Variable(labels)
# inputs, labels = Variable(inputs), Variable(labels)
# zero the parameter gradients
optimizer.zero_grad()
# forward
outputs = model(inputs)
_, preds = torch.max(outputs.data, 1)
# labels = labels.long()
# labels =labels.cuda()
loss = criterion(F.log_softmax(outputs), labels)
confusion_meter.add(preds.view(-1),labels.data.view(-1))
# backward + optimize only if in training phase
if phase == 'train':
loss.backward()
optimizer.step()
# statistics
running_loss += loss.data[0]
running_corrects += torch.sum(preds == labels.data)
# print (torch.max(preds))
epoch_loss = running_loss / dataset_sizes[phase]
epoch_acc = running_corrects / (dataset_sizes[phase]*240*240)
print('{} Loss: {:.4f} Acc: {:.4f}'.format(
phase, epoch_loss, epoch_acc))
# deep copy the model
if phase == 'val' and epoch_acc > best_acc:
best_acc = epoch_acc
best_model_wts = model.state_dict()
print (confusion_meter.conf)
print()
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(
time_elapsed // 60, time_elapsed % 60))
print('Best val Acc: {:4f}'.format(best_acc))
# load best model weights
model.load_state_dict(best_model_wts)
return model
model_ft = pyt_unet.unet(feature_scale=2, n_classes=4, is_deconv=True, in_channels=4, is_batchnorm=True)
apply weight initialization###############
model_ft.apply(init_weights)####################
#################################################
weights = torch.from_numpy(np.asarray([ 0.5044884, 1.21449727, 0.84989637, 2.91036189])).float()
weights= weights.cuda()
criterion = nn.NLLLoss2d(weight=weights)
if use_gpu:
model_ft = model_ft.cuda()
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
optimizer_ft = optim.RMSprop(model_ft.parameters(), lr=0.001)
learning_rate = 1e-3
optimizer_ft = torch.optim.Adam(model_ft.parameters(), lr=learning_rate)
Decay LR by a factor of 0.1 every 3 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=100, gamma=0.01)
model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,
num_epochs=50)
torch.save(model_ft,'unet_model_feature_scale_2_zscore_complete_network_2.pth')
from pytorch-semseg.
I think I figured it
from torch import nn
import torch.nn.init as init
import numpy as np
def init_weights(m):
if type(m) == nn.Conv2d:
init.xavier_uniform(m.weight, gain=np.sqrt(2))
init.constant(m.bias,0.0)
model= unet() ### load unet model written by shah
model.apply(init_weights) ### apply init_weight
from pytorch-semseg.
@varghesealex90 could you share the code? Class name and markup formatted would be nice and helpful for others.
from pytorch-semseg.
Hope this helps
from pytorch-semseg.
This is amazing, thank you!
from pytorch-semseg.
@meetshah1995 could we get some permissions to post stuff like that beautifully formatted in a wiki page or something?
from pytorch-semseg.
Related Issues (20)
- Where I run 'python train.py [-h] [--config [CONFIG]]' HOT 1
- pspnet training HOT 2
- About the speed results of ICNet
- SegNet on Pascal:: TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; HOT 1
- Any good results from SegNet? HOT 9
- ValueError: Segmentation map contained invalid class values HOT 2
- RecursionError: maximum recursion depth exceeded in __instancecheck__ HOT 1
- test.py error HOT 1
- Image shape changed to 352 from 360 in FRRN camvid HOT 3
- mscoco pre-trained model
- Any tip to train models from scratch using cityscape dataset? HOT 1
- Semantic Segmentation Tool
- benchmark_RELEASE
- Pretrained Models
- Problem while trying to train HardNet on CamVid dataset
- KeyError: 'name' HOT 2
- Where is model being saved after training?
- python-cdo HOT 1
- Poly learning rate scheduler not doing anything HOT 1
- Error in fcn8s_pascal.yml HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pytorch-semseg.