GithubHelp home page GithubHelp logo

knowledgedistillation1's Introduction

  • 入口函数:bilstm.py
  • 学生模型:models
  • 数据处理:bilstm.py
  • 蒸馏:distillation/hintonDistiller.py

SST2:

  • bug: 学习问题,正确率还在50%左右,还在调试

QQP:

  • 可以正常学习,正确率63%~70%左右

Hinton Knowledge Distillation

Knowledge Distillation in PyTorch

Simple PyTorch implementation of (Hinton) Knowledge Distillation and BaseDistiller class to easily extend to other distillation procedures as well. Knowledge distillation in the sense of Hinton et al. (2015) seek to transfer knowledge from a large pretrained model, teacher, to a smaller untrained model, student. If done correctly, one can obtain performance improvements over student models trained from scratch, and more recent adaptions of the knowledge distillation scheme has examples of students outperforming the teacher. More recent work has introduced different distillation losses, looked at different information to transfer from the teacher, and the size of the student amongst others.

Install requirements

To install the needed requirements in a new conda environment (KD) use

conda env create -f environment.yml

Example

Using the HintonDistiller is straight forward. Provide the usual elements; optimizer, objectives, models etc. and initiate the distiller with a weighting, alpha, between the distillation and objective function as well as the layers used for activation matching between student and teacher.

import torch
import torch.nn as nn
from distillation.hintonDistiller import HintonDistiller
from distillation.utils import MLP, PseudoDataset

# Initialize random models and distiller
student = MLP(100, 10, 256)
teacher = MLP(100, 10, 256)
distiller = HintonDistiller(alpha=0.5, studentLayer=-2, teacherLayer=-2)

# Initialize objectives and optimizer
objective = nn.CrossEntropyLoss()
distillObjective = nn.KLDivLoss(reduction='batchmean')
optimizer = torch.optim.SGD(student.parameters(), lr=0.1)

# Pseudo dataset and dataloader 
trainloader = torch.utils.data.DataLoader(
    PseudoDataset(size=(100)),
    batch_size=512,
    shuffle=True)

# Load state if checkpoint is provided
checkpoint = None
startEpoch = distiller.load_state(checkpoint, student, teacher, optimizer)
epochs = 15

# Construct tensorboard logger
distiller.init_tensorboard_logger()

for epoch in range(startEpoch, epochs+1):
        # Training step for one full epoch
        trainMetrics = distiller.train_step(student=student,
                                            teacher=teacher,
                                            dataloader=trainloader,
                                            optimizer=optimizer,
                                            objective=objective,
                                            distillObjective=distillObjective)
        
        # Validation step for one full epoch
        validMetrics = distiller.validate(student=student,
                                          dataloader=trainloader,
                                          objective=objective)
        metrics = {**trainMetrics, **validMetrics}
        
        # Log to tensorbard
        distiller.log(epoch, metrics)

        # Save model
        distiller.save(epoch, student, teacher, optimizer)
        
        # Print epoch performance
        distiller.print_epoch(epoch, epochs, metrics)

To continue a previous run, add the path to the checkpoint and adjust the epochs to the total training length. If only some elements from a previous run should be loaded, set the remaining arguments to None in the .load_state() call.

Change type of knowledge distillation

In order to change the type of knowledge distillation, you merely need to change the type of distiller. Note, the following types of knowledge distillation is currently implemented:

  • Hinton Knowledge Distillation (Hinton et al. (2015))
  • Attention Knowledge Distillation (Zagoruyko and Komodakis (2016))
  • Data Free Knowledge Distillation or Zero-Shot Knowledge Distillation (Micaelli and Storkey (2019))

For Attention Knowledge Distillation on the first and third layer change to the following.

from distillation.attentionDistiller import AttentionDistiller
distiller = AttentionDistiller(alpha=0.5, studentLayer=[1, 3], teacherLayer=[1, 3])

Using the DataFreeDistller for Data Free Adversarial Knowledge Distillation (aka Zero-Shot Knowledge Distillation) is slightly more involved than e.g. HintonDistiller or AttentionDistiller. See the below example for usage of the DataFreeDistiller.

import torch
import torch.nn as nn
from distillation.datafreeDistiller import DataFreeDistiller
from distillation.utils import PseudoDataset, CNN, Generator

# Initialize random models and distiller
imgSize = (3, 32, 32)
noiseDim = 100
student = CNN(imgSize, 64)
teacher = CNN(imgSize, 64)
distiller = DataFreeDistiller(generatorIters=3,
                              studentIters=2,
                              generator=Generator(noiseDim, imgSize),
                              generatorLR=1e-3,
                              batchSize=64,
                              noiseDim=noiseDim,
                              resampleRatio=1)

# Initialize objectives and optimizer
objective = nn.KLDivLoss(reduction='batchmean')
validObjective = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(student.parameters(), lr=0.1)

# Pseudo dataset and dataloader 
validloader = torch.utils.data.DataLoader(
    PseudoDataset(size=imgSize),
    batch_size=512,
    shuffle=True)

# Load state if checkpoint is provided
checkpoint = None
startEpoch = distiller.load_state(checkpoint, student, teacher, optimizer)
epochs = 15

# Construct tensorboard logger
distiller.init_tensorboard_logger()

for epoch in range(startEpoch, epochs+1):
        # Training step for one full epoch
        trainMetrics = distiller.train_step(student=student,
                                            teacher=teacher,
                                            dataloader=None,
                                            optimizer=optimizer,
                                            objective=objective,
                                            distillObjective=None)
        
        # Validation step for one full epoch
        validMetrics = distiller.validate(student=student,
                                          dataloader=validloader,
                                          objective=validObjective)
        metrics = {**trainMetrics, **validMetrics}
        
        # Log to tensorbard
        distiller.log(epoch, metrics)

        # Save model
        distiller.save(epoch, student, teacher, optimizer)
        
        # Print epoch performance
        distiller.print_epoch(epoch, epochs, metrics)

For DatasetDistiller, no generator is used, but a fixed set of samples is changed through SGD to maximise the loss. It can be used with an optional scaler to rescale the samples to some fixed interval.

from distillation.datasetDistiller import DatasetDistiller
from distillation.utils import SigmoidScaler
distiller = DatasetDistiller(pseudoIters=3,
                             studentIters=2,
                             pseudoLR=1e-1,
                             scaler=SigmoidScaler((0,1)),
                             batchSize=64,
                             pseudoSize=imgSize)

Citation

Remember to cite the original papers:

Hinton et. al (2015)
@misc{hinton2015distilling,
    title = {{Distilling the Knowledge in a Neural Network}},
    author = {Hinton, Geoffrey and Vinyals, Oriol and Dean, Jeff},
    year = {2015},
    eprint = {1503.02531},
    archivePrefix = {arXiv},
    primaryClass = {stat.ML}
}
Zagoruyko and Komodakis (2016)
@misc{zagoruyko2016paying,
    title = {{Paying More Attention to Attention: Improving the Performance of Convolutional Neural Networks via Attention Transfer}},
    author = {Zagoruyko, Sergey and Komodakis, Nikos},
    year = {2016},
    eprint = {1612.03928},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV},
}
Micaelli and Storkey (2019)
@misc{micaelli2019zeroshot,
    title = {{Zero-shot Knowledge Transfer via Adversarial Belief Matching}},
    author = {Micaelli, Paul and Storkey, Amos},
    year = {2019},
    eprint = {1905.09768},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG},
}

knowledgedistillation1's People

Contributors

huyiwen avatar kennethborup avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.