GithubHelp home page GithubHelp logo

swa's Introduction

Stochastic Weight Averaging (SWA)

This repository contains a PyTorch implementation of the Stochastic Weight Averaging (SWA) training method for DNNs from the paper

Averaging Weights Leads to Wider Optima and Better Generalization

by Pavel Izmailov, Dmitrii Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson.

Note: as of August 2020, SWA is now a core optimizer in the PyTorch library, and can be immediately used by anyone with PyTorch, without needing an external repo, as easily SGD or Adam. Please see this blog post introducing the native PyTorch implementation with examples.

Introduction

SWA is a simple DNN training method that can be used as a drop-in replacement for SGD with improved generalization, faster convergence, and essentially no overhead. The key idea of SWA is to average multiple samples produced by SGD with a modified learning rate schedule. We use a constant or cyclical learning rate schedule that causes SGD to explore the set of points in the weight space corresponding to high-performing networks. We observe that SWA converges more quickly than SGD, and to wider optima that provide higher test accuracy.

In this repo we implement the constant learning rate schedule that we found to be most practical on CIFAR datasets.

Please cite our work if you find this approach useful in your research:

@article{izmailov2018averaging,
  title={Averaging Weights Leads to Wider Optima and Better Generalization},
  author={Izmailov, Pavel and Podoprikhin, Dmitrii and Garipov, Timur and Vetrov, Dmitry and Wilson, Andrew Gordon},
  journal={arXiv preprint arXiv:1803.05407},
  year={2018}
}

Dependencies

Usage

The code in this repository implements both SWA and conventional SGD training, with examples on the CIFAR-10 and CIFAR-100 datasets.

To run SWA use the following command:

python3 train.py --dir=<DIR> \
                 --dataset=<DATASET> \
                 --data_path=<PATH> \
                 --model=<MODEL> \
                 --epochs=<EPOCHS> \
                 --lr_init=<LR_INIT> \
                 --wd=<WD> \
                 --swa \
                 --swa_start=<SWA_START> \
                 --swa_lr=<SWA_LR>

Parameters:

  • DIR — path to training directory where checkpoints will be stored
  • DATASET — dataset name [CIFAR10/CIFAR100] (default: CIFAR10)
  • PATH — path to the data directory
  • MODEL — DNN model name:
    • VGG16/VGG16BN/VGG19/VGG19BN
    • PreResNet110/PreResNet164
    • WideResNet28x10
  • EPOCHS — number of training epochs (default: 200)
  • LR_INIT — initial learning rate (default: 0.1)
  • WD — weight decay (default: 1e-4)
  • SWA_START — the number of epoch after which SWA will start to average models (default: 161)
  • SWA_LR — SWA learning rate (default: 0.05)

To run conventional SGD training use the following command:

python3 train.py --dir=<DIR> \
                 --dataset=<DATASET> \
                 --data_path=<PATH> \
                 --model=<MODEL> \
                 --epochs=<EPOCHS> \
                 --lr_init=<LR_INIT> \
                 --wd=<WD> 

Examples

To reproduce the results from the paper run (we use same parameters for both CIFAR-10 and CIFAR-100 except for PreResNet):

#VGG16
python3 train.py --dir=<DIR> --dataset=CIFAR100 --data_path=<PATH> --model=VGG16 --epochs=200 --lr_init=0.05 --wd=5e-4 # SGD
python3 train.py --dir=<DIR> --dataset=CIFAR100 --data_path=<PATH> --model=VGG16 --epochs=300 --lr_init=0.05 --wd=5e-4 --swa --swa_start=161 --swa_lr=0.01 # SWA 1.5 Budgets

