GithubHelp home page GithubHelp logo

Gradients Vanish about net2net.torch HOT 3 CLOSED

dmsedra avatar dmsedra commented on July 20, 2024
Gradients Vanish

from net2net.torch.

Comments (3)

soumith avatar soumith commented on July 20, 2024

Hi, I tried to verify your issue via this example:

require 'nn'
require 'nnx'
local n2n = require 'net2net'

local m = nn.Sequential()
m:add(nn.Probe('before 1st layer'))
m:add(nn.Linear(100,200))
m:add(nn.ReLU())
m:add(nn.Probe('after 1st layer, before 2nd layer'))
m:add(nn.Linear(200,400))
m:add(nn.ReLU())
m:add(nn.Probe('after 2nd layer, before 3nd layer'))
m:add(nn.Linear(400,400))
m:add(nn.ReLU())
m:add(nn.Probe('after 3nd layer'))

local inp = torch.randn(4, 100)

-- output before transform
local out = m:forward(inp):clone()


-- make the 2nd layer of m to 1000 units
n2n.wider(m, 5, 8, 1000)

local outWider = m:forward(inp):clone()

assert(out:add(-1, outWider):abs():max() < 0.001)

print('CHECKING GRADIENTS')
print('CHECKING GRADIENTS')
print('CHECKING GRADIENTS')
print('CHECKING GRADIENTS')

local out = m:forward(inp)
m:backward(inp, out:clone():normal())

The gradients dont seem to vanish:

CHECKING GRADIENTS
CHECKING GRADIENTS
CHECKING GRADIENTS
CHECKING GRADIENTS

<before 1st layer>.output
  + size = 4x100
  + mean = -0.012448332549268
  + std = 1.026612854967
  + min = -2.7935804695658
  + max = 2.4375579174905
  + time since last probe = 0.1ms

<after 1st layer, before 2nd layer>.output
  + size = 4x200
  + mean = 0.22821340252776
  + std = 0.33271552442906
  + min = 0
  + max = 1.7280632397103
  + time since last probe = 0.1ms

<after 2nd layer, before 3nd layer>.output
  + size = 4x1000
  + mean = 0.090732867737928
  + std = 0.13194987718722
  + min = 0
  + max = 0.81821365921255
  + time since last probe = 0.2ms

<after 3nd layer>.output
  + size = 4x400
  + mean = 0.037815612714769
  + std = 0.057141964765994
  + min = 0
  + max = 0.34245181026104
  + time since last probe = 0.3ms

layer<after 3nd layer>.gradInput
  + size = 4x400
  + mean = -0.028988228574532
  + std = 1.0051430135503
  + min = -3.4666892521304
  + max = 3.3759001076536
  + time since last probe = 0.2ms

layer<after 2nd layer, before 3nd layer>.gradInput
  + size = 4x1000
  + mean = 0.0027001652243362
  + std = 0.18928777455994
  + min = -1.0130841813304
  + max = 1.250792645219
  + time since last probe = 0.5ms

layer<after 1st layer, before 2nd layer>.gradInput
  + size = 4x200
  + mean = 0.00341808812378
  + std = 0.23542371323202
  + min = -0.76848756463182
  + max = 0.94700955915216
  + time since last probe = 0.5ms

layer<before 1st layer>.gradInput
  + size = 4x100
  + mean = -0.0047065170117254
  + std = 0.12592198643502
  + min = -0.36394529670349
  + max = 0.33896972226453
  + time since last probe = 0.3ms

Can you give me an example to reproduce in your case?

from net2net.torch.

dmsedra avatar dmsedra commented on July 20, 2024

Hi,

I put in the probes in and am observing near 0 values for grad input. Here's my code and model.

Code:

