GithubHelp home page GithubHelp logo

defangchen / simkd Goto Github PK

View Code? Open in Web Editor NEW
87.0 3.0 19.0 487 KB

[CVPR-2022] Official implementation for "Knowledge Distillation with the Reused Teacher Classifier".

Python 97.80% Shell 2.20%
deep-learning knowledge-distillation

simkd's Introduction

SimKD

Knowledge Distillation with the Reused Teacher Classifier (CVPR-2022) https://arxiv.org/abs/2203.14001

Toolbox for KD research

This repository aims to provide a compact and easy-to-use implementation of several representative knowledge distillation approaches on standard image classification tasks (e.g., CIFAR100, ImageNet).

  • Generally, these KD approaches include a classification loss, a logit-level distillation loss, and an additional feature distillation loss. For fair comparison and ease of tuning, we fix the hyper-parameters for the first two loss terms as one throughout all experiments. (--cls 1 --div 1)

  • The following approaches are currently supported by this toolbox, covering vanilla KD, feature-map distillation/feature-embedding distillation, instance-level distillation/pairwise-level distillation:

  • This toolbox is built on a open-source benchmark and our previous repository. The implementation of more KD approaches can be found there.

  • Computing Infrastructure:

    • We use one NVIDIA GeForce RTX 2080Ti GPU for CIFAR-100 experiments. The PyTorch version is 1.0. We use four NVIDIA A40 GPUs for ImageNet experiments. The PyTorch version is 1.10.
    • As for ImageNet, we use DALI for data loading and pre-processing.
  • The current codes have been reorganized and we have not tested them thoroughly. If you have any questions, please contact us without hesitation.

  • Please put the CIFAR-100 and ImageNet dataset in the ../data/.

Get the pretrained teacher models

# CIFAR-100
python train_teacher.py --batch_size 64 --epochs 240 --dataset cifar100 --model resnet32x4 --learning_rate 0.05 --lr_decay_epochs 150,180,210 --weight_decay 5e-4 --trial 0 --gpu_id 0

# ImageNet
python train_teacher.py --batch_size 256 --epochs 120 --dataset imagenet --model ResNet18 --learning_rate 0.1 --lr_decay_epochs 30,60,90 --weight_decay 1e-4 --num_workers 32 --gpu_id 0,1,2,3 --dist-url tcp://127.0.0.1:23333 --multiprocessing-distributed --dali gpu --trial 0 

The pretrained teacher models used in our paper are provided in this link [GoogleDrive].

Train the student models with various KD approaches

# CIFAR-100
python train_student.py --path_t ./save/teachers/models/resnet32x4_vanilla/resnet32x4_best.pth --distill simkd --model_s resnet8x4 -c 0 -d 0 -b 1 --trial 0

# ImageNet
python train_student.py --path-t './save/teachers/models/ResNet50_vanilla/ResNet50_best.pth' --batch_size 256 --epochs 120 --dataset imagenet --model_s ResNet18 --distill simkd -c 0 -d 0 -b 1 --learning_rate 0.1 --lr_decay_epochs 30,60,90 --weight_decay 1e-4 --num_workers 32 --gpu_id 0,1,2,3 --dist-url tcp://127.0.0.1:23444 --multiprocessing-distributed --dali gpu --trial 0 

More scripts are provided in ./scripts

Some results on CIFAR-100

ResNet-8x4 VGG-8 ShuffleNetV2x1.5
Student 73.09 70.46 74.15
KD 74.42 72.73 76.82
FitNet 74.32 72.91 77.12
AT 75.07 71.90 77.51
SP 74.29 73.12 77.18
VID 74.55 73.19 77.11
CRD 75.59 73.54 77.66
SRRL 75.39 73.23 77.55
SemCKD 76.23 75.27 79.13
SimKD (f=8) 76.73 74.74 78.96
SimKD (f=4) 77.88 75.62 79.48
SimKD (f=2) 78.08 75.76 79.54
Teacher (ResNet-32x4) 79.42 79.42 79.42

result

(Left) The cross-entropy loss between model predictions and test labels.
(Right) The top-1 test accuracy (%) (Student: ResNet-8x4, Teacher: ResNet-32x4).

Citation

If you find this repository useful, please consider citing the following paper:

