GithubHelp home page GithubHelp logo

irakorshunova / bruno Goto Github PK

View Code? Open in Web Editor NEW
35.0 7.0 8.0 188 KB

a deep recurrent model for exchangeable data

Home Page: https://arxiv.org/abs/1802.07535

License: MIT License

Python 100.00%
real-nvp gaussian-processes deep-neural-networks exchangeable-structures recurrent-neural-networks omniglot few-shot-learning

bruno's Introduction

BRUNO: A Deep Recurrent Model for Exchangeable Data

This is an official code for reproducing the main results from our NIPS'18 paper:

I. Korshunova, J. Degrave, F. Huszár, Y. Gal, A. Gretton, J. Dambre
BRUNO: A Deep Recurrent Model for Exchangeable Data
arxiv.org/abs/1802.07535

and from our NIPS'18 Bayesian Deep Learning workshop paper:

I. Korshunova, Y. Gal, J. Dambre, A. Gretton
Conditional BRUNO: A Deep Recurrent Process for Exchangeable Labelled Data bayesiandeeplearning.org/2018/papers/40.pdf

Requirements

The code was used with the following settings:

  • python3
  • tensorflow-gpu==1.7.0
  • scikit-image==0.13.1
  • numpy==1.14.2
  • scipy==1.0.0

Datasets

Below we list files for every dataset that should be stored in a data/ directory inside a project folder.

MNIST

Download from yann.lecun.com/exdb/mnist/

 data/train-images-idx3-ubyte.gz
 data/train-labels-idx1-ubyte.gz
 data/t10k-images-idx3-ubyte.gz
 data/t10k-labels-idx1-ubyte.gz

Fashion MNIST

Download from github.com/zalandoresearch/fashion-mnist

data/fashion_mnist/train-images-idx3-ubyte.gz
data/fashion_mnist/train-labels-idx1-ubyte.gz
data/fashion_mnist/t10k-images-idx3-ubyte.gz
data/fashion_mnist/t10k-labels-idx1-ubyte.gz

Omniglot

Download and unzip files from github.com/brendenlake/omniglot/tree/master/python

data/images_background
data/images_evaluation

Download .pkl files from github.com/renmengye/few-shot-ssl-public#omniglot. These are used to make train-test-validation split.

data/train_vinyals_aug90.pkl
data/test_vinyals_aug90.pkl
data/val_vinyals_aug90.pkl

Run utils.py to preprocess Omniglot images

data/omniglot_x_train.npy
data/omniglot_y_train.npy
data/omniglot_x_test.npy
data/omniglot_y_test.npy
data/omniglot_valid_classes.npy

CIFAR-10

This dataset will be downloaded directly with the first call to CIFAR-10 models.

data/cifar/cifar-10-batches-py

Training and testing

There are configuration files in config_rnn for every model we used in the paper and a bunch of testing scripts. Below are examples on how to train and test Omniglot models.

Training (supports multiple gpus)

CUDA_VISIBLE_DEVICES=0,1 python3 -m config_rnn.train  --config_name bn2_omniglot_tp --nr_gpu 2

Fine-tuning (to be used on one gpu only)

CUDA_VISIBLE_DEVICES=0 python3 -m config_rnn.train_finetune  --config_name bn2_omniglot_tp_ft_1s_20w

Generating samples

CUDA_VISIBLE_DEVICES=0 python3 -m config_rnn.test_samples  --config_name bn2_omniglot_tp_ft_1s_20w

Few-shot classification

CUDA_VISIBLE_DEVICES=0 python3 -m config_rnn.test_few_shot_omniglot  --config_name bn2_omniglot_tp --seq_len 2 --batch_size 20
CUDA_VISIBLE_DEVICES=0 python3 -m config_rnn.test_few_shot_omniglot  --config_name bn2_omniglot_tp_ft_1s_20w --seq_len 2 --batch_size 20

Here, batch_size = k and seq_len = n + 1 to test the model in a k-way, n-shot setting.

Citation

Please cite our paper when using this code for your research. If you have any questions, please send me an email at [email protected]

@incollection{bruno2018,
    title = {BRUNO: A Deep Recurrent Model for Exchangeable Data},
    author = {Korshunova, Iryna and Degrave, Jonas and Huszar, Ferenc and Gal, Yarin and Gretton, Arthur and Dambre, Joni},
    booktitle = {Advances in Neural Information Processing Systems 31},
    year = {2018}
}

bruno's People

Contributors

christabella avatar irakorshunova avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

bruno's Issues

Purpose of inverse softplus and square root of GP/TP variance?

Why is the trainable variable of the variance of the GP/TP kernel (v in Eq. 9 of the BRUNO paper, diagonal values on the kernel matrix) expressed as the inverse softplus of the square root of the actual variance?

tf.constant_initializer(inv_softplus(np.sqrt(var_init)))

And then actual variance is recovered like this:

self.var = tf.square(tf.nn.softplus(self.var_vbl))

Is it simply to ensure v is non-negative? But then taking the square of self.var_vbl would suffice for that. Is the softplus some kind of trick for stability or convergence (and if so is it documented anywhere)? Thank you!

P.S. What does the m1 in m1_shapenet stand for? 🤔

Drawbacks of conditional BRUNO compared to RNN BRUNO (not a bug)

Hello, thank you for open-sourcing the code! I have a few high-level questions about the models:

1. Why is validation only done for RNN and not conditional?

In the original RNN version, there is validation done during training:

bruno/config_rnn/train.py

Lines 201 to 208 in c631d3d

if hasattr(config, 'validate_every') and (iteration + 1) % config.validate_every == 0:
print('\n Validating ...')
losses = []
rng = np.random.RandomState(42)
for _, x_valid_batch in zip(range(0, config.n_valid_batches),
config.valid_data_iter.generate(rng)):
feed_dict = {x_in_eval: x_valid_batch}
l = sess.run([eval_loss], feed_dict)

Whereas in the conditional version, eval_loss is never used:

# evaluation in case we want to validate
x_in_eval = tf.placeholder(tf.float32, shape=(config.batch_size,) + config.obs_shape)
y_in_eval = tf.placeholder(tf.float32, shape=(config.batch_size,) + config.label_shape)
log_probs = model(x_in_eval, y_in_eval)[0]
eval_loss = config.eval_loss(log_probs) if hasattr(config, 'eval_loss') else config.loss(log_probs)

2. Is conditional BRUNO not maximizing joint (conditional) log likelihood?

BRUNO is clearly maximizing the joint log likelihood:
image

However, conditional BRUNO does not seem to be maximizing the joint conditional log likelihood... or is it?
image

3. "Conditional de Finetti" is not guaranteed

Do you think this is this a problem, or not really since in practice it works nonetheless?

Thank you very much!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.