require 'optim'
require 'nn'
require 'cifar-dataset'
require 'cudnn'
------------------------------------------------------------------------
--                              Parser                                --
------------------------------------------------------------------------
local function commandLine()
    local cmd = torch.CmdLine()

    cmd:text()
    cmd:text('Options:')
    cmd:option('-seed',         1, 'fixed input seed for repeatable experiments')
    cmd:option('-validate',     0.2, 'Specify size of validation set (empty by default by)')
    cmd:option('-learningRate', 1e-3, 'learning rate at t=0')
    cmd:option('-decay_lr',     1e-4, 'learning rate decay')
    cmd:option('-batchSize',    50, 'mini-batch size (1 = pure stochastic)')
    cmd:option('-momentum',     0.9, 'momentum (SGD only)')
    cmd:option('-l2reg',        0, 'l2 regularization')
    cmd:option('-maxEpoch',     100, 'maximum # of epochs to train for')
    cmd:option('-shuffle',      true, 'shuffle training data')
    cmd:option('-optimizer',    'sgd', 'choose optimizer sgd|cg')
    cmd:option('-device',       'gpu', 'what device to use')
    cmd:option('-std',          0, 'standard deviation')
    cmd:option('-save',         'logs', 'save location')
    cmd:option('-custom',       '', 'custom model location')
    cmd:option('-dataset',      'mnist', 'mnist|cifar10')
    cmd:option('-trainSize',    10000, 'training set size')
    cmd:option('-expandPoint',  -1, 'point to double network')
    cmd:text()

    local opt = cmd:parse(arg or {})

    torch.manualSeed(opt.seed)
    return opt
end


------------------------------------------------------------------------
--                             Data Loader                            --
------------------------------------------------------------------------

local function load_data(opt)
    if opt.dataset == 'mnist' then
        local mnist     = require 'mnist'
        local trainData = mnist.traindataset()
        local testData  = mnist.testdataset()
        local data      = {}
        data['xr']      = trainData.data:float()[{{1,opt.trainSize},{},{}}]
        data['xe']      = testData.data:float()
        data['yr']      = (trainData.label + 1)
        data['yr']      = data['yr'][{{1,opt.trainSize}}]
        data['ye']      = testData.label + 1
        opt.outputDim   = 10
        opt.inputDim    = 784

        print(data['xr']:size())
        -- shuffle the training data
        local shuffle_idx = torch.randperm(data.xr:size(1),'torch.LongTensor')
        data.xr           = data.xr:index(1,shuffle_idx)
        data.yr           = data.yr:index(1,shuffle_idx)

        -- normalization
        local xMax = data.xr:max()
        data.xr:div(xMax)
        data.xe:div(xMax)

        -- validation set
        local nValid = math.floor(data.xr:size(1) * opt.validate)
        local nTrain = data.xr:size(1) - nValid
        data['xv']   = data.xr:sub(nTrain+1,data.xr:size(1))
        data['yv']   = data.yr:sub(nTrain+1,data.xr:size(1))
        data['xr']   = data.xr:sub(1,nTrain)
        data['yr']   = data.yr:sub(1,nTrain)
        return data
    elseif opt.dataset == 'cifar10' then
        path = 'cifar-10-batches-t7'
        dataTrain = Dataset.CIFAR(path, "train", 0)
        dataValid = Dataset.CIFAR(path, "valid", 0)
        dataTest = Dataset.CIFAR(path, "test", 0)
        local mean,std = dataTrain:preprocess()
        dataValid:preprocess(mean,std)
        dataTest:preprocess(mean,std)

        local data = {}
        data['xr'] = dataTrain.data
        data['xe'] = dataTest.data
        data['yr'] = dataTrain.labels
        data['ye'] = dataTest.labels
        data['xv'] = dataValid.data
        data['yv'] = dataValid.labels
        s1, s2, s3 = data['xv']:size(1),data['xr']:size(1),data['xe']:size(1)
        --print('Train: '..s2..' Test: '..s3..' Valid: '..s1)
        return data
    end
end

------------------------------------------------------------------------
--                           Configuration                            --
------------------------------------------------------------------------

