Frank-Wolfe Optimization method for NNs with minima sharpness analysis.
Mariam Hakobyan, Sergei Volodin. Swiss Federal Institute of Technology in Lausanne (EPFL)
We train small fully-connected networks on MNIST using (Frank-Wolfe, Adam, SGD) and measure minima sharpness via Hessian eigenvalues.
Mini-Project for Optimization for Machine Learning CS-439 at EPFL, 2019
We consider SGD, Adam and Frank-Wolfe, with and without averaging. See our report for more details
Tested on Ubuntu 16.04.5 LTS with 12 CPU, 60GB of RAM and 2x GPU NVidia GeForce 1080.
- Install Anaconda (Python 3.7 option)
- Create and activate an environment
- Clone/download:
git clone https://github.com/sergeivolodin/OptMLProject.git; cd OptMLProject
- Install requirements:
pip install -r requirements.txt
. Install tensorflow-gpu byconda install -c anaconda tensorflow-gpu
- Run all settings by calling
run_all.sh
- It will produce
output/*.output
files andoutput/figures/*.pdf
files, as well as will output run information torun_*.txt
experiment.py
the main file containing one experiment (loading optimizer, training, computing Hessian, computing metrics)helpers.py
contains a definition of a Fully-Connected NetworkFCModelConcat()
with variables as a single tensor (needed to compute the Hessian). In addition, it contains our own implementation of the Stochastic Frank-Wolfe methodStochasticFrankWolfe()
. This file also contains helper functions required in the experiments, such as training code, dataset loaders, Hessian calculationcreate_run.py
creates the.sh
script fromconfig.py
analyze_run.py
analyzes output produced by training (the.sh
script) and writes output torun_*.txt
and figures tooutput/figures
create_analyze_runs_helpers.py
is the helper file for the previous notebook containing code to make the results niceoutput/*.sh
files consist of many lines of the formpython ../experiment.py --param1 v1 --param2 v2 ...
, running at most 4 processes in total (2 per GPU)output/*.output
files contain outputs ofexperiment.py
(one run corresponds to one file)output/figures
contains generated figuresrun_setting.sh
runs a particular setting (create +.sh
+ analyze) and writes data to a filerun_all.sh
runs all settings- Other files are not used