#PreResNet
python3 train.py --dir=<DIR> --dataset=CIFAR100 --data_path=<PATH>  --model=[PreResNet110 or PreResNet164] --epochs=150  --lr_init=0.1 --wd=3e-4 # SGD
#CIFAR100
python3 train.py --dir=<DIR> --dataset=CIFAR100 --data_path=<PATH>  --model=[PreResNet110 or PreResNet164] --epochs=225 --lr_init=0.1 --wd=3e-4 --swa --swa_start=126 --swa_lr=0.05 # SWA 1.5 Budgets
#CIFAR10
python3 train.py --dir=<DIR> --dataset=CIFAR10 --data_path=<PATH>  --model=[PreResNet110 or PreResNet164] --epochs=225 --lr_init=0.1 --wd=3e-4 --swa --swa_start=126 --swa_lr=0.01 # SWA 1.5 Budgets

#WideResNet28x10 
python3 train.py --dir=<DIR> --dataset=CIFAR100 --data_path=<PATH> --model=WideResNet28x10 --epochs=200 --lr_init=0.1 --wd=5e-4 # SGD
python3 train.py --dir=<DIR> --dataset=CIFAR100 --data_path=<PATH> --model=WideResNet28x10 --epochs=300 --lr_init=0.1 --wd=5e-4 --swa --swa_start=161 --swa_lr=0.05 # SWA 1.5 Budgets

Results

CIFAR-100

Test accuracy (%) of SGD and SWA on CIFAR-100 for different training budgets. For each model the Budget is defined as the number of epochs required to train the model with the conventional SGD procedure.

DNN (Budget) SGD SWA 1 Budget SWA 1.25 Budgets SWA 1.5 Budgets
VGG16 (200) 72.55 ± 0.10 73.91 ± 0.12 74.17 ± 0.15 74.27 ± 0.25
PreResNet110 (150) 76.77 ± 0.38 78.75 ± 0.16 78.91 ± 0.29 79.10 ± 0.21
PreResNet164 (150) 78.49 ± 0.36 79.77 ± 0.17 80.18 ± 0.23 80.35 ± 0.16
WideResNet28x10 (200) 80.82 ± 0.23 81.46 ± 0.23 81.91 ± 0.27 82.15 ± 0.27

Below we show the convergence plot for SWA and SGD with PreResNet164 on CIFAR-100 and the corresponding learning rates. The dashed line illustrates the accuracy of individual models averaged by SWA.

CIFAR-10

Test accuracy (%) of SGD and SWA on CIFAR-10 for different training budgets.

DNN (Budget) SGD SWA 1 Budget SWA 1.25 Budgets SWA 1.5 Budgets
VGG16 (200) 93.25 ± 0.16 93.59 ± 0.16 93.70 ± 0.22 93.64 ± 0.18
PreResNet110 (150) 95.03 ± 0.05 95.51 ± 0.10 95.65 ± 0.03 95.82 ± 0.03
PreResNet164 (150) 95.28 ± 0.10 95.56 ± 0.11 95.77 ± 0.04 95.83 ± 0.03
WideResNet28x10 (200) 96.18 ± 0.11 96.45 ± 0.11 96.64 ± 0.08 96.79 ± 0.05

Other Implementations

References

Provided model implementations were adapted from

swa's People

Contributors

andrewgordonwilson avatar izmailovpavel avatar timgaripov avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

swa's Issues

Figure 3 plot

Hi, can you update the code for plotting Figure 3 in the paper?

CUDA out of memory

when the train epoch reaches start_epoch of the SWA, I always got the error: RuntimeError: CUDA out of memory. Is anyone can help me?

release code for training imagenet

hi, i just use your code to reimplement result of resnet50 swa in imagenet.
Before tryiing with imagnet, i have reimplement result of cifar100 from 81.54% to 83.3%.
Then i try to train with imagnet, but i find it is difficult to match 76.97% in 10 epoch finetune stage, actually i only get 71.4580% with swa in 10 epochs.
in your paper, you use cylic learning rate scedule, Is that the main reason you can get 76.97%?
By the way, in 10 epoch finetune with lr=0.0001 , model without swa also can get 76.78%.
It seems swa not get surprising improvement compared to only finetune.
Could you relase code for finetuning imagenet?
image

About preresnet results

