GithubHelp home page GithubHelp logo

sanaelotfi / bayesian_model_comparison Goto Github PK

View Code? Open in Web Editor NEW
34.0 2.0 3.0 3.03 MB

Supporing code for the paper "Bayesian Model Selection, the Marginal Likelihood, and Generalization".

License: MIT License

Python 9.14% Shell 0.04% Jupyter Notebook 90.82%
machine-learning deep-learning model-comparison model-comparison-and-selection marginal-likelihood

bayesian_model_comparison's Introduction

Bayesian Model Selection, the Marginal Likelihood, and Generalization

This repository contains experiments for the paper Bayesian Model Selection, the Marginal Likelihood, and Generalization by Sanae Lotfi, Pavel Izmailov, Gregory Benton, Micah Goldblum, and Andrew Gordon Wilson.

Introduction

In this paper, we discuss the marginal likelihood as a model comparison tool, and fundamentally re-evaluate whether it is the right metric for predicting generalization of trained models, and learning parameters.

  • We discuss the strengths and weaknesses of the marginal likelihood for model selection, hypothesis testing, architecture search and hyperparameter tuning.
  • We show that the marginal likelihood is answering an entirely different question than the generalization question: "how well will my model generalize on unseen data?", which makes the difference between hypothesis testing and predicting generalization.
  • We show that optimizing the marginal likelihood can lead to overfitting and underfitting in the function space.
  • We revisit the connection between the marginal likelihood and the training efficiency, and show that models that train faster don't necessarily generalize better or have higher marginal likelihood.
  • We demonstrate how the Laplace approximation of the marginal likelihood can fail in architecture search and hyperparameter tuning of deep neural networks.
  • We study the conditional marginal likelihood and show that it provides a compelling alternative to the marginal likelihood for neural architecture comparison, deep kernel hyperparameter learning, and transfer learning.

Pitfalls of the marginal likelihood

In this repository we provide code for reproducing results in the paper.

Please cite our work if you find it helpful in your work:

@article{lotfi2022bayesian,
  title={Bayesian Model Selection, the Marginal Likelihood, and Generalization},
  author={Lotfi, Sanae and Izmailov, Pavel and Benton, Gregory and Goldblum, Micah and Wilson, Andrew Gordon},
  journal={arXiv preprint arXiv:2202.11678},
  year={2022}
}

Requirements

We use the Laplace package for Laplace experiments, which requires python3.8. It can be installed using pip as follows:

pip install laplace-torch

Experiments

You can reproduce the GP experiments by running the Jupyter notebooks in ./GP_experiments/.

CIFAR-10 and CIFAR-100

To train ResNet and CNN models and compute their Laplace marginal likelihood for CIFAR-10 and CIFAR-100 as in section 6 of the paper, navigate to ./Laplace_experiments/ and run the following:

python logml_<dataset>_<models>.py --decay=<weight decay parameter> \
				 --prior_structure=<the structure of the prior: scalar or layerwise> \
                 --hessian_structure=<structure of the hessian approximation: full, kron, diag> \
                 --base_lr=<optimization learning rate> \
                 --use_sgdr=<use cosine lr scheduler> \
                 --optimizehypers=<optimize hyperparameters using Laplace approximation> \
                 --hypers_lr=<learning rate for hyperparameter learning> \
                 --batchnorm=<use batchnorm instead of fixup> \
                 --chk_path=<path to save the checkpoints> \
                 --result_folder=<path to save the results> 

The same code can be run to train the models with 80% of the data, then compute the conditional marginal likelihood as follows:

python logcml_<dataset>_<models>.py --prior_prec_init=<weight decay parameter> \
				 --prior_structure=<the structure of the prior: scalar or layerwise> \
                 --hessian_structure=<structure of the hessian approximation: full, kron, diag> \
                 --base_lr=<optimization learning rate> \
                 --bma_nsamples=<number of posterior samples to average over> \
                 --data_ratio=<ratio of the data to condition on> \
                 --max_iters=<number of iterations to optimize the rescaling parameter of the hessian> \
                 --partialtrain_chk_path=<path to checkpoints of models trained on a fraction of the data> \
                 --fulltrain_chk_path=<path to checkpoints of models trained on the full data> \
                 --result_folder=<path to save the results> 

Neural hyperparameter optimization for CIFAR-100

Deep kernel learning

To reproduce results for the deep kernel learning experiments, navigate to ./DKL_experiments/ and run the following:

python exact_runner.py --m=<the number of datapoints on which we condition> \
				 --losstype=<type of the loss> \
                 --dataset=<choice of the dataset> \
                 --ntrain=<number of training points> \
                 --ntrial=<number of trials > 

Deep kernel learning experiments

bayesian_model_comparison's People

Contributors

sanaelotfi avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.