@inproceedings{chen2022simkd,
  title={Knowledge Distillation with the Reused Teacher Classifier},
  author={Chen, Defang and Mei, Jian-Ping and Zhang, Hailin and Wang, Can and Feng, Yan and Chen, Chun},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
  pages={11933--11942},
  year={2022}
}
@inproceedings{chen2021cross,
  author    = {Defang Chen and Jian{-}Ping Mei and Yuan Zhang and Can Wang and Zhe Wang and Yan Feng and Chun Chen},
  title     = {Cross-Layer Distillation with Semantic Calibration},
  booktitle = {Proceedings of the AAAI Conference on Artificial Intelligence},
  pages     = {7028--7036},
  year      = {2021},
}

simkd's People

Contributors

defangchen avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

simkd's Issues

Issue with Integrating a New Loss Function into Knowledge Distillation Framework

Hi,

We're encountering difficulties incorporating a new loss function into the primary KD ones implemented within your repository. Despite adding it to the existing loss calculation within the file "/helper/loops.py", we've noticed no discernible change in the output.

We have the presumption that our approach lacks some essential elements necessary for proper implementation in pytorch or in the repository code. Any insights or suggestions on how to proceed would be very helpful.

Criterion = Nn.MSEloss()

Loss_custom = criterion( torch.tensor(map1, requires_grad=True),
torch.tensor(map2, requires_grad=True)).cuda()

and then,

loss = opt.cls * loss_cls + opt.div * loss_div + opt.beta * loss_kd + opt.custom_weight * Loss_custom

Map1 and map2 are numpy.ndarray(), while custom_weight is a float.

Thanks in advance!

edit: I've found the error, the added function for the maps wasn't planned to be differentiable and a quick-fix created a theoretical error. Thanks anyways.

It is hoped to improve the setting of relevant parameters in the form of a table

Hello, thank you very much for open-source several representative knowledge distillation methods. In order to use this repository as a benchmark in the field of knowledge distillation and for researchers to quickly start comparative studies on Vanilla KD, FitNet, AT, SP, VID, CRD, SRRL, SemCKD, KR, SimKD, etc., Can you list in tabular form the usual Settings of --cls, --div, --beta, --factor, --soft for the above knowledge distillation methods?

Questions about the cross entropy loss

Greetings,

I hope this message finds you well.

I was wondering if you could kindly clarify something for me.

Based on the command "python train_student.py --path_t ./save/teachers/models/resnet32x4_vanilla/resnet32x4_best.pth --distill simkd --model_s resnet8x4 -c 0 -d 0 -b 1 --trial 0", it seems that the student is solely trained with L2 loss between teacher and student features, without the classification loss of student output logits.

However, I wanted to confirm if my understanding is correct.

Thank you for your time and assistance.

2.5 hours training only 1 epoch with four v100

it takes 5 hours training only 1 epoch on the imagenet with 4 v100,is it normal? Is there any problem with the DALI module?
My script is listed below:

#!/bin/bash
#SBATCH -N 1
#SBATCH -n 32
#SBATCH -M swarm
#SBATCH -p gpu
#SBATCH --gres=gpu:4
#SBATCH --no-requeue
export MODULEPATH=/dat01/paraai_test/software/modulefiles:$MODULEPATH
module load nvidia/cuda/11.6 anaconda/3.7
source activate KD
export PYTHONUNBUFFERED=1
python train_teacher.py --batch_size 256 --epochs 120 --dataset imagenet --model ResNet50 --learning_rate 0.1 --lr_decay_epochs 30,60,90 --weight_decay 1e-4 --num_workers 32 --gpu_id 0,1,2,3 --dist-url tcp://127.0.0.1:10000 --multiprocessing-distributed --trial 0 

And the corresponding output is listed below:


Use GPU: 3 for training
Use GPU: 0 for trainingUse GPU: 2 for training

