GithubHelp home page GithubHelp logo

Comments (5)

joshim5 avatar joshim5 commented on August 30, 2024 1

Hi @brucejwittmann, to quickly answer your questions:

  1. No, not for pretraining
  2. No
    You asked about <pad> tokens in the batch_converter, so I brought up ignore_index as a warning just in case you implement it in a way where masks could be introduced on <pad> tokens. However, the design plan you just described sounds great and ignore_index won't be needed in that case.

from esm.

joshim5 avatar joshim5 commented on August 30, 2024

Hi @wjs20, you don't need to specify modelWithLMHead. Following our README, simply do:

import torch
model = torch.hub.load("facebookresearch/esm", "esm1_t34_670M_UR50S")  

Then, when you call result = model(...) or result = model.forward(), you'll get a dictionary where result['logits'] contains the output from the language modeling head. To fine-tune ESM with a language modeling head, you should setup your loss with respect to that output.

from esm.

brucejwittmann avatar brucejwittmann commented on August 30, 2024

Thanks for making these models available. They are very useful!

I'm also trying to further fine-tune the provided models using masked language modeling for a specific protein family and have just a few questions on the training procedure (commenting here because they pertain to your previous answer):

  1. I noticed that <eos> is an available token in the alphabet provided, but calling batch_converter(data) returns tokens with only <cls> and <pad>. Should the <eos> token be added to the end before feeding to the model, or was this an unused token for language modeling during initial training?
  2. The accompanying publication states "Our model was pre-trained using a context size of 1024 tokens". Should <pad> tokens be appended to sequences that would otherwise be shorter than 1024 tokens? The batch_converter only adds padding tokens up to the length of the longest sequence in the batch.
  3. Based on the output, I believe this is the case, but just to confirm: Are the outputs given by results["logits"] unnormalized scores? In other words, can they be passed directly into an instance of torch.nn.CrossEntropyLoss() without further modification?

from esm.

joshim5 avatar joshim5 commented on August 30, 2024

Hi @brucejwittmann, thanks for the great questions!

  1. We did not use the <eos> token during pretraining for the ESM-1 models.
  2. No, you don't need to append <pad> tokens. The batch_converter only adds these tokens so that sequences of different length can be included in the same batch. When we pre-train the models, the loss function ignores any <pad> positions. This means that we get the same loss for any arbitrary number of <pad> tokens. For sequences longer than 1023 tokens, we used a random crop of 1023 tokens and then pre-pended a <cls> token, for a total of 1024 tokens.
  3. Yes! They are unnormalized logits that can be passed to torch.nn.CrossEntropyLoss(). Just make sure to ignore pad tokens with ignore_index so that they don't contribute to the loss.

from esm.

brucejwittmann avatar brucejwittmann commented on August 30, 2024

Hi @joshim5 , thanks for your quick and helpful response! Just to clarify on the use of ignore_index: My understanding from your paper is that loss was calculated for predictions made for masked tokens only. Does this mean that <pad> tokens were sometimes the ones that were masked? I was planning to design my masking function such that it never masks a padding token (in other words, it knows the length of each given protein and just masks amino acid tokens). If I were to do that, my understanding is that ignore_index wouldn't be needed as <pad> could never be a target. I suppose I have a few follow-up questions, then:

  1. Was loss calculated against more than the masked tokens in your original work?
  2. Were <pad> tokens masked in the original work? If so, is this because there is a downside to restricting padding to amino-acid tokens only?

Thanks again!

from esm.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.