local function optimConfig(opt)
    if opt.optimizer == 'sgd' then
        opt.optim_config = {
            learningRate          = opt.learningRate,
            learningRateDecay     = opt.decay_lr,
            weightDecay           = opt.l2reg,
            momentum              = opt.momentum
        }
        opt.optimizer = optim.sgd
    elseif opt.optimizer == 'lbfgs' then
        opt.optim_config = {
            learningRate          = 0.05,
            maxIter               = 10,
            nCorrection           = 10,
            verbose = true,
        }
        opt.optimizer = optim.lbfgs
    elseif opt.optimizer == 'cg' then
        opt.optim_config = {
            maxIter               = 5,
            verbose = true,
        }
        opt.optimizer = optim.cg
    end
end

local function createModel(opt)
    local prev = opt.inputDim

    -- model
    if opt.custom == '' then
        error('Please select a custom model to load')   
    else
        print 'custom'
        model = dofile('models/'..opt.custom..'.lua')

    end
    model:add(nn.LogSoftMax())

    -- loss function
    criterion   = nn.ClassNLLCriterion()

    -- transfer to cuda
    print(model)
    print(criterion)

    if opt.device == 'gpu' then
        model:cuda()
        criterion:cuda()
    end
    return model, criterion
end

------------------------------------------------------------------------
--                             Training                               --
------------------------------------------------------------------------

local function train(model, criterion, W, grad, data, opt)
    model:training()

    if opt.device == 'gpu' then
        inputs_gpu = torch.CudaTensor()
        targets_gpu = torch.CudaTensor()
    end

    local nTrain = data.xr:size(1)

    -- shuffle the data
    if opt.shuffle then
        local shuffle_idx = torch.randperm(nTrain,'torch.LongTensor')
        data.xr           = data.xr:index(1,shuffle_idx)
        data.yr           = data.yr:index(1,shuffle_idx)
    end

    -- Train minibatch
    for t = 1, nTrain, opt.batchSize do
        ------ Minibatch generation
        local idx     = math.min(t+opt.batchSize-1, nTrain)
        local inputs  = data.xr:sub(t,idx)
        local targets = data.yr:sub(t,idx)

        if opt.device == 'gpu' then
            -- copy data from cpu to gpu
            inputs_gpu:resize(inputs:size()):copy(inputs)
            targets_gpu:resize(targets:size()):copy(targets)
        end

        -- objective function for optimization
        function feval(x)
            assert(x==W)
            grad:zero() -- reset grads
            f = 0
            if opt.device == 'gpu' then
                local outputs  = model:forward(inputs_gpu)
                      f        = criterion:forward(outputs, targets_gpu)
                local df_dw    = criterion:backward(outputs, targets_gpu)
                model:backward(inputs_gpu, df_dw)
            else
                local outputs  = model:forward(inputs)
                      f        = criterion:forward(outputs, targets)
                local df_dw    = criterion:backward(outputs, targets)
                model:backward(inputs, df_dw)
            end

            f = f/opt.batchSize -- Adjust for batch size
            -- grad = grad/opt.batchSize -- CAN'T DO IT, this would go crazy
            --print(grad:mean())
            return f,grad
        end
        opt.optimizer(feval,W, opt.optim_config)
    end

end


------------------------------------------------------------------------
--                             Evaluation                             --
------------------------------------------------------------------------
local function evaluation(suffix, data, model, batchSize, confusion)

    if suffix ~= 'r' and suffix ~= 'e' and suffix ~= 'v' then
        error('Unrecognized dataset specified')
    end

    model:evaluate()

    local N     = data['x' .. suffix]:size(1)
    local err   = 0

    if opt.device == 'gpu' then
        inputs_gpu = torch.CudaTensor()
        targets_gpu = torch.CudaTensor()
    end

    for k = 1, N, batchSize do
        local idx         = math.min(k+batchSize-1,N)
        local inputs      = data['x' .. suffix]:sub(k,idx)
        local targets     = data['y' .. suffix]:sub(k,idx)

        if opt.device == 'gpu' then
            -- copy data from cpu to gpu
            inputs_gpu:resize(inputs:size()):copy(inputs)
            targets_gpu:resize(targets:size()):copy(targets)
            local outputs     = model:forward(inputs_gpu)
            confusion:batchAdd(outputs, targets_gpu)
        else
            local outputs     = model:forward(inputs)
            confusion:batchAdd(outputs, targets)
        end
    end

    confusion:updateValids()
    err    = 1 - confusion.totalValid
    confusion:zero()

    return err