Use GPU: 1 for training
==> training...
==> training...
==> training...
==> training...
Epoch: [1][0/5005]	GPU 3	Time: 15.908	Loss 6.9851	Acc@1 0.000	Acc@5 0.000
Epoch: [1][0/5005]	GPU 2	Time: 16.105	Loss 7.1599	Acc@1 0.000	Acc@5 3.125
Epoch: [1][0/5005]	GPU 1	Time: 15.316	Loss 7.1880	Acc@1 0.000	Acc@5 0.000
Epoch: [1][0/5005]	GPU 0	Time: 14.870	Loss 7.0144	Acc@1 0.000	Acc@5 0.000
Epoch: [1][200/5005]	GPU 0	Time: 291.743	Loss 7.0587	Acc@1 0.101	Acc@5 0.552
Epoch: [1][200/5005]	GPU 3	Time: 296.768	Loss 7.0906	Acc@1 0.148	Acc@5 0.591
Epoch: [1][200/5005]	GPU 2	Time: 296.956	Loss 7.0746	Acc@1 0.109	Acc@5 0.583
Epoch: [1][200/5005]	GPU 1	Time: 296.166	Loss 7.0819	Acc@1 0.078	Acc@5 0.505
Epoch: [1][400/5005]	GPU 0	Time: 557.993	Loss 6.9764	Acc@1 0.160	Acc@5 0.725
Epoch: [1][400/5005]	GPU 3	Time: 563.076	Loss 6.9914	Acc@1 0.152	Acc@5 0.690
Epoch: [1][400/5005]	GPU 1	Time: 562.478	Loss 6.9888	Acc@1 0.140	Acc@5 0.647
Epoch: [1][400/5005]	GPU 2	Time: 563.269	Loss 6.9830	Acc@1 0.125	Acc@5 0.733
Epoch: [1][600/5005]	GPU 0	Time: 814.599	Loss 6.9050	Acc@1 0.211	Acc@5 0.915
Epoch: [1][600/5005]	GPU 3	Time: 819.187	Loss 6.9171	Acc@1 0.226	Acc@5 0.970
Epoch: [1][600/5005]	GPU 1	Time: 818.591	Loss 6.9163	Acc@1 0.211	Acc@5 0.920
Epoch: [1][600/5005]	GPU 2	Time: 819.381	Loss 6.9118	Acc@1 0.195	Acc@5 0.941
Epoch: [1][800/5005]	GPU 0	Time: 1095.555	Loss 6.8389	Acc@1 0.304	Acc@5 1.248
Epoch: [1][800/5005]	GPU 3	Time: 1100.501	Loss 6.8444	Acc@1 0.318	Acc@5 1.280
Epoch: [1][800/5005]	GPU 2	Time: 1100.695	Loss 6.8433	Acc@1 0.300	Acc@5 1.266
Epoch: [1][800/5005]	GPU 1	Time: 1099.905	Loss 6.8481	Acc@1 0.283	Acc@5 1.198
Epoch: [1][1000/5005]	GPU 0	Time: 1423.992	Loss 6.7630	Acc@1 0.395	Acc@5 1.666
Epoch: [1][1000/5005]	GPU 2	Time: 1429.297	Loss 6.7694	Acc@1 0.407	Acc@5 1.648
Epoch: [1][1000/5005]	GPU 3	Time: 1429.103	Loss 6.7699	Acc@1 0.473	Acc@5 1.751
Epoch: [1][1000/5005]	GPU 1	Time: 1428.507	Loss 6.7722	Acc@1 0.379	Acc@5 1.617
Epoch: [1][1200/5005]	GPU 0	Time: 1779.207	Loss 6.6860	Acc@1 0.514	Acc@5 2.156
Epoch: [1][1200/5005]	GPU 3	Time: 1784.156	Loss 6.6886	Acc@1 0.609	Acc@5 2.300
Epoch: [1][1200/5005]	GPU 2	Time: 1784.351	Loss 6.6886	Acc@1 0.595	Acc@5 2.227
Epoch: [1][1200/5005]	GPU 1	Time: 1783.559	Loss 6.6898	Acc@1 0.546	Acc@5 2.180
Epoch: [1][1400/5005]	GPU 0	Time: 2156.362	Loss 6.6043	Acc@1 0.712	Acc@5 2.876
Epoch: [1][1400/5005]	GPU 3	Time: 2161.578	Loss 6.6066	Acc@1 0.797	Acc@5 2.940
Epoch: [1][1400/5005]	GPU 2	Time: 2161.773	Loss 6.6032	Acc@1 0.756	Acc@5 2.875
Epoch: [1][1400/5005]	GPU 1	Time: 2160.982	Loss 6.6077	Acc@1 0.713	Acc@5 2.799
Epoch: [1][1600/5005]	GPU 0	Time: 2535.887	Loss 6.5222	Acc@1 0.948	Acc@5 3.618
Epoch: [1][1600/5005]	GPU 2	Time: 2541.294	Loss 6.5257	Acc@1 0.963	Acc@5 3.537
Epoch: [1][1600/5005]	GPU 3	Time: 2541.099	Loss 6.5265	Acc@1 1.023	Acc@5 3.639
Epoch: [1][1600/5005]	GPU 1	Time: 2540.504	Loss 6.5275	Acc@1 0.933	Acc@5 3.518
Epoch: [1][1800/5005]	GPU 0	Time: 2907.476	Loss 6.4424	Acc@1 1.222	Acc@5 4.403
Epoch: [1][1800/5005]	GPU 3	Time: 2912.685	Loss 6.4441	Acc@1 1.305	Acc@5 4.422
Epoch: [1][1800/5005]	GPU 2	Time: 2912.880	Loss 6.4436	Acc@1 1.231	Acc@5 4.335
Epoch: [1][1800/5005]	GPU 1	Time: 2912.090	Loss 6.4461	Acc@1 1.173	Acc@5 4.349
Epoch: [1][2000/5005]	GPU 0	Time: 3269.215	Loss 6.3630	Acc@1 1.473	Acc@5 5.247
Epoch: [1][2000/5005]	GPU 3	Time: 3274.484	Loss 6.3632	Acc@1 1.573	Acc@5 5.291
Epoch: [1][2000/5005]	GPU 2	Time: 3274.680	Loss 6.3660	Acc@1 1.545	Acc@5 5.195
Epoch: [1][2000/5005]	GPU 1	Time: 3273.889	Loss 6.3661	Acc@1 1.489	Acc@5 5.262
Epoch: [1][2200/5005]	GPU 0	Time: 3623.786	Loss 6.2894	Acc@1 1.772	Acc@5 6.114
Epoch: [1][2200/5005]	GPU 2	Time: 3629.261	Loss 6.2902	Acc@1 1.864	Acc@5 6.095
Epoch: [1][2200/5005]	GPU 1	Time: 3628.470	Loss 6.2897	Acc@1 1.795	Acc@5 6.139
Epoch: [1][2200/5005]	GPU 3	Time: 3629.065	Loss 6.2892	Acc@1 1.846	Acc@5 6.124
Epoch: [1][2400/5005]	GPU 0	Time: 3974.772	Loss 6.2160	Acc@1 2.103	Acc@5 7.004
Epoch: [1][2400/5005]	GPU 1	Time: 3979.475	Loss 6.2178	Acc@1 2.105	Acc@5 6.986
Epoch: [1][2400/5005]	GPU 3	Time: 3980.070	Loss 6.2147	Acc@1 2.181	Acc@5 7.022
Epoch: [1][2400/5005]	GPU 2	Time: 3980.266	Loss 6.2174	Acc@1 2.212	Acc@5 6.961
Epoch: [1][2600/5005]	GPU 0	Time: 4325.969	Loss 6.1438	Acc@1 2.468	Acc@5 7.927
Epoch: [1][2600/5005]	GPU 3	Time: 4331.301	Loss 6.1456	Acc@1 2.512	Acc@5 7.901
Epoch: [1][2600/5005]	GPU 2	Time: 4331.497	Loss 6.1456	Acc@1 2.553	Acc@5 7.878
Epoch: [1][2600/5005]	GPU 1	Time: 4330.707	Loss 6.1489	Acc@1 2.418	Acc@5 7.832
Epoch: [1][2800/5005]	GPU 0	Time: 4677.760	Loss 6.0758	Acc@1 2.825	Acc@5 8.852
Epoch: [1][2800/5005]	GPU 3	Time: 4683.093	Loss 6.0793	Acc@1 2.846	Acc@5 8.775
Epoch: [1][2800/5005]	GPU 1	Time: 4682.499	Loss 6.0815	Acc@1 2.762	Acc@5 8.692
Epoch: [1][2800/5005]	GPU 2	Time: 4683.289	Loss 6.0775	Acc@1 2.907	Acc@5 8.770
Epoch: [1][3000/5005]	GPU 0	Time: 5032.306	Loss 6.0100	Acc@1 3.205	Acc@5 9.768
Epoch: [1][3000/5005]	GPU 3	Time: 5037.679	Loss 6.0131	Acc@1 3.234	Acc@5 9.678
Epoch: [1][3000/5005]	GPU 2	Time: 5037.874	Loss 6.0127	Acc@1 3.258	Acc@5 9.667
Epoch: [1][3000/5005]	GPU 1	Time: 5037.085	Loss 6.0185	Acc@1 3.110	Acc@5 9.559
Epoch: [1][3200/5005]	GPU 0	Time: 5388.797	Loss 5.9487	Acc@1 3.564	Acc@5 10.615Epoch: [1][3200/5005]	GPU 3	Time: 5394.206	Loss 5.9486	Acc@1 3.599	Acc@5 10.609

