GithubHelp home page GithubHelp logo

Comments (7)

wise-east avatar wise-east commented on July 22, 2024

I realized that the dimension size for ['lm_head.weight'] that the model expects and 'lm_head.decoder.weight' contained in the pretrained model file (pytorch_model.bin) are the same and there must have been a change in key name from the migration for pytorch_pretrained_bert to pytorch_transformers, so I renamed the key lm_head.decoder.weight in the state_dict loaded from the pretrained weights downloaded from the github repo's link to lm_head.weight by adding

            if key == 'lm_head.decoder.weight': 
                new_key = 'lm_head.weight'

to modeling_utils.py in the pytorch_transformers package:

        for key in state_dict.keys():
            new_key = None
            if 'gamma' in key:
                new_key = key.replace('gamma', 'weight')
            if 'beta' in key:
                new_key = key.replace('beta', 'bias')

            # needed for loading weights for interact.py in transfer-learning-conv-ai 
            if key == 'lm_head.decoder.weight': 
                new_key = 'lm_head.weight'
                
            if new_key:
                old_keys.append(key)
                new_keys.append(new_key)

So I'm no longer getting the following message:

INFO:pytorch_transformers.modeling_utils:Weights of OpenAIGPTLMHeadModel not initialized from pretrained model: ['lm_head.weight']

and also lm_head.decoder.weight is removed from the list of weights not loaded.

To my dismay, the model when run with interact.py still return the nonsensical response as given in the original post:

>>> hi
<unk>? <unk><unk><unk><unk>? <unk><unk><unk>
>>> say something that makes sense
<unk><unk><unk>? <unk><unk>

from transfer-learning-conv-ai.

martinritchie avatar martinritchie commented on July 22, 2024

Hey, I have run into the same problem. I am using your codebase, thank you for for making the adaptation.

In train.py:

I believe you are adding the special tokens incorrectly. Your SPECIAL_TOKEN dict currently looks like:

SPECIAL_TOKENS = {"bos_token": "<bos>", 
                  "eos_token": "<eos>",
                  "speaker1_token": "<speaker1>", 
                  "speaker2_token": "<speaker2>",
                  "pad_token": "<pad>"}

The additional_special_tokens have been added incorrectly, it should look like this:

SPECIAL_TOKENS = {"bos_token" : "<bos>", 
                  "eos_token" : "<eos>", 
                  "additional_special_tokens" : ["<speaker2>", "<speaker1>"],
                  "pad_token" : "<pad>" } 

I am unsure of a neat way to unpack that dict so I also define:

special_tokens = ['<bos>', '<eos>', '<speaker1>', '<speaker2>']

that needs to be used as input for tokenizer.convert_tokens_to_ids(...).

In interact.py similar modifications need to made.

  1. from train import special_tokens,
  2. tokenizer.add_special_tokens(SPECIAL_TOKENS) after the tokenizer has been initialised,
  3. Add special_tokens_ids = tokenizer.convert_tokens_to_ids(special_tokens) to the function sample_sequence(...)

Slightly off topic: lines 136-141, seem to be incorrectly commented out, is this for debugging purposes?

I think that is all the changes I have made. Let me know how you get on.

Edit: Are you going to submit a pull request to the original repo?

from transfer-learning-conv-ai.

wise-east avatar wise-east commented on July 22, 2024

Thanks for the feedback. I'll incorporate your notes into my forked repo!
I'm just curious as to why '<pad>' is left out for special_tokens = ['<bos>', '<eos>', '<speaker1>', '<speaker2>']. Is it just not necessary?

For lines 136-141, I've commented those lines out for my own purposes for which I just wanted to make use of a longer history sequence rather than use arbitrary personalities.

As an update, I was able to make the entire thing work after finetuning the pretrained weights that you can get as illustrated in my README file with the current code (prior to the update that I am about to make with your notes). I feel like there was something wrong with the final LM weights in the pretrained model that just needed to be retrained/finetuned with whatever dataset you have.

I'll work through making the corrections and look into if I can make a pull request without causing too much of a hassle. Otherwise, I'll just leave an issue open with a link to my forked repo for others to refer for an updated version.

from transfer-learning-conv-ai.

sshleifer avatar sshleifer commented on July 22, 2024

I didn't see this and did a super similar #29. Checkpoint with renamed weights: https://www.dropbox.com/s/bt6n0kyqsrnyx3e/gpt_personachat_cache.tar.gz?dl=0

from transfer-learning-conv-ai.

wise-east avatar wise-east commented on July 22, 2024

@sshleifer Glad to see someone else wanting to update this repo. It seems like most of our adjustments are the same, although I think you pointed out some points that I've misunderstood or overlooked. Thanks.

from transfer-learning-conv-ai.

martinritchie avatar martinritchie commented on July 22, 2024

I'm just curious as to why '<pad>' is left out for special_tokens = ['<bos>', '<eos>', '<speaker1>', '<speaker2>']. Is it just not necessary?

I can't quite remember, I think that <pad> is already included as default now.

from transfer-learning-conv-ai.

julien-c avatar julien-c commented on July 22, 2024

Fixed by #29

from transfer-learning-conv-ai.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.