GithubHelp home page GithubHelp logo

carlosbertoncelli / simclr-in-tensorflow-2 Goto Github PK

View Code? Open in Web Editor NEW

This project forked from sayakpaul/simclr-in-tensorflow-2

0.0 0.0 0.0 94.06 MB

(Minimally) implements SimCLR (https://arxiv.org/abs/2002.05709) in TensorFlow 2.

Home Page: https://app.wandb.ai/sayakpaul/simclr/reports/Towards-self-supervised-image-understanding-with-SimCLR--VmlldzoxMDI5NDM

License: MIT License

Jupyter Notebook 100.00%

simclr-in-tensorflow-2's Introduction

SimCLR-in-TensorFlow-2

(Minimally) implements SimCLR (A Simple Framework for Contrastive Learning of Visual Representations by Chen et al.) in TensorFlow 2. Uses many delicious pieces of tf.keras and TensorFlow's core APIs. A report is available here.

Acknowledgements

I did not code everything from scratch. This particular research paper felt super amazing to read and often felt natural to understand, that's why I wanted to try it out myself and come up with a minimal implementation. I reused the works of the following for different purposes -

Following are the articles I studied for understanding SimCLR other than the paper:

Thanks a ton to the ML-GDE program for providing the GCP Credits using which I could run the experiments, store the intermediate results on GCS buckets as necessary. All the notebooks can be run on Colab though.

Dataset

Architecture

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input_2 (InputLayer)         [(None, 224, 224, 3)]     0
_________________________________________________________________
resnet50 (Model)             (None, 7, 7, 2048)        23587712
_________________________________________________________________
global_average_pooling2d (Gl (None, 2048)              0
_________________________________________________________________
dense (Dense)                (None, 256)               524544
_________________________________________________________________
activation (Activation)      (None, 256)               0
_________________________________________________________________
dense_1 (Dense)              (None, 128)               32896
_________________________________________________________________
activation_1 (Activation)    (None, 128)               0
_________________________________________________________________
dense_2 (Dense)              (None, 50)                6450
=================================================================
Total params: 24,151,602
Trainable params: 24,098,482
Non-trainable params: 53,120

Contrastive learning progress

Training with 10% training data using the learned representations (linear evaluation)

loss: 1.1009 - accuracy: 0.5840 - val_loss: 1.1486 - val_accuracy: 0.5280

This is when I only took the base encoder network i.e. without any non-linear projections. I presented results with different projection heads as well (available here) but this one came to be the best.

Learned representations with TSNE

This is when I only took the base encoder network i.e. without any non-linear projections. I presented results with different projection heads as well (available here) but this one came to be the best.

Supervised training with the full training dataset

Here's the architecture that was used:


Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input_4 (InputLayer)         [(None, 224, 224, 3)]     0
_________________________________________________________________
resnet50 (Model)             (None, 7, 7, 2048)        23587712
_________________________________________________________________
global_average_pooling2d_1 ( (None, 2048)              0
_________________________________________________________________
dense_1 (Dense)              (None, 256)               524544
_________________________________________________________________
activation (Activation)      (None, 256)               0
_________________________________________________________________
dense_2 (Dense)              (None, 5)                 1285
=================================================================
Total params: 24,113,541
Trainable params: 24,060,421
Non-trainable params: 53,120

loss: 0.6623 - accuracy: 0.7528 - val_loss: 1.0171 - val_accuracy: 0.6440

We see a 12% increase here. The accuracy with the SimCLR framework could further be increased with better pre-training in terms of the following aspect:

  • More unsupervised data. If we could gather a larger corpurs of images for the pre-training task (think of ImageNet) that would have definitely helped.
  • I only trained using the SimCLR framework for 200 epochs. Longer training could have definitely helped.
  • Architectural considerations and hyperparameter tuning:
    • Temperature (tau) (I used 0.1)
    • Mix and match between the different augmentation policies shown in the paper and the strength of the color distortion.
    • Different projection heads.

SimCLR benefits from larger data. Ting Chen (the first author of the paper) suggested to go for an augmentation policy (when using custom datasets) that's not too easy nor too hard for the contrastive task i.e. the contrastive accuracy should be high (e.g. > 80%).

Pre-trained weights

Available here - Pretrained_Weights.

simclr-in-tensorflow-2's People

Contributors

sayakpaul avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.