Epoch: [1][3200/5005]	GPU 2	Time: 5394.402	Loss 5.9502	Acc@1 3.605	Acc@5 10.523Epoch: [1][3200/5005]	GPU 1	Time: 5393.612	Loss 5.9578	Acc@1 3.477	Acc@5 10.385

Epoch: [1][3400/5005]	GPU 3	Time: 5749.492	Loss 5.8877	Acc@1 3.966	Acc@5 11.486
Epoch: [1][3400/5005]	GPU 1	Time: 5748.898	Loss 5.8954	Acc@1 3.884	Acc@5 11.295
Epoch: [1][3400/5005]	GPU 0	Time: 5744.161	Loss 5.8889	Acc@1 3.934	Acc@5 11.495
Epoch: [1][3400/5005]	GPU 2	Time: 5749.687	Loss 5.8888	Acc@1 3.969	Acc@5 11.402
Epoch: [1][3600/5005]	GPU 0	Time: 6101.266	Loss 5.8312	Acc@1 4.309	Acc@5 12.364
Epoch: [1][3600/5005]	GPU 2	Time: 6106.812	Loss 5.8314	Acc@1 4.354	Acc@5 12.249
Epoch: [1][3600/5005]	GPU 1	Time: 6106.022	Loss 5.8368	Acc@1 4.254	Acc@5 12.168
Epoch: [1][3600/5005]	GPU 3	Time: 6106.617	Loss 5.8300	Acc@1 4.333	Acc@5 12.333
Epoch: [1][3800/5005]	GPU 0	Time: 6459.961	Loss 5.7734	Acc@1 4.730	Acc@5 13.255
Epoch: [1][3800/5005]	GPU 3	Time: 6465.350	Loss 5.7715	Acc@1 4.733	Acc@5 13.205
Epoch: [1][3800/5005]	GPU 2	Time: 6465.546	Loss 5.7728	Acc@1 4.755	Acc@5 13.147
Epoch: [1][3800/5005]	GPU 1	Time: 6464.756	Loss 5.7784	Acc@1 4.666	Acc@5 13.045
Epoch: [1][4000/5005]	GPU 0	Time: 6821.551	Loss 5.7189	Acc@1 5.116	Acc@5 14.062
Epoch: [1][4000/5005]	GPU 2	Time: 6827.156	Loss 5.7164	Acc@1 5.136	Acc@5 14.011
Epoch: [1][4000/5005]	GPU 3	Time: 6826.961	Loss 5.7156	Acc@1 5.153	Acc@5 14.075
Epoch: [1][4000/5005]	GPU 1	Time: 6826.366	Loss 5.7246	Acc@1 5.035	Acc@5 13.852
Epoch: [1][4200/5005]	GPU 0	Time: 7185.276	Loss 5.6636	Acc@1 5.519	Acc@5 14.919
Epoch: [1][4200/5005]	GPU 2	Time: 7190.915	Loss 5.6622	Acc@1 5.547	Acc@5 14.893
Epoch: [1][4200/5005]	GPU 1	Time: 7190.125	Loss 5.6692	Acc@1 5.461	Acc@5 14.719
Epoch: [1][4200/5005]	GPU 3	Time: 7190.720	Loss 5.6628	Acc@1 5.544	Acc@5 14.890
Epoch: [1][4400/5005]	GPU 0	Time: 7552.365	Loss 5.6117	Acc@1 5.943	Acc@5 15.757
Epoch: [1][4400/5005]	GPU 3	Time: 7557.846	Loss 5.6099	Acc@1 5.964	Acc@5 15.718
Epoch: [1][4400/5005]	GPU 1	Time: 7557.251	Loss 5.6171	Acc@1 5.864	Acc@5 15.536
Epoch: [1][4400/5005]	GPU 2	Time: 7558.041	Loss 5.6100	Acc@1 5.965	Acc@5 15.720
Epoch: [1][4600/5005]	GPU 1	Time: 7927.832	Loss 5.5662	Acc@1 6.268	Acc@5 16.358
Epoch: [1][4600/5005]	GPU 0	Time: 7922.937	Loss 5.5613	Acc@1 6.315	Acc@5 16.552
Epoch: [1][4600/5005]	GPU 2	Time: 7928.622	Loss 5.5599	Acc@1 6.367	Acc@5 16.525
Epoch: [1][4600/5005]	GPU 3	Time: 7928.427	Loss 5.5596	Acc@1 6.370	Acc@5 16.548
Epoch: [1][4800/5005]	GPU 0	Time: 8296.352	Loss 5.5120	Acc@1 6.723	Acc@5 17.341
Epoch: [1][4800/5005]	GPU 2	Time: 8302.048	Loss 5.5108	Acc@1 6.773	Acc@5 17.319
Epoch: [1][4800/5005]	GPU 1	Time: 8301.258	Loss 5.5178	Acc@1 6.645	Acc@5 17.136
Epoch: [1][4800/5005]	GPU 3	Time: 8301.852	Loss 5.5114	Acc@1 6.764	Acc@5 17.326
Epoch: [1][5000/5005]	GPU 0	Time: 8672.685	Loss 5.4638	Acc@1 7.120	Acc@5 18.129
Epoch: [1][5000/5005]	GPU 2	Time: 8678.350	Loss 5.4637	Acc@1 7.168	Acc@5 18.112
Epoch: [1][5000/5005]	GPU 1	Time: 8677.560	Loss 5.4699	Acc@1 7.058	Acc@5 17.914
Epoch: [1][5000/5005]	GPU 3	Time: 8678.154	Loss 5.4639	Acc@1 7.147	Acc@5 18.098
 * Epoch 1, Acc@1 7.131, Acc@5 18.077, Time 18126.27
