GithubHelp home page GithubHelp logo

Fine-tuning about crnn-pytorch HOT 6 CLOSED

holmeyoung avatar holmeyoung commented on June 19, 2024
Fine-tuning

from crnn-pytorch.

Comments (6)

Holmeyoung avatar Holmeyoung commented on June 19, 2024 1

Hi, you should change train.py to

import torch.nn as nn
from models.crnn import BidirectionalLSTM
def net_init():
    nclass_pre = 11 # the nclass of your pre-trained model:  = len(params.alphabet--pre version) + 1
    nclass = len(params.alphabet) + 1
    crnn = net.CRNN(params.imgH, params.nc, nclass_pre, params.nh)
    crnn.apply(weights_init)
    if params.pretrained != '':
        print('loading pretrained model from %s' % params.pretrained)
        if params.multi_gpu:
            crnn = torch.nn.DataParallel(crnn)
        crnn.load_state_dict(torch.load(params.pretrained))
    
    crnn.rnn = nn.Sequential(
            BidirectionalLSTM(512, params.nh, params.nh),
            BidirectionalLSTM(params.nh, params.nh, nclass))
    return crnn

from crnn-pytorch.

Holmeyoung avatar Holmeyoung commented on June 19, 2024

Hi,

  1. load the pre-trained model use the same net as the pre-trained model
crnn = net.CRNN(params.imgH, params.nc, nclass, params.nh)

here, the nclass should't equal to len(params.alphabet) + 1, it should be the classed number of pre-trained model.

  1. change the last layer to yourself.
crnn.rnn = nn.Sequential(
            BidirectionalLSTM(512, nh, nh),
            BidirectionalLSTM(nh, nh, nclass))

nclass = len(params.alphabet) + 1

from crnn-pytorch.

niddal-imam avatar niddal-imam commented on June 19, 2024

Thanks for the quick response.
Now I should first load the pre-trained model by changing the params.py
pretrained = 'path/to/my/pre-trained'
But I did not understand the second point. What should I change?

Thanks

from crnn-pytorch.

Holmeyoung avatar Holmeyoung commented on June 19, 2024

After load the model, change the rnn layer.

from crnn-pytorch.

niddal-imam avatar niddal-imam commented on June 19, 2024

Thank you Holmeyoung. Should I change crnn.py:
self.cnn = cnn self.rnn = nn.Sequential( BidirectionalLSTM(512, nh, nh), BidirectionalLSTM(nh, nh, nclass))
to
crnn.rnn = nn.Sequential( BidirectionalLSTM(512, nh, nh), BidirectionalLSTM(nh, nh, nclass))?

from crnn-pytorch.

niddal-imam avatar niddal-imam commented on June 19, 2024

Thank you very much.

from crnn-pytorch.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.