GithubHelp home page GithubHelp logo

bayer-science-for-a-better-life / contrastive-reconstruction Goto Github PK

View Code? Open in Web Editor NEW
14.0 9.0 3.0 696 KB

Tensorflow-keras implementation for Contrastive Reconstruction (ConRec) : a self-supervised learning algorithm that obtains image representations by jointly optimizing a contrastive and a self-reconstruction loss.

License: GNU General Public License v3.0

Python 100.00%
keras tensorflow2 self-supervised-learning simclr conrec attention-pooling finegrained represenation-learning toy-dataset

contrastive-reconstruction's Introduction

Contrastive Reconstruction (ConRec)

Tensorflow-keras implementation for Contrastive Reconstruction: a self-supervised learning algorithm that obtains image representations by jointly optimizing a contrastive and a self-reconstruction loss presented at the ICML 2021 Workshop: Self-Supervised Learning for Reasoning and Perception [Paper, Poster].

ConRec_model

Install Dependencies

We used Python 3.7 in our experiments.

pip install -r requirements.txt

Data Preparation

For the Oxford Flowers and Stanford Dogs dataset, the data is automatically downloaded when invoking the training script. For the Aptos2019 dataset, the data has to be downloaded manually.

Aptos2019

Register for the Aptos 2019 Kaggle Competition and download the train_images folder (train_images.zip). Unzip the images into one folder (e.g. train_images) and then resize the images and put them in the resources folder with the following script:

python scripts/aptos2019/resize_images.py --image-dir train_images --output_dir resources/aptos2019/images

Synthetic dataset

Samples from the synthetic dataset can be generated by invoking the following script:

python scripts/create_synthetic_ds.py --type <rectange-triangle|circle-square> 
--output dataset.npz --num-train <n> --num-test <m>

The generated dataset that we used in our paper are included in this repository and can be found under resources/rectangle-triangle.npz and resources/circle-square.npz as numpy arrays. The respective datasets are also available as pngs files in resources/rectangle-triangle.zip and resources/circle-square.zip.

Train the model

To pretrain the models and reproduce our results in the paper, invoke the training script in the following way:

python train.py \
-o lars -lr 0.075 --lr-scaling sqrt -t 0.5 -wd 1e-4 \
--color-jitter-strength=1.0 --use-blur \
-bs 16 \
-m unet --depth 4 --filters 64 \
--logdir <logdir> \
-e <epochs> \

# For oxford_flowers102, stanford_dogs or any other tf dataset
--dataset <dataset> \
--linear-type categorical \
--eval-center-crop \

# For oxford_flowers102 additionally
--train-split train+validation \

# For aptos2019
--dataset aptos2019 \
--train-split all.csv \
--test-split all.csv \
--data-path resources/aptos2019 \
--linear-type diabetic \

# For synthehtic dataset
--dataset numpy \
--data-path <resources/rectangle-triangle.npz|resources/circle-square.npz> \
--height 128 --width 128 --channels 1 \
--linear-type categorical \

# For SimCLR
--lambda-con=1.0 \
--encoder-reduction ga_pooling \
--aug-impl simclr --simclr \

# For SimCLR + Attention
--lambda-con=1.0 \
--encoder-reduction ga_attention \
--aug-impl simclr --simclr \

# For Conrec
--lambda-rec=100.0 --lambda-con=1.0 \
--encoder-reduction ga_attention \
--aug-impl conrec \

# Optional Parameters
--validation-freq 20 \
--log-images \
--image-log-interval 20 \
--linear-interval 20 \
--save-epochs 100, 200
# Shuffle Buffer size is batch_size x shuffle-buffer-multiplier
--shuffle-buffer-multiplier 10
# Performs sklearn linear evaluation in another thread
--async-eval

where epochs should be at least 1200 for the stanford_dogs and aptos2019 dataset, 2700 for the oxford_flowers dataset and 1000 for the synthetic datasets. A ConRec example for the Oxford Flowers dataset would be:

python train.py \
-o lars -lr 0.075 --lr-scaling sqrt -t 0.5 -wd 1e-4 \
--color-jitter-strength=1.0 --use-blur -bs 16 -e 2700 \
-m unet --depth 4 --filters 64 \
--dataset oxford_flowers102 \
--linear-type categorical \
--eval-center-crop \
--lambda-rec=100.0 --lambda-con=1.0 \
--encoder-reduction ga_attention \
--aug-impl conrec \
--train-split train+validation \
--logdir <logdir>

Train the Model from an image folder

There is also the possibility to train the images from an image folder with jpg/png images. The folder with the images should have the following structure:

path/to/image_dir/
  split_name/  # Ex: 'train'
    label1/  # Ex: 'airplane' or '0015'
      xxx.png
      xxy.png
      xxz.png
    label2/
      xxx.png
      xxy.png
      xxz.png
  split_name/  # Ex: 'test' 

If we do not have labels just put all images under <path>/train/0 and specify --data-path <path>.

python train.py \
-o lars -lr 0.075 --lr-scaling sqrt -t 0.5 -wd 1e-4 \
--color-jitter-strength=1.0 --use-blur \
-e 1000 -bs 16 \
-m backbone --backbone densenet121 \
--lambda-rec=100.0 --lambda-con=1.0 \
--encoder-reduction ga_attention \
--aug-impl conrec \
--logdir <logdir> \
--dataset image-folder \
--data-path <path-to-image-folder> \

# Specify folder with train images
--train-split train  # by default \

# Specify folder with test images
--test-split test  # by default \

# or deactivate test data, no validation and eval will be performed
--no-test-data \
--linear-type none \

# It is also possible to supply a different evaluation dataset
--eval-dataset <..> \
--eval-data-path <..> \

# Center crop data for eval if all images do not have the same dimensions
--eval-center-crop