Test: [0/196]	GPU: 3	Time: 1.507	Loss 4.8784	Acc@1 7.812	Acc@5 25.000Test: [0/196]	GPU: 0	Time: 1.023	Loss 3.0211	Acc@1 51.562	Acc@5 71.875

Test: [0/196]	GPU: 1	Time: 1.507	Loss 2.7740	Acc@1 40.625	Acc@5 75.000
Test: [0/196]	GPU: 2	Time: 1.509	Loss 3.6662	Acc@1 34.375	Acc@5 57.812
 ** Acc@1 17.078, Acc@5 38.732
==> training...==> training...
==> training...

saving the best model!
Epoch: [2][0/5005]	GPU 3	Time: 0.300	Loss 4.2087	Acc@1 10.938	Acc@5 29.688
Epoch: [2][0/5005]	GPU 2	Time: 0.299	Loss 4.6335	Acc@1 4.688	Acc@5 34.375
Epoch: [2][0/5005]	GPU 1	Time: 0.299	Loss 4.6575	Acc@1 10.938	Acc@5 26.562
==> training...
Epoch: [2][0/5005]	GPU 0	Time: 0.062	Loss 4.4344	Acc@1 15.625	Acc@5 32.812
Epoch: [2][200/5005]	GPU 0	Time: 487.983	Loss 4.2647	Acc@1 17.436	Acc@5 37.819
Epoch: [2][200/5005]	GPU 3	Time: 494.182	Loss 4.2411	Acc@1 17.211	Acc@5 37.974
Epoch: [2][200/5005]	GPU 2	Time: 494.177	Loss 4.2715	Acc@1 16.620	Acc@5 37.803
.......

