GithubHelp home page GithubHelp logo

Comments (7)

jessevig avatar jessevig commented on August 25, 2024 1

I'm still not able to recreate the issue for the latest version of the repo, unfortunately.

I created a colab notebook where I save a pre-trained model and the loaded the saved state_dict, and I am able to visualize the weights: https://colab.research.google.com/drive/1sgwcna3HhIbCqqIgYY3MxtoV7PWfpzF8.

Perhaps you could take a look and see if there is a difference with what you're doing? Otherwise, if you're willing to share your .pth file, I could try to debug what is going on.

Also, there could be an issue with the fine-tuned model that is causing it to give roughly equal attention weights. Do you see that in other layers as well? (besides layer 0 as in image?) Also, what do you see when you select "All" instead of "Sentence A -> Sentence A"?

from bertviz.

jessevig avatar jessevig commented on August 25, 2024

Hi, thanks for reporting this. As far as the issue of BertForSequenceClassification not returning attentions, where did you import BertForSequenceClassification from? If not imported from the special forked version of pytorch-transformers (in pytorch_transformers_attn directory) this can cause issues, so wanted to rule that out first.

from bertviz.

chikubee avatar chikubee commented on August 25, 2024

@jessevig I've used BertForSequenceClassification from the forked version only.
I tried to initiate the BertModel with output_attentions=True as well, but something is still wrong.
The attention scores for a token are scaled up for the fine tuned model in comparison to the pretrained bert uncased model with their scores being 0.06/0.07.
The problem still persists.
Thanks in advance for any leads.

from bertviz.

jessevig avatar jessevig commented on August 25, 2024

Good catch! There was indeed a bug and it has been fixed, so the tool now supports BertForSequenceClassification.

I believe I've loaded a pretrained model in the past in the following fashion:

model = BertForSequenceClassification(model_dir, num_classes=num_classes)

where model_dir contains config.json and pytorch_model.bin files. But this may not be appropriate in your case.

Please let me know if this works! Thanks again for catching this.

from bertviz.

chikubee avatar chikubee commented on August 25, 2024

@jessevig Sorry, doesn't work for me.
What I am trying to do is basically visualize a BERT model fine tuned for sentiment analysis task.

If I load the model like
model_state_dict = torch.load('models/bert_model_30_07_part3.pth')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', state_dict=model_state_dict, output_attentions=True, num_classes=3)

The issue with this is that it doesn't return the attention scores. So visualization is not possible.

If I load the model like
model = BertModel.from_pretrained('bert-base-uncased', state_dict=model_state_dict, output_attentions=True)
The result is as reported.

Can you tell me how exactly can I achieve this?

from bertviz.

chikubee avatar chikubee commented on August 25, 2024

Thanks a lot @jessevig this works.
I don't understand what lead to the issue though.

The changes I did after referring to your notebook was saved the model as an instance of BertForSequenceClassification loaded from pytorch transformers, while loaded it as an instance of the same class from the forked pytorch_transformers_attn.

Thanks a lot for your help.

from bertviz.

jessevig avatar jessevig commented on August 25, 2024

Glad it worked eventually. Will keep an eye out for similar issues to see if can get to root cause. Thanks again.

from bertviz.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.