GithubHelp home page GithubHelp logo

Comments (6)

davidefiocco avatar davidefiocco commented on August 25, 2024 2

Hi @NarineK for the helpful reply!

Indeed that return torch.softmax(preds, dim = 1)[0][1].unsqueeze(0) solved!
(you have an extra ] though!)

Here's a few more changes that I tried starting from the SQUAD tutorial and adapt it to the binary task:

  • I compute just one set of attributions (not sure if I should pass 0 or 1 in the additional_forward_arg tuple though...):
attributions, delta = lig.attribute(inputs=input_ids,
                                  baselines=ref_input_ids,
                                  additional_forward_args=(token_type_ids, position_ids, attention_mask, 0), # revise this
                                  return_convergence_delta=True)

I then just have one attribution sum

attributions_sum = summarize_attributions(attributions)

  • I compute a single score with
score = predict(input_ids, token_type_ids=token_type_ids, \
                                   position_ids=position_ids, \
                                   attention_mask=attention_mask)
  • I then display results with
score_vis = viz.VisualizationDataRecord(
                        attributions_sum,
                        torch.max(torch.softmax(score[0], dim=0)),
                        torch.argmax(score[0]),  # revise this, not sure about it
                        torch.argmax(score[0]),  # revise this, not sure about it
                        text,
                        attributions_sum.sum(),       
                        all_tokens,
                        delta)

print('\033[1m', 'Visualization For Score', '\033[0m')
viz.visualize_text([score_vis])

This allowed me to display some output:

image

but I am not fully convinced that all of the above is OK (the interpretation is a bit tricky to digest, as I have finetuned BERT on the GLUE CoLA task), so if anybody has some feedback it's much appreciated!

PS: You find my current notebook at this gist: https://gist.github.com/davidefiocco/3e1a0ed030792230a33c726c61f6b3a5

from captum.

NarineK avatar NarineK commented on August 25, 2024 1

Thank you @davidefiocco! Glad that you find it useful. From the error it looks like the inputs that you are trying to attribute to aren't used in the forward pass.

Actually in custom_forward: why are you creating a new tensor: torch.tensor([torch.softmax(preds, dim = 1)[0][1]], requires_grad = True)
I don't think that it is necessary:

torch.softmax(preds, dim = 1)[0][1]].unsqueeze(0)

should do it too.

from captum.

NarineK avatar NarineK commented on August 25, 2024

@davidefiocco, the 0 or 1 additional forward arg indices where specifically for the SQUAD model because that model returns a tuple of 2 tensors, one is for the prediction probability for the start index and the other one is the prediction probability for the last index.
In the binary classification case my forward function looks something like this:

def custom_forward(*inputs):
    out = model(*inputs)[0]
    return out

from captum.

davidefiocco avatar davidefiocco commented on August 25, 2024

Thanks again @NarineK for your kind replies :)

I edited the notebook in https://gist.github.com/davidefiocco/3e1a0ed030792230a33c726c61f6b3a5 so to use a custom_forward more similar to yours:

def custom_forward(inputs):
    out = model(inputs)[0][0]
    return out

so my call to lig.attributes also gets simplified:

attributions, delta = lig.attribute(inputs=input_ids,
                                  baselines=ref_input_ids,
                                  return_convergence_delta=True)

Not sure this is ready for PR(ime) time and is working correctly, but I am glad to share it here in case that's helpful for somebody else. Let me know if you have additional feedback!

from captum.

NarineK avatar NarineK commented on August 25, 2024

@davidefiocco , looks good to me! It terms of custom_function, you can play with it and see what is it actually returning. I used it for binary classification but it can be different from model to model

from captum.

davidefiocco avatar davidefiocco commented on August 25, 2024

Thanks you, I will close this issue then. Should any follow-up arise, I will open a new issue (I am not fully convinced my solution works well/may want to explore other interpretation schemes beyond LayerIntegratedGradients. Thanks!

from captum.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.