为什么SimKD蒸馏后的模型的推理时间比teacher模型还慢呢

首先感谢作者提供的源码,为学术界做出极大的贡献。
我复现了该论文代码,使用ResNet38x4预训练模型作为teacher网络,使用ShuffleNetv2x1.5网络作为student模型进行训练,使用CIFAR100数据集。
Student网络训练过程中,我在代码中加入tqdm包以做推理速度展示,发现原始teacher模型的推理时间为130张每秒,蒸馏后的student模型的推理时间为50张每秒,但模型大小确实缩小了6倍多,且精度不变。
推理时间的降低是否是因为teacher模型的分类器重新使用在student推理阶段而导致的呢?

DDP for teacher

Hi, I run your code and encounter the error bellow when I run the code for 'semckd' (I havent try the SimKD). I think because we don't directly use the parameter in teacher. Please give some advice if possible. Thank you!

RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by passing the keyword argument find_unused_parameters=True to torch.nn.parallel.DistributedDataParallel, and by
making sure all forward function outputs participate in calculating loss.
If you already have done the above, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's forward function. Please include the loss function and the structure of the return value of forward of your module when reporting this issue (e.g. list, dict, iterable).
Parameter indices which did not receive grad for rank 0: 4 5 6 7 12 13 14 15 20 21 22 23 28 29 30 31 36 37 38 39 44 45 46 47 52 53 54 55 60 61 62 63
In addition, you can set the environment variable TORCH_DISTRIBUTED_DEBUG to either INFO or DETAIL to print out information about which particular parameters did not receive gradient on this rank as part of this error

