GithubHelp home page GithubHelp logo

Comments (9)

NarineK avatar NarineK commented on June 15, 2024 2

@ngopee , can it be that the output of your model is one dimensional and when you try to access target=1 it can't access that index of output?Can you try with target=0 or without specifying the target ?

from captum.

NarineK avatar NarineK commented on June 15, 2024 1

Hi @ngopee, does test_input_tensor contain token indices ?
If you want to use LayerConductance you need to configure the InterpretableEmbeddingBase.
LayerConductance doesn't have the logic of replacing the indices with embedding vectors.
you'll need to call:

interpretable_embedding = configure_interpretable_embedding_layer(model, '<EMBEDDING-LAYER>')
input_embeddings = interpretable_embedding.indices_to_embeddings(test_input_tensor)
cond = LayerConductance(model, model.convs)
cond_vals = cond.attribute(input_embeddings, target=1)
remove_interpretable_embedding_layer(model, input_embeddings)

You'll find some examples here:
https://captum.ai/tutorials/Multimodal_VQA_Captum_Insights
https://captum.ai/tutorials/Bert_SQUAD_Interpret

from captum.

NarineK avatar NarineK commented on June 15, 2024 1

@ngopee , did it work for you ? If so, can we close the issue ?

from captum.

ngopee avatar ngopee commented on June 15, 2024 1

Tried with both target = 0 and without any. Still same error.

This is the tutorial I am trying to run and get the layer conductance: https://captum.ai/tutorials/IMDB_TorchText_Interpret

from captum.

NarineK avatar NarineK commented on June 15, 2024 1

It can be that model.convs is a moduleList, would you try: model.convs[0] instead ? I'll also give a try.

from captum.

NarineK avatar NarineK commented on June 15, 2024 1

@ngopee , I tried this code snippet and it worked for me:

# accumalate couple samples in this array for visualization purposes
vis_data_records_ig = []

lc = LayerConductance(model, model.convs[0])


def interpret_sentence(model, sentence, min_len = 7, label = 0):
    text = [tok.text for tok in nlp.tokenizer(sentence)]
    if len(text) < min_len:
        text += ['pad'] * (min_len - len(text))
    indexed = [TEXT.vocab.stoi[t] for t in text]

    model.zero_grad()

    input_indices = torch.tensor(indexed, device=device)
    input_indices = input_indices.unsqueeze(0)
    
    # input_indices dim: [sequence_length]
    seq_length = min_len

    # predict
    pred = forward_with_sigmoid(input_indices).item()
    pred_ind = round(pred)

    # generate reference indices for each sample
    reference_indices = token_reference.generate_reference(seq_length, device=device).unsqueeze(0)

    # compute attributions and approximation delta using layer integrated gradients
    #attributions_ig, delta = lig.attribute(input_indices, reference_indices, \
    #                                       n_steps=500, return_convergence_delta=True)
    interpretable_embedding = configure_interpretable_embedding_layer(model, 'embedding')
    input_embeddings = interpretable_embedding.indices_to_embeddings(input_indices)
    attributions_ig, delta = lc.attribute(input_embeddings, return_convergence_delta=True)
    remove_interpretable_embedding_layer(model, interpretable_embedding)
    
    print('pred: ', Label.vocab.itos[pred_ind], '(', '%.2f'%pred, ')', ', delta: ', abs(delta))

    add_attributions_to_visualizer(attributions_ig, text, pred, pred_ind, label, delta, vis_data_records_ig)
    
def add_attributions_to_visualizer(attributions, text, pred, pred_ind, label, delta, vis_data_records):
    attributions = attributions.sum(dim=2).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    attributions = attributions.cpu().detach().numpy()

    # storing couple samples in an array for visualization purposes
    vis_data_records.append(visualization.VisualizationDataRecord(
                            attributions,
                            pred,
                            Label.vocab.itos[pred_ind],
                            Label.vocab.itos[label],
                            Label.vocab.itos[1],
                            attributions.sum(),       
                            text,
                            delta))

Do you have the latest version of captum ? Maybe that can be a reason for inconsistency

from captum.

ngopee avatar ngopee commented on June 15, 2024

@NarineK Thank you so much for this.

I am actually getting this error when I tried the code above:
IndexError: index 1 is out of bounds for dimension 1 with size 1

Spent whole day trying to figure out if it wasn't a silly mistake. Your help is very much appreciated!

#Setting layer conductance parameters: Conv layer
cond = LayerConductance(model, model.convs)
interpretable_embedding = configure_interpretable_embedding_layer(model, 'embedding')

def compute_input_tensor(model, sentence, min_len = 7, label = 0):
    text = [tok.text for tok in nlp.tokenizer(sentence)]
    if len(text) < min_len:
        text += ['pad'] * (min_len - len(text))
    indexed = [TEXT.vocab.stoi[t] for t in text]
    print(indexed)

    model.zero_grad()

    input_indices = torch.tensor(indexed, device=device)
    input_indices = input_indices.unsqueeze(0)
    
    
    input_embeddings = interpretable_embedding.indices_to_embeddings(input_indices)


    cond_vals = cond.attribute(input_embeddings,target=1)
    
    return input_indices


compute_input_tensor(model, 'It was a fantastic performance !', label=1)

from captum.

ngopee avatar ngopee commented on June 15, 2024

Doesn't seem to be it.

Thank you for the help!

from captum.

ngopee avatar ngopee commented on June 15, 2024

Yes this works perfectly! I still have no clue what I missed. I'll look into it and post a follow up when I figure it out.

Thank you very much!!

from captum.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.