GithubHelp home page GithubHelp logo

mholmeslinder / semi-supervised-cv Goto Github PK

View Code? Open in Web Editor NEW
2.0 2.0 1.0 217.76 MB

A short exploration of semi-supervised learning for image classification from the Kaggle retinal disease detection dataset. Also - gradCAM heatmaps for visualizing model activation.

Jupyter Notebook 96.82% PureBasic 3.18%
alexnet computer-vision machine-learning medical-image-analysis heatmap-visualization data-augmentation semi-supervised-learning label-spreading alexnet-cv heatmaps dataset

semi-supervised-cv's Introduction

Machine Learning for Computer Vision

This workbook is is designed for automated pathology detection for Medical Images in a realistic setup - i.e. each image may have multiple pathologies/disorders.

The goal is to design models and methods to predictively detect pathological images and explain the pathology sites in the image data.

Environment/Requirements

This repo contains an environment.yml file that you can use with your package manager of choice to reproduce the environment/download the necessary packages.

Data

Data for this assignment is taken from a Kaggle contest: https://www.kaggle.com/c/vietai-advance-course-retinal-disease-detection/overview

The training data set contains 3435 retinal images that represent multiple pathological disorders. The patholgy classes and corresponding labels are: included in 'train.csv' file and each image can have more than one class category (multiple pathologies). The labels for each image are:

-opacity (0), 
-diabetic retinopathy (1), 
-glaucoma (2),
-macular edema (3),
-macular degeneration (4),
-retinal vascular occlusion (5)
-normal (6)

The test dataset contains 350 unlabeled images, which we'll be using to augment our dataset and improve model performance.

For this particular assignment, we are working with specialists for Diabetic Retinopathy and Glaucoma only, and the client is interested in a predictive learning model along with feature explanability and self-learning for Diabetic Retinopathy and Glaucoma vs. Normal images.

Model

I chose an AlexNet architecture for this task and implemented it as a custom Keras class inheriting from Sequential.

# AlexNet
class AlexNet(Sequential):
   def __init__(self, input_shape, num_classes, **kwargs):
    super().__init__()
    # Only take green channel
    self.add(Lambda(lambda x: x[:,:,:,1:2], input_shape=input_shape))

    self.add(Conv2D(96, kernel_size=(11,11), strides= 4,
                    padding= 'valid', activation= 'relu',
                    input_shape= input_shape, 
                    kernel_initializer= 'he_normal'))
    self.add(BatchNormalization())
    self.add(MaxPooling2D(pool_size=(3,3), strides= (2,2),
                          padding= 'valid', data_format= None))
    
    
    self.add(Conv2D(256, kernel_size=(5,5), strides= 1,
                    padding= 'same', activation= 'relu',
                    kernel_initializer= 'he_normal'))
    self.add(BatchNormalization())
    self.add(MaxPooling2D(pool_size=(3,3), strides= (2,2),
                          padding= 'valid', data_format= None)) 
    

    self.add(Conv2D(384, kernel_size=(3,3), strides= 1,
                    padding= 'same', activation= 'relu',
                    kernel_initializer= 'he_normal'))
    self.add(BatchNormalization())
    
    self.add(Conv2D(384, kernel_size=(3,3), strides= 1,
                    padding= 'same', activation= 'relu',
                    kernel_initializer= 'he_normal'))
    self.add(BatchNormalization())
    
    self.add(Conv2D(256, kernel_size=(3,3), strides= 1,
                    padding= 'same', activation= 'relu',
                    kernel_initializer= 'he_normal', name='last_conv'))
    self.add(BatchNormalization())
    
    self.add(MaxPooling2D(pool_size=(3,3), strides= (2,2),
                          padding= 'valid', data_format= None, name='max_pool'))
    

    self.add(Flatten(name='flatten'))
    
    self.add(Dense(num_classes, name='dense'))#, activation= 'sigmoid'))

    self.add(Activation('softmax'))

    self.compile(optimizer=tf.keras.optimizers.Adam(),
                loss='categorical_crossentropy',
                metrics=['accuracy'])

I chose this model because I've used it before on a similar challenge, and because I wanted a lightweight model for this application.

In retrospect/if I had more time, I would have done a few things differently.

  1. Using a custom class made it more difficult to Google certain aspects of Tasks 2 and 3 (Heatmaps and Data Augmentation techniques), and thus those tasks took quite a bit longer.
  2. A deeper model might have achieved better classification performance with less gridsearching and tweaking, though I believe that:
  3. Implementing one-vs all would have led to better performance than what I achieved, no matter what model I used. Specifically, doing one vs Normal (our Normal category of images), would have led to better performance, since for Normal the model was looking at more global features vs more local features for each specific pathology.

Heatmaps

The basic idea here is to compute the activations of the network's final convlutional layer with respect to each class, then use those to produce a "heat map" of what parts of the image most "excite" the model when looking for/at that particular class.

Once that heatmap is made, it can be overlayed onto the original image to give a very clear graphical representation of what the model is "looking for" in that particular class.

normal

glaucoma

Data Augmentation Technique

To augment our initial labeled dataset, we were given an unlabeled set of 350 images, which we were told had once been the test dataset (and thus could be assumed to evince the same set of pathologies).

In order to use this to augment our data, I had to build a new version of the trained model without the classification head, then use the output of the final dense layer as a representation of the data's underlying distribution.

I passed both my labeled and unlabeled Xs through this dense model, then passed the result, along with the existing labels through sklearn's LabelSpreading model. This essentially propagated labels to the new examples, based on the relationship between the underlying distribution of all the Xs along with the labels that we had.

Then, I retrained my model on this new dataset, to see if the additional data improved performance.

Performance

(see notebook)

Overall, performance was okay. I was very surprised at how much the default hyperparameters tended to outperform everything I tried to tweak.

As I mentioned above, if I started this again - I'd probably try a "one-vs-Normal" approach. AlexNet would still be a good choice, since it's lightweight - building 3 separate ResNet100s or whatever seems like it would be impractical for many purposes.

The most exciting thing about performance is how much it was boosted by the data augmentation - both in terms of fine-tuning the original model AND retraining from scratch on the augmented data. This technique seems like it could be very helpful for many real-world applications, since labeled data is always at a premium relative to unlabeled data.# AlexNet_cv

This Repo

Everything necessary to reproduce this work is contained in the notebook ML_for ComputerVisionvFINAL.ipynb. The other directories contain contextual and explanatory material, showing the development process for the final notebook.

archive:

contains previous versions of this work in various states of completion, mostly showing the dead ends common to developing and iterating on these types of models.

logs:

contains Tensorboard scalars showing the training history of the various versions of the model in numerical order (1 being the earliest models, 10 the latest).

models:

contains model weights for the various versions of the model that were tested. It contains weights as opposed to saved models because we used a custom AlextNet class, and the tf/Keras save_model() method doesn't work for that type of custom class.

results:

contains heatmap images for many of the training images showing our different pathologies - Diabetic Retinopathy and Glaucoma - vs Normal images.

semi-supervised-cv's People

Contributors

mholmeslinder avatar

Stargazers

 avatar  avatar

Watchers

 avatar  avatar

Forkers

prachapratik

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.