Hi, @timgaripov ,@izmailovpavel ,

Thanks for your pretty work & nice paper.

from the issues
bearpaw/pytorch-classification#6
and bearpaw/pytorch-classification#9

preresnet-110 is not really 110 layers,
because we use bottleneck. ((110-2)/6 =18, 18*9+2=164 layers)

        n = (depth - 2) // 6
        block = Bottleneck if depth >= 44 else BasicBlock

the correct way is:

        if depth >= 44:
            assert (depth - 2) % 9 == 0
            n = (depth - 2) // 9
            block = Bottleneck

        else:
            assert (depth - 2) % 6 == 0, 'depth should be 6n+2'
            n = (depth - 2) // 6
            block = BasicBlock

so the results of preresnet seems not correct,
look forward the new experiments in swa!

thanks a lot!

How to reproduce result of Fig.1 in the paper?

How to reproduce the result of Fig.1 , which illustrates the loss(test error) as function of network weights in a two-dimensional subspace? Is it using PCA to get two main dimensions of weights?

Unable to re-produce SGD numbers

Hi,
Thanks for SWA, neat idea. I just checked-out the code and wanted to replicate the SWA experiment as mentioned here . In Table1 of the paper you have mentioned SGD numbers are much higher than what I get when I run the code. Infact the SWA number I get is better. I also tried another minor change doing "bn_update" operation even for SGD before evaluation and it improves the performance, not sure why ? I am attaching files of training VGG16BN model on CIFAR-10 dataset. Will be thankful for your insights.
[VGG16BN Model]
(https://github.com/timgaripov/swa/files/4008439/cifar10_VGG16BNModel_swaLogs.txt)

VGG16BN Model with BN update for SGD

What's the best way to get similar or exact numbers as mentioned in paper. I used the same command as mentioned in Git repo:
python3 train.py --dir=/home/swa/swa_cifar10_VGG16BNModel/ --dataset=CIFAR10 --data_path=/home/swa_gaussian/cifar10_data/ --model=VGG16BN --epochs=300 --lr_init=0.05 --wd=5e-4 --swa --swa_start=161 --swa_lr=0.01 --save_freq=50 --eval_freq=10 > cifar10_VGG16BNModel_swaLogs

Thank you once again.

SWA with Torchbearer

Hi! I am trying to implement SWA with Torchbearer but I'm having issues with the state etc. Could anybody point me in the direction of some resources to better understand how to use SWA with Torchbearer, particularly the Trial class?

Thank you in advance!
Anabetsy

Incorrect model parameter update?

In the code you update SWA weights via

utils.moving_average(swa_model, model, 1.0 / (swa_n + 1))

Then in the moving_average function your math is implemented

def moving_average(net1, net2, alpha=1):
    for param1, param2 in zip(net1.parameters(), net2.parameters()):
        param1.data *= (1.0 - alpha)
        param1.data += param2.data * alpha

Is this updating correctly? I'm assuming this part of the code is performing the part in Algorithm 1 from the paper where you "Update average" or updating the stochastic weight average. But the math feels a bit off. Wouldn't the equation look something like

param1.data = (param1.data * ((1.0 / alpha) - 1) + param2.data) * alpha

To be more in accordance with the stochastic weight update eq.

Isn't the weight learnt from each minibatchs are supposed to be averaged?

Hello, thank you for the great work.

I am trying to use swa in my work, so I read code to find out what I get from the paper is right.

However, I think that the way the weight is averaged is different.

In paper( and also in other_example), it seems that weights learnt from each minibatch are ensembled at the end of an epoch.

However, in this code, it seems that weights learnt form each epoch are ensembled.

So I am confused.

BrokenPipeError: [Errno 32] Broken pipe

Preparing directory '/'
Using model PreResNet110
Loading dataset CIFAR10 from '/'
Files already downloaded and verified
Files already downloaded and verified
Preparing model
SGD training
Preparing directory '/'
Using model PreResNet110
Loading dataset CIFAR10 from '/'
Files already downloaded and verified
Files already downloaded and verified
Preparing model
SGD training
Traceback (most recent call last):
File "", line 1, in
File "C:\Users\Videet\Anaconda3\lib\multiprocessing\spawn.py", line 105, in spawn_main
exitcode = _main(fd)
File "C:\Users\Videet\Anaconda3\lib\multiprocessing\spawn.py", line 114, in _main
prepare(preparation_data)
File "C:\Users\Videet\Anaconda3\lib\multiprocessing\spawn.py", line 225, in prepare
_fixup_main_from_path(data['init_main_from_path'])
File "C:\Users\Videet\Anaconda3\lib\multiprocessing\spawn.py", line 277, in _fixup_main_from_path
run_name="mp_main")
File "C:\Users\Videet\Anaconda3\lib\runpy.py", line 263, in run_path
pkg_name=pkg_name, script_name=fname)
File "C:\Users\Videet\Anaconda3\lib\runpy.py", line 96, in _run_module_code
mod_name, mod_spec, pkg_name, script_name)
File "C:\Users\Videet\Anaconda3\lib\runpy.py", line 85, in _run_code
Traceback (most recent call last):
exec(code, run_globals)
File "train.py", line 148, in
File "C:\Users\Videet\SWA\swa-master\train.py", line 148, in
train_res = utils.train_epoch(loaders['train'], model, criterion, optimizer)
train_res = utils.train_epoch(loaders['train'], model, criterion, optimizer)
File "C:\Users\Videet\SWA\swa-master\utils.py", line 26, in train_epoch
File "C:\Users\Videet\SWA\swa-master\utils.py", line 26, in train_epoch
for i, (input, target) in enumerate(loader):
for i, (input, target) in enumerate(loader):
File "C:\Users\Videet\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 451, in iter
File "C:\Users\Videet\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 451, in iter
return _DataLoaderIter(self)
File "C:\Users\Videet\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 239, in init
return _DataLoaderIter(self)
w.start()
File "C:\Users\Videet\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 239, in init
File "C:\Users\Videet\Anaconda3\lib\multiprocessing\process.py", line 105, in start
w.start()
self._popen = self._Popen(self)
File "C:\Users\Videet\Anaconda3\lib\multiprocessing\process.py", line 105, in start
File "C:\Users\Videet\Anaconda3\lib\multiprocessing\context.py", line 223, in _Popen
self._popen = self._Popen(self)
return _default_context.get_context().Process._Popen(process_obj)
File "C:\Users\Videet\Anaconda3\lib\multiprocessing\context.py", line 223, in _Popen
File "C:\Users\Videet\Anaconda3\lib\multiprocessing\context.py", line 322, in _Popen
return _default_context.get_context().Process._Popen(process_obj)
return Popen(process_obj)
File "C:\Users\Videet\Anaconda3\lib\multiprocessing\context.py", line 322, in _Popen
File "C:\Users\Videet\Anaconda3\lib\multiprocessing\popen_spawn_win32.py", line 65, in init
return Popen(process_obj)
reduction.dump(process_obj, to_child)
File "C:\Users\Videet\Anaconda3\lib\multiprocessing\reduction.py", line 60, in dump
File "C:\Users\Videet\Anaconda3\lib\multiprocessing\popen_spawn_win32.py", line 33, in init
ForkingPickler(file, protocol).dump(obj)
prep_data = spawn.get_preparation_data(process_obj._name)
BrokenPipeError: [Errno 32] Broken pipe
File "C:\Users\Videet\Anaconda3\lib\multiprocessing\spawn.py", line 143, in get_preparation_data
_check_not_importing_main()
File "C:\Users\Videet\Anaconda3\lib\multiprocessing\spawn.py", line 136, in _check_not_importing_main
is not going to be frozen to produce an executable.''')
RuntimeError:
An attempt has been made to start a new process before the
current process has finished its bootstrapping phase.

    This probably means that you are not using fork to start your
    child processes and you have forgotten to use the proper idiom
    in the main module:

        if __name__ == '__main__':
            freeze_support()
            ...

    The "freeze_support()" line can be omitted if the program
    is not going to be frozen to produce an executable.

About calculating parameters of BN

Hi, Thanks for your great work. In your paper, you say after training the model, we should run one additional pass over the train data to get the running mean and standard deviation. I don't know why we should do that. What about the gamma, beta needed in BN?

performance drop due to batch norm params recalculation

Thanks for the great work!
I have some cases when the performance drops when using swa compared to a single model.
In this case from 0.67 loss to 0.72 loss of the exact SWA copy.

In order to debug the problem I run SWA for only one epoch and compared the model vs the SWA copy.
All the parameter are the same except the batch norms running_mean and running var. and it seems that the deeper you go in the network the bigger the divergence is:

Do you have any tips on how to recalculate the batch_norm params more accurately? or should i just run the training set to the swa version multiple times for them to converge to the original model params?

This the code I use to compare m the model state dict and swa the SWA copy state dict

for key in m.keys():
...     print(key,(swa[key]-m[key]).sum())
...
in_c.0.weight tensor(0., device='cuda:0')
in_c.1.weight tensor(0., device='cuda:0')
in_c.1.bias tensor(0., device='cuda:0')
in_c.1.running_mean tensor(-0.0847, device='cuda:0')
in_c.1.running_var tensor(4.5671, device='cuda:0')
in_c.1.num_batches_tracked tensor(0, device='cuda:0')
stage1.block1.conv1.weight tensor(0., device='cuda:0')
stage1.block1.bn1.weight tensor(0., device='cuda:0')
stage1.block1.bn1.bias tensor(0., device='cuda:0')
stage1.block1.bn1.running_mean tensor(-0.0932, device='cuda:0')
stage1.block1.bn1.running_var tensor(-1.3953, device='cuda:0')
stage1.block1.bn1.num_batches_tracked tensor(0, device='cuda:0')
stage1.block1.conv2.weight tensor(0., device='cuda:0')
stage1.block1.bn2.weight tensor(0., device='cuda:0')
stage1.block1.bn2.bias tensor(0., device='cuda:0')
stage1.block1.bn2.running_mean tensor(0.0153, device='cuda:0')
stage1.block1.bn2.running_var tensor(0.4095, device='cuda:0')
stage1.block1.bn2.num_batches_tracked tensor(0, device='cuda:0')
stage1.block2.conv1.weight tensor(0., device='cuda:0')
stage1.block2.bn1.weight tensor(0., device='cuda:0')
stage1.block2.bn1.bias tensor(0., device='cuda:0')
stage1.block2.bn1.running_mean tensor(-0.1347, device='cuda:0')
stage1.block2.bn1.running_var tensor(0.1461, device='cuda:0')
stage1.block2.bn1.num_batches_tracked tensor(0, device='cuda:0')
stage1.block2.conv2.weight tensor(0., device='cuda:0')
stage1.block2.bn2.weight tensor(0., device='cuda:0')
stage1.block2.bn2.bias tensor(0., device='cuda:0')
stage1.block2.bn2.running_mean tensor(0.0590, device='cuda:0')
stage1.block2.bn2.running_var tensor(-0.0815, device='cuda:0')
stage1.block2.bn2.num_batches_tracked tensor(0, device='cuda:0')
stage1.block3.conv1.weight tensor(0., device='cuda:0')
stage1.block3.bn1.weight tensor(0., device='cuda:0')
stage1.block3.bn1.bias tensor(0., device='cuda:0')
stage1.block3.bn1.running_mean tensor(-0.3279, device='cuda:0')
stage1.block3.bn1.running_var tensor(1.1645, device='cuda:0')
stage1.block3.bn1.num_batches_tracked tensor(0, device='cuda:0')
stage1.block3.conv2.weight tensor(0., device='cuda:0')
stage1.block3.bn2.weight tensor(0., device='cuda:0')
stage1.block3.bn2.bias tensor(0., device='cuda:0')
stage1.block3.bn2.running_mean tensor(0.0153, device='cuda:0')
stage1.block3.bn2.running_var tensor(-0.0028, device='cuda:0')
stage1.block3.bn2.num_batches_tracked tensor(0, device='cuda:0')
stage1.block4.conv1.weight tensor(0., device='cuda:0')
stage1.block4.bn1.weight tensor(0., device='cuda:0')
stage1.block4.bn1.bias tensor(0., device='cuda:0')
stage1.block4.bn1.running_mean tensor(-0.0740, device='cuda:0')
stage1.block4.bn1.running_var tensor(1.9955, device='cuda:0')
stage1.block4.bn1.num_batches_tracked tensor(0, device='cuda:0')
stage1.block4.conv2.weight tensor(0., device='cuda:0')
stage1.block4.bn2.weight tensor(0., device='cuda:0')
stage1.block4.bn2.bias tensor(0., device='cuda:0')
stage1.block4.bn2.running_mean tensor(-0.0084, device='cuda:0')
stage1.block4.bn2.running_var tensor(-0.1089, device='cuda:0')
stage1.block4.bn2.num_batches_tracked tensor(0, device='cuda:0')
stage2.block1.conv1.weight tensor(0., device='cuda:0')
stage2.block1.bn1.weight tensor(0., device='cuda:0')
stage2.block1.bn1.bias tensor(0., device='cuda:0')
stage2.block1.bn1.running_mean tensor(-0.5207, device='cuda:0')
stage2.block1.bn1.running_var tensor(13.1180, device='cuda:0')
stage2.block1.bn1.num_batches_tracked tensor(0, device='cuda:0')
stage2.block1.conv2.weight tensor(0., device='cuda:0')
stage2.block1.bn2.weight tensor(0., device='cuda:0')
stage2.block1.bn2.bias tensor(0., device='cuda:0')
stage2.block1.bn2.running_mean tensor(0.0416, device='cuda:0')
stage2.block1.bn2.running_var tensor(0.0256, device='cuda:0')
stage2.block1.bn2.num_batches_tracked tensor(0, device='cuda:0')
stage2.block1.shortcut.conv.weight tensor(0., device='cuda:0')
stage2.block1.shortcut.bn.weight tensor(0., device='cuda:0')
stage2.block1.shortcut.bn.bias tensor(0., device='cuda:0')
stage2.block1.shortcut.bn.running_mean tensor(-0.0403, device='cuda:0')
stage2.block1.shortcut.bn.running_var tensor(9.1353, device='cuda:0')
stage2.block1.shortcut.bn.num_batches_tracked tensor(0, device='cuda:0')
stage2.block2.conv1.weight tensor(0., device='cuda:0')
stage2.block2.bn1.weight tensor(0., device='cuda:0')
stage2.block2.bn1.bias tensor(0., device='cuda:0')
stage2.block2.bn1.running_mean tensor(0.5277, device='cuda:0')
stage2.block2.bn1.running_var tensor(-3.7371, device='cuda:0')
stage2.block2.bn1.num_batches_tracked tensor(0, device='cuda:0')
stage2.block2.conv2.weight tensor(0., device='cuda:0')
stage2.block2.bn2.weight tensor(0., device='cuda:0')
stage2.block2.bn2.bias tensor(0., device='cuda:0')
stage2.block2.bn2.running_mean tensor(-0.0346, device='cuda:0')
stage2.block2.bn2.running_var tensor(-0.5359, device='cuda:0')
stage2.block2.bn2.num_batches_tracked tensor(0, device='cuda:0')
stage2.block3.conv1.weight tensor(0., device='cuda:0')
stage2.block3.bn1.weight tensor(0., device='cuda:0')
stage2.block3.bn1.bias tensor(0., device='cuda:0')
stage2.block3.bn1.running_mean tensor(-0.1441, device='cuda:0')
stage2.block3.bn1.running_var tensor(-9.7162, device='cuda:0')
stage2.block3.bn1.num_batches_tracked tensor(0, device='cuda:0')
stage2.block3.conv2.weight tensor(0., device='cuda:0')
stage2.block3.bn2.weight tensor(0., device='cuda:0')
stage2.block3.bn2.bias tensor(0., device='cuda:0')
stage2.block3.bn2.running_mean tensor(0.0451, device='cuda:0')
stage2.block3.bn2.running_var tensor(-0.1082, device='cuda:0')
stage2.block3.bn2.num_batches_tracked tensor(0, device='cuda:0')
stage2.block4.conv1.weight tensor(0., device='cuda:0')
stage2.block4.bn1.weight tensor(0., device='cuda:0')
stage2.block4.bn1.bias tensor(0., device='cuda:0')
stage2.block4.bn1.running_mean tensor(0.5103, device='cuda:0')
stage2.block4.bn1.running_var tensor(-31.8605, device='cuda:0')
stage2.block4.bn1.num_batches_tracked tensor(0, device='cuda:0')
stage2.block4.conv2.weight tensor(0., device='cuda:0')
stage2.block4.bn2.weight tensor(0., device='cuda:0')
stage2.block4.bn2.bias tensor(0., device='cuda:0')
stage2.block4.bn2.running_mean tensor(-0.0271, device='cuda:0')
stage2.block4.bn2.running_var tensor(-0.5281, device='cuda:0')
stage2.block4.bn2.num_batches_tracked tensor(0, device='cuda:0')
stage3.block1.conv1.weight tensor(0., device='cuda:0')
stage3.block1.bn1.weight tensor(0., device='cuda:0')
stage3.block1.bn1.bias tensor(0., device='cuda:0')
stage3.block1.bn1.running_mean tensor(2.6584, device='cuda:0')
stage3.block1.bn1.running_var tensor(-154.8269, device='cuda:0')
stage3.block1.bn1.num_batches_tracked tensor(0, device='cuda:0')
stage3.block1.conv2.weight tensor(0., device='cuda:0')
stage3.block1.bn2.weight tensor(0., device='cuda:0')
stage3.block1.bn2.bias tensor(0., device='cuda:0')
stage3.block1.bn2.running_mean tensor(-0.0399, device='cuda:0')
stage3.block1.bn2.running_var tensor(-2.9489, device='cuda:0')
stage3.block1.bn2.num_batches_tracked tensor(0, device='cuda:0')
stage3.block2.conv1.weight tensor(0., device='cuda:0')
stage3.block2.bn1.weight tensor(0., device='cuda:0')
stage3.block2.bn1.bias tensor(0., device='cuda:0')
stage3.block2.bn1.running_mean tensor(-0.0263, device='cuda:0')
stage3.block2.bn1.running_var tensor(-6.7252, device='cuda:0')
stage3.block2.bn1.num_batches_tracked tensor(0, device='cuda:0')
stage3.block2.conv2.weight tensor(0., device='cuda:0')
stage3.block2.bn2.weight tensor(0., device='cuda:0')
stage3.block2.bn2.bias tensor(0., device='cuda:0')
stage3.block2.bn2.running_mean tensor(0.2284, device='cuda:0')
stage3.block2.bn2.running_var tensor(0.6274, device='cuda:0')
stage3.block2.bn2.num_batches_tracked tensor(0, device='cuda:0')
stage3.block3.conv1.weight tensor(0., device='cuda:0')
stage3.block3.bn1.weight tensor(0., device='cuda:0')
stage3.block3.bn1.bias tensor(0., device='cuda:0')
stage3.block3.bn1.running_mean tensor(-0.1151, device='cuda:0')
stage3.block3.bn1.running_var tensor(15.2176, device='cuda:0')
stage3.block3.bn1.num_batches_tracked tensor(0, device='cuda:0')
stage3.block3.conv2.weight tensor(0., device='cuda:0')
stage3.block3.bn2.weight tensor(0., device='cuda:0')
stage3.block3.bn2.bias tensor(0., device='cuda:0')
stage3.block3.bn2.running_mean tensor(0.3864, device='cuda:0')
stage3.block3.bn2.running_var tensor(-1.1801, device='cuda:0')
stage3.block3.bn2.num_batches_tracked tensor(0, device='cuda:0')
stage3.block4.conv1.weight tensor(0., device='cuda:0')
stage3.block4.bn1.weight tensor(0., device='cuda:0')
stage3.block4.bn1.bias tensor(0., device='cuda:0')
stage3.block4.bn1.running_mean tensor(-0.2459, device='cuda:0')
stage3.block4.bn1.running_var tensor(29.2794, device='cuda:0')
stage3.block4.bn1.num_batches_tracked tensor(0, device='cuda:0')
stage3.block4.conv2.weight tensor(0., device='cuda:0')
stage3.block4.bn2.weight tensor(0., device='cuda:0')
stage3.block4.bn2.bias tensor(0., device='cuda:0')
stage3.block4.bn2.running_mean tensor(-0.0546, device='cuda:0')
stage3.block4.bn2.running_var tensor(9.7019, device='cuda:0')
stage3.block4.bn2.num_batches_tracked tensor(0, device='cuda:0')
feed_forward.0.weight tensor(0., device='cuda:0')
feed_forward.1.weight tensor(0., device='cuda:0')
feed_forward.1.bias tensor(0., device='cuda:0')
feed_forward.1.running_mean tensor(0.2525, device='cuda:0')
feed_forward.1.running_var tensor(-2.0196, device='cuda:0')
feed_forward.1.num_batches_tracked tensor(0, device='cuda:0')

Codes for figure 1,3,4, and 5

Hello,
Could you please give me the codes to plot the figure of preactivations (Fig. 1,3) and L2-R CE train loss(4,5)? The link for the referenced repo would be ok as well.

About conv_init(m)

I see def conv_init(m): in the wide_resnet.py of models, but it just be defined, and it's not applyed. Is it should be used as net.apply(conv_init)?

How about sgdr?

CLR is a special case of sgdr (given restart peroid=1 and restart_mul=1). But how about the performance if I choose restart_period=10 and restart_mul=2? Since you always chose the minimum values of each cycle, it will be only a few snapshot to be averaged in sgdr. For example, in 100 epochs, there is only 3 cycles in sgdr, so only 3 snapshot I can use for swa.

Since CLR is not better than sgdr in my experiment, it's much better if sgdr can work with swa.

SWA with distributed training

In case of distributed training, e.g. DDP, each gpu will only process a minibatch, and the bn statistics computed in each gpu are different.
When SWA is adopted, we need to conduct 1 more epoch for bn_update, in this epoch should we use sync bn to average the bn statistics from all gpus?
And is there any other modifications we need to make for DDP training?

About bn_update

Hi, @timgaripov
I know we should run one additional pass over the training data to get the running mean and running var. But why change momentum?

        momentum = b / (n + b)
        for module in momenta.keys():
            module.momentum = momentum

[ Question ] Isn't the mid-training evaluation of the SWA performance corrupting the batchnorm running averages ?

Hello,

Thank you for your code.

I believe some mechanism is needed to reset the batchnorm running averages to their original values after a mid-training evaluation of the SWA performance (i.e. use of the bn_update method).
Otherwise, evaluating the SWA performance is tampering with the training procedure.

Maybe a buffer could be created when calling bn_update, and then used when the user calls a new method in the flavour of bn_reset before resuming to the training ?

Best,
Alex

Can we use Adam or other optimizer instead of SGD to train the network?

Hi, I use swa to train my network recently, and the task is Re-ID. But I can not see obvious improvement (actually, almost the same w/o swa) when training network with Adam.

So, can we use Adam or other optimizer instead of SGD to train the networks, if we want to improve our networks with swa?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.