GithubHelp home page GithubHelp logo

surisdi / learning-by-drawing Goto Github PK

View Code? Open in Web Editor NEW
6.0 2.0 4.0 2.77 MB

Code for the Learning Words by Drawing Images paper

Python 14.00% Shell 0.45% CSS 0.02% HTML 1.29% TypeScript 1.96% JavaScript 81.48% C 0.10% Cuda 0.43% SCSS 0.26%

learning-by-drawing's Introduction

Learning Words by Drawing Images

This is the official PyTorch implementation of the paper Learning Words by Drawing Images by:

Dídac Surís Adrià Recasens David Bau David Harwath James Glass Antonio Torralba
Dídac Surís(*) Adrià Recasens(*) David Bau David Harwath James Glass Antonio Torralba

(*) Equal contribution

In this paper, we propose a framework for learning through drawing. Our goal is to learn the correspondence between spoken words and abstract visual attributes, from a dataset of spoken descriptions of images. We use the learned representations of GANs and manipulate them to edit semantic concepts in the generated outputs, and use such GAN-generated images to train a model using a triplet loss.

Installation and setup

Code and dependencies

The current code is programmed with Python 3.6. We do not guarantee it will work with other versions of Python. We use the PyTorch framework, version 0.4. It should also work for PyTorch 1.0, but it was not tested.

To install everything needed, have conda available, and create a new virtual environment:

conda create -n env_drawing python=3.6
conda activate env_drawing

Then install the libraries listed in requirements.txt.

After that, we still need to install the netdissect module, which is provided as part of the code. To do so, go to the root folder of the project, and use the following instructions:

gandissect/script/download_data_drawing.sh  # Download support models
pip install -v -e ./gandissect/             # Link the local netdissect package into the env

Data and pretrained models

The data can be obtained downloading the files in this link

The default folder structure is:

  • /path/to/project/
    • data/
      • audio_clevrgan_natural/
        • name_list_{train, val, test}.txt <-- Download from here (train), here (val), here (test),
        • audio/ <-- Extract here the files in this file
        • images/ <-- Extract here the files in this file
        • text/ <-- Extract here the files in this file
      • audio_clevrgan_synth/
        • name_list_{train, val, test}.txt <-- Download from here (train), here (val), here (test),
        • audio/ <-- Extract here the files in this file
        • images/ <-- Extract here the files in this file or (if audio_clevrgan_natural is alredy downloaded) create a symlink to this file

Please note that there is NO NEED to download the images, as the default training setting generates them with the GAN. They are only needed if the GAN generation is not desired (the flag loading_image has to be activated) or for evaluation purposes. The text transcription of the audios is also NOT used, but we make it available in case it is useful.

In order to download the basic data (audio and name_list files) and prepare it with the correct folder structure, execute:

./download_data.sh

Modify the DATA_FOLDER in download_data.sh to choose your preferred data folder. Change the flag "folder_dataset" accordingly when running the code.

The name (ID) of the files corresponds to the ID of the noise vector used to generate the image of the image/caption pair. Do NOT modify the seed of the random noise generation.

Pretrained models are downloaded during the execution of the scripts, if they are not found. No manual action is required. In case there is any problem, they can be found here. Please take a look at the default folder where these models are stored, which is ./downloaded_files. In order to change it, set the flags path_negatives_test, path_model_gan and path_model_segmenter accordingly.

Code structure and running

The code is structured as follows

  • run.py: script we have to execute. From a configuration file (in the \config_files folder), it calls the script in the file attribute of the configuration file, with all the parameters in the configuration file. Example of execution:
CUDA_VISIBLE_DEVICES=0 python run.py -f train_example.yaml
  • main.py: main script called from run.py. It creates all the actors (trainer, optimizer, dataset) and calls their methods in order to train, evaluate or test the system.
  • trainer.py: contains the main training class (Trainer), including the training loop.
  • dataset.py: implements a Dataset class inheriting from torch.utils.data.Dataset, which loads images, audios and text.
  • models.py: networks used for this project. Classes inheriting from torch.nn.Module.
  • clusterer.py: class that performs the clustering of features, as well as some auxiliary methods.
  • segmenter.py: class that performs segmentations, both using ground truth labels, and using the cluster classes.
  • losses.py: methods used to compute the loss. Loss definitions.
  • experiments.py: methods implementing experiments for testing. See the example in config_files/.
  • utils.py: general useful methods.
  • active_learning.py: methods to generate and use active learning samples.
  • README.md: this file.
  • requirements.txt: list of python libraries required.
  • download_data.sh: script used to download data. See instructions above.
  • config_files: folder with .yaml configuration files. Take a look at the examples to understand the format. All the checkpoints and results contain the checkpoint name in the .yaml file so that it is easy to follow what parameters were used.
  • ablate/: contains auxiliary scripts to help ablate units of the GAN.
  • gandissect/: folder containing the netdissect module.

The training can be visualized in a web browser (localhost), using tensorboard with the following command (from the project root folder):

tensorboard --logidir=./results/runs/ --port=6006

Citation

If you want to cite our research, please use:

@InProceedings{Suris_2019_CVPR,
    author = {Suris, Didac and Recasens, Adria and Bau, David and Harwath, David and Glass, James and Torralba, Antonio},
    title = {Learning Words by Drawing Images},
    booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
    month = {June},
    year = {2019}
}

learning-by-drawing's People

Contributors

dependabot[bot] avatar surisdi avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.