GithubHelp home page GithubHelp logo

Comments (12)

DefangChen avatar DefangChen commented on August 15, 2024 1
   if opt.distill == 'simkd':
       feat_s, _ = model_s(images, is_feat=True)
       #feat_t, _ = model_t(images, is_feat=True)
       #feat_t = [f.detach() for f in feat_t]
       cls_t = model_t.module.get_feat_modules()[-1] if opt.multiprocessing_distributed else model_t.get_feat_modules()[-1]
       _, _, output = module_list[1](feat_s[-2], feat_s[-2], cls_t)
   else:
       output = model_s(images)

from simkd.

DefangChen avatar DefangChen commented on August 15, 2024

对于SimKD来说,Student网络在推理时对Teacher模型进行了一次前向传播(helper/loops.py 278-283行),而这部分代码是不必要的,可以注释掉。若你不做任何修改直接统计的话,就会出现Student模型推理速度反而慢于Teacher模型的错误结论。

from simkd.

DonMuv avatar DonMuv commented on August 15, 2024

对于SimKD来说,Student网络在推理时对Teacher模型进行了一次前向传播(helper/loops.py 278-283行),而这部分代码是不必要的,可以注释掉。若你不做任何修改直接统计的话,就会出现Student模型推理速度反而慢于Teacher模型的错误结论。

注释掉以后貌似就无法体现论文的**了,即把Teacher模型的分类器复用到Student模型的推理阶段,是这样吗

from simkd.

DefangChen avatar DefangChen commented on August 15, 2024

当然不是,照样用Teacher分类器进行推理(这对参数和计算量都影响很小),但是没必要对Teacher整个模型做前向传播。

from simkd.

DonMuv avatar DonMuv commented on August 15, 2024

当然不是,照样用Teacher分类器进行推理(这对参数和计算量都影响很小),但是没必要对Teacher整个模型做前向传播。

       if opt.distill == 'simkd':
           feat_s, _ = model_s(images, is_feat=True)
           feat_t, _ = model_t(images, is_feat=True)
           feat_t = [f.detach() for f in feat_t]
           cls_t = model_t.module.get_feat_modules()[-1] if opt.multiprocessing_distributed else model_t.get_feat_modules()[-1]
           _, _, output = module_list[1](feat_s[-2], feat_t[-2], cls_t)
       else:
           output = model_s(images)

如果if语句被注释的话,那就只能运行output = model_s(images)这句了,这句话代码理解为单纯通过student模型进行推理,没有使用到teacher的分类器,不知道我哪里有误解,还请指教

from simkd.

qiuxiaqing avatar qiuxiaqing commented on August 15, 2024

看了代码,如果将helper/loops.py 278-283行注释掉的话,只能运行output = model_s(images)这句,即只用student获得推理结果,可以写个推理代码试试,看看保存下来的student模型在CIFAR100上的推理结果是不是train时候的验证结果。理论上,如果student的分类器学得好的话,那也能达到不错的效果。但是这样确实没有利用到Teacher的分类器。

from simkd.

qiuxiaqing avatar qiuxiaqing commented on August 15, 2024

我将保存的学生模型测试了一下,直接用“output = model_s(images)”,得到的结果很差(训练精度能达到78.45%,而测试精度只达到0.0127%)。分析了下原因:训练过程中最优模型保存时,应该构建一个新的网络s_new,将学生网络的backbone+空间和通道对齐部分+教师网络分类器 这三部分的权重都加载进一个网络中,才能实现正确的推理。所以说原始train_student.py中模型保存有问题。

from simkd.

DefangChen avatar DefangChen commented on August 15, 2024
  1. 你所谓的“将学生网络的backbone+空间和通道对齐部分+教师网络分类器”正是我们保存模型的方法,见train_student.py Line 394-412.
  2. 本文并没有Loss去训练“student的分类器”。如果你只保存学生网络的backbone+随机初始化的分类器,那么在CIFAR-100上的准确率显然退化为random guess(约为1%).

from simkd.

qiuxiaqing avatar qiuxiaqing commented on August 15, 2024

对,你是保存了“projector”部分的权重。

from simkd.

qiuxiaqing avatar qiuxiaqing commented on August 15, 2024

关键问题是:您应该在保存时候,将学生网络的backbone+空间和通道对齐部分+教师网络分类器 这三部分的权重都加载进一个新的网络s'中,这样推理起来就不需要加载教师网络的权重了。。。要不然在推理时候聚合三部分的特征,还是要加载教师模型的权重获取FC层权重,这样还是会影响推理速度。

from simkd.

DefangChen avatar DefangChen commented on August 15, 2024

你可以测试一下,加载参数的时间可以忽略不计,这种细节跟论文讨论的内容也没有关系。你可以按照自己喜欢的方式去保存参数。

from simkd.

qiuxiaqing avatar qiuxiaqing commented on August 15, 2024

好的,感谢您的回复

from simkd.

Related Issues (17)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.