GithubHelp home page GithubHelp logo

agopal42 / ctcae Goto Github PK

View Code? Open in Web Editor NEW
3.0 2.0 0.0 940 KB

Official code repository for "Contrastive Training of Complex-Valued Autoencoders for Object Discovery"

License: MIT License

Python 100.00%

ctcae's Introduction

Contrastive Training of Complex-Valued Autoencoders for Object Discovery

Current state-of-the-art object-centric models use slots and attention-based routing for binding. However, this class of models has several conceptual limitations: the number of slots is hardwired; all slots have equal capacity; training has high computational cost; there are no object-level relational factors within slots. Synchrony-based models in principle can address these limitations by using complex-valued activations which store binding information in their phase components. However, working examples of such synchrony-based models have been developed only very recently, and are still limited to toy grayscale datasets and simultaneous storage of less than three objects in practice. Here we introduce architectural modifications and a novel contrastive learning method that greatly improve the state-of-the-art synchrony-based model. For the first time, we obtain a class of synchrony-based models capable of discovering objects in an unsupervised manner in multi-object color datasets and simultaneously representing more than three objects.

This repo provides a reference implementation for the CtCAE as introduced in our paper "Contrastive Training of Complex-Valued Autoencoders for Object Discovery" (https://arxiv.org/abs/2305.15001).


Model figure


Main plot


Setup

To download the data use the linke provided on the EMORL GitHub repository: https://github.com/pemami4911/EfficientMORL, or directly from here: https://zenodo.org/records/4895643. Store the *.h5 files to data directory in this repository.

Use the requirements.txt to install the necessary packages, e.g. run pip3 install -r requirements.txt.

Run Experiments

To train and test the CtCAE on 32x32 resolution, run one of the following commands, depending on the dataset you want to use:

python3 train.py --profile=ctcae_tetrominoes

python3 train.py --profile=ctcae_dsprites

python3 train.py --profile=ctcae_clevr

To run training and test on the original resolution (64x64 for multi_dpsrites and 96x96 for CLEVR), run the following commands:

python3 train.py --profile=ctcae_dsprites_64x64

python3 train.py --profile=ctcae_clevr_96x96

To run training and test for CAE++ and CAE baselines, simply replace ctcae in the above commands with caepp and cae respectively (e.g. python3 train.py --profile=caepp_clevr to run CAE model on CLEVR dataset). See train.py script for details on all run profiles.

Citation

When using this code, please cite our paper:

@inproceedings{stanic2023contrastive,
  title={Contrastive Training of Complex-Valued Autoencoders for Object Discovery},
  author={Stani{\'c}, Aleksandar and Gopalakrishnan, Anand and Irie, Kazuki and Schmidhuber, J{\"u}rgen},
  booktitle={Proc. Advances in Neural Information Processing Systems (NeurIPS), 2023},
  year={2023}
}

Contact

For questions and suggestions, feel free to open an issue on GitHub or send an email to [email protected] or [email protected].

ctcae's People

Contributors

agopal42 avatar astanic avatar

Stargazers

Angxiao Yue avatar steve avatar Volodymyr Kyrylov avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.