Why simkd use feat[-2] here?

            elif opt.distill == 'simkd':
            trans_feat_s, trans_feat_t, pred_feat_s = module_list[1](feat_s[-2], feat_t[-2], cls_t) # why index is -2?
            logit_s = pred_feat_s
            loss_kd = criterion_kd(trans_feat_s, trans_feat_t)

Other student models

Hi,
Could you please tell me, for what kind of student model I can use this code? I tried to use for example resnet 20, vgg-8 also but there was an error:
UnboundLocalError: local variable 'model_s' referenced before assignment

Re-use a distilled student as a teacher

Hi, I just have some problems here. I used the ShuffleV2_1_5 that was distilled from resnet32x4_vanilla on CIFAR-100, and now I want to use that result (ShuffleV2_1_5) as a teacher. But when I executed the code, the error shows up as follows. One noticeable thing is that the accuracy of the model is too low, why?

Use GPU: 0 for training
==> loading teacher model
==> done
Files already downloaded and verified
/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py:560: UserWarning: This DataLoader will create 8 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  cpuset_checked))
Files already downloaded and verified
/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py:560: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  cpuset_checked))
Test: [0/313]	GPU: 0	Time: 1.615	Loss 4.6052	Acc@1 0.000	Acc@5 3.125
Test: [200/313]	GPU: 0	Time: 3.354	Loss 4.6052	Acc@1 0.964	Acc@5 5.193
teacher accuracy:  1.0
==> training...
Traceback (most recent call last):
  File "train_student.py", line 428, in <module>
    main()
  File "train_student.py", line 169, in main
    main_worker(None if ngpus_per_node > 1 else opt.gpu_id, ngpus_per_node, opt)
  File "train_student.py", line 365, in main_worker
    train_acc, train_acc_top5, train_loss = train(epoch, train_loader, module_list, criterion_list, optimizer, opt)
  File "/kaggle/working/SimKD/helper/loops.py", line 156, in train_distill
    trans_feat_s, trans_feat_t, pred_feat_s = module_list[1](feat_s[-2], feat_t[-2], cls_t)
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/kaggle/working/SimKD/models/util.py", line 234, in forward
    pred_feat_s = cls_t(temp_feat)
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/container.py", line 139, in forward
    input = module(input)
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/kaggle/working/SimKD/models/ShuffleNetv2.py", line 90, in forward
    out1 = self.bn1(self.conv1(x))
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/conv.py", line 457, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/conv.py", line 454, in _conv_forward
    self.padding, self.dilation, self.groups)
RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [64, 704]

The role of projector layer

When the feature dimensions from the student and teacher are equal, such as resn32x4 and res8x4, it doesn't need projector, we try to remove the projector. But we get a terrible result. So the reason why this method works is because of the existence of the projector?

Request for t-SNE

Thank you for sharing, could you please share the code of the teacher and student t-SNE in your paper?

acc on imagnet

In the document, why the teacher model trained on the imagenet dataset is resnet18, when training the student model, the teacher model is resnet50, and the student model is resnet18. In addition, can you tell me the acc on imagnet?

Models implementation: number of channels

Hi, thank you for releasing the code.

I noticed that almost all models have a (much) smaller number of channels with respect to the default ones, e.g. wrt the torchvision implementation ones. Is there a reason for this design choice?
Thank you.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.