Evaluation

After pretraining the model, it is possible to evaluate the model with logistic regression for various subsets of the data. This is done by generating a json file (e.g. models.json) which includes an entry for every model that should be evaluated in the following format:

[
  {
    "name": "unet-conrec",
    "path": "<path-to-model.hdf5>",
    "preprocess": null,
    "output_layer": "encoder_output"
  },
  ...
]

Then we can compute the embeddings for each model, write them into a directory and finally perform linear evaluation on them.

Aptos2019

python scripts/aptos2019/compute_embeddings.py --models models.json --out-dir resources/aptos2019/embeddings --data-path resources/aptos2019

python scripts/aptos2019/evaluate_embeddings.py --embeddings-dir resources/aptos2019/embeddings \
--label-percentages 0.1 0.25 0.5 1.0 --repetitions 5 --output resources/aptos2019/results.csv

Oxford Flowers, Stanford Dogs

python scripts/tf_dataset/compute_embeddings.py --dataset <oxford_flowers102|stanford_dogs> \
--models models.json --out-dir <dir>

python scripts/tf_dataset/evaluate_embeddings.py --embeddings-dir <dir> \
--label-percentages 0.1 0.25 0.5 1.0  --output results.csv

To plot the results use the output file that was generated in the evaluation script:

python scripts/plot_results.py --input results.csv --metric <accuracy|kappa_kaggle>

Furthermore, instead of using logistic regression, adding a dense layer on top of the frozen encoder and using augmentations while finetuning yielded better results for the Oxford Flowers and Stanford Dogs dataset. This can be reproduced in the following way

python finetune.py -d oxford_flowers102 \
--classes 102 -lr 0.1 -o sgd \
-e 400 -bs 64 -wd 0 \
--train-split "train+validation" \
--freeze-until encoder_output \
--validation-freq 20 --preprocess simclr \
--gpu 0 --save-model \
--logdir <logdir> \
-p <path-to-model.hdf5>
python finetune.py -d stanford_dogs \
--classes 120 -lr 0.01 -o sgd \
-e 500 -bs 64 -wd 0 \
--freeze-until encoder_output \
--validation-freq 20 --preprocess simclr \
--gpu 0 --save-model \
--logdir <logdir> \
-p <path-to-model.hdf5>

With the same script, it is also possible to train the reported baselines:

python finetune.py -d oxford_flowers102 \
--classes 102 -lr 0.3 -o lars \
-e 1000 -bs 16 -wd 5e-4 \
--train-split "train+validation" \
--validation-freq 20 --preprocess simclr \
--gpu 0 --save-model \
-m unet --depth 4 --filters 64 \
--logdir <logdir>
python finetune.py -d stanford_dogs \
--classes 120 -lr 0.3 -o lars \
-e 500 -bs 16 -wd 1e-4 \
--validation-freq 20 --preprocess simclr \
--gpu 0 --save-model \
-m unet --depth 4 --filters 64 \
--logdir <logdir>

For the Aptos2019, we used the following script and configuration:

python scripts/aptos2019/finetune.py \
--output results.csv \
--models models.json \
-o adam \
-bs 32 \
-e 25 \
-lr 5e-5 \
--preprocess diabetic \
-wd 0 \
--gpu 0 \
--folds 5 \
--repetitions 5

where models.json has the same structure as for the logisitic regression and contained paths to random initialized models in this case.

Cite

ConRec arxiv paper

@article{dippel2021finegrained,
      title={Towards Fine-grained Visual Representations by Combining Contrastive Learning with Image Reconstruction and Attention-weighted Pooling}, 
      author={Jonas Dippel and Steffen Vogler and Johannes H\"ohne},
      year={2021},
      journal={arXiv preprint arXiv:2104.04323}
}

contrastive-reconstruction's People

Contributors

johanneshoehne avatar jonasd4 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

contrastive-reconstruction's Issues

How to resume training from a specific epoch ?

I was training the model and I stopped training at 13th epoch. When I tried to resume training from the 13th epoch it started training from the 1st epoch.
P.S. I don't want to finetune the model, I just want to resume the training from the 13th epoch.
Thankyou

training ends after some time

I'm trying to train the model on an image directory.
Using the following parameters:

python train.py
-o lars -lr 0.075 --lr-scaling sqrt -t 0.5 -wd 1e-4
--color-jitter-strength=1.0 --use-blur
-e 1000 -bs 1
-m unet
--lambda-rec=100.0 --lambda-con=1.0
--encoder-reduction ga_attention
--aug-impl conrec
--logdir /content
--dataset image-folder
--data-path /content/data
--no-test-data
--linear-type none

which stops after few iteration with the following indication:

2/97554 [..............................] - ETA: 3:35:03 - loss: 13.4387 - reconst_output_loss: 0.1344 - con_output_loss: 0.0000e+00 - reconst_output_mse: 0.1344 - reconst_output_mae: 0.2977 - con_output_contrastive_acc: 1.0000 - con_output_contrastive_entropy: 0.0000e+00WARNING:tensorflow:Callbacks method on_train_batch_end is slow compared to the batch time (batch time: 0.0361s vs on_train_batch_end time: 0.2279s). Check your callbacks.
WARNING:tensorflow:Callbacks method on_train_batch_end is slow compared to the batch time (batch time: 0.0361s vs on_train_batch_end time: 0.2279s). Check your callbacks.
2126/97554 [..............................] - ETA: 2:00:26 - loss: 4.2061 - reconst_output_loss: 0.0421 - con_output_loss: 0.0000e+00 - reconst_output_mse: 0.0421 - reconst_output_mae: 0.1484 - con_output_contrastive_acc: 1.0000 - con_output_contrastive_entropy: 0.0000e+00^C

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.