end

local function reportErr(data, model, opt, confusion)
    local bestValid = math.huge
    local bestTest  = math.huge
    local bestTrain = math.huge
    local bestEpoch = math.huge
    local function report(t)
        local err_e = evaluation('e', data, model, opt.batchSize, confusion)
        local err_v = evaluation('v', data, model, opt.batchSize, confusion)
        local err_r = evaluation('r', data, model, opt.batchSize, confusion)
        print('---------------Epoch: ' .. t .. ' of ' .. opt.maxEpoch)
        print(string.format('%.4f | valid: %.4f | train: %.4f',
              err_e, err_v, err_r))
        if bestValid > err_v then
            -- Model that achieves the best validation error is considered the
            -- best model
            bestValid = err_v
            bestTrain = err_r
            bestTest = err_e
            bestEpoch = t
        end
        --print(string.format('Optima achieved at epoch %d: test: %.4f, valid: %.4f',
              --bestEpoch, bestTest, bestValid))
        all_errors[t] = {err_e*100, err_v*100, err_r*100}
        logger:add(all_errors[t])
        --if t%100 == 0 then
            --torch.save(paths.concat(paths.cwd(), 'params/', filename), {model,opt})
       -- end
    end
    return report
end



------------------------------------------------------------------------
--                            Main Function                           --
------------------------------------------------------------------------

local function main()
    opt = commandLine()

    filename = paths.concat(paths.cwd(), 'results/', string.format(opt.save))
    logger = optim.Logger(filename)
    logger.showPlot = flase
    if opt.device == 'gpu' then
        require 'cunn'
    end
    print(opt)
    torch.setdefaulttensortype('torch.FloatTensor')
    local data = load_data(opt)

    local nTrain = data.xr:size(1)
    opt.nBatches     = math.ceil(nTrain/opt.batchSize)
    local model, criterion = createModel(opt)
    local confusion     = optim.ConfusionMatrix(10)
    local W,grad        = model:getParameters()

    print('the number of paramters is ' .. W:nElement())


    local report = reportErr(data, model, opt, confusion)
    all_errors = {}
    for t = 1,opt.maxEpoch do
        ------------ Training Call

        if t == opt.expandPoint then
            print('Expanding at '..tostring(t))
            local n2n = require 'net2net'
            n2n.wider(model,2,4,150)
            print(model)
            opt.optim_config['learningRate'] = 5e-4
        end

        timer = torch.Timer()
        optimConfig(opt)
        train(model, criterion, W, grad, data, opt) -- performs a single epoch
        ------------ Report Errors
        report(t)

        collectgarbage()
        print(timer:time().real)
    end
    --torch.save(paths.concat(paths.cwd(), 'params/', filename), model)
end

main()

Model:

 require 'nn'
 require 'torch'
 require 'nnx'
 local model = nn.Sequential()
 model:add(nn.Reshape(784))

 model:add(nn.Probe('before first'))
 model:add(nn.Linear(784,100))
 model:add(nn.ReLU())
 model:add(nn.Probe('after first'))

 model:add(nn.Linear(100,10))
 model:add(nn.Probe('after second'))

 return model

from net2net.torch.

dmsedra avatar dmsedra commented on July 20, 2024

Oh, the only line where I do expansion is 341 in the code.

from net2net.torch.

Related Issues (2)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.