GithubHelp home page GithubHelp logo

edge attributes about diffpool HOT 9 OPEN

rexying avatar rexying commented on August 23, 2024
edge attributes

from diffpool.

Comments (9)

dtchang avatar dtchang commented on August 23, 2024

I tried using the PyTorch Geometric implementation, dense_diff_pool(). When A contains edge attributes as an additional dimension, the following computation (using torch.matmul()) throws an exception:
A(l+1)=S(l)^A(l)S(l),
where S(l)^ is the transpose of S(l).

from diffpool.

RexYing avatar RexYing commented on August 23, 2024

i didn't use edge attributes here. should be easy to add. If categorical, you can have a different weight matrix for each edge type, and sum the messages for all edge types after every layer. if continuous, you can compute messages similar to GAT.

from diffpool.

RexYing avatar RexYing commented on August 23, 2024

I tried using the PyTorch Geometric implementation, dense_diff_pool(). When A contains edge attributes as an additional dimension, the following computation (using torch.matmul()) throws an exception:
A(l+1)=S(l)^A(l)S(l),
where S(l)^ is the transpose of S(l).

Could you check the dimension of S and A? S should be num_nodes x num_next_level_nodes; A should be num_nodes x num_nodes

from diffpool.

dtchang avatar dtchang commented on August 23, 2024

When A contains multiple edge attributes (many datasets, including mine, have such), its size is num_nodes x num_nodes x num_edge_attrs. I also reported this as an issue with PyG. The owner said dense_diff_pool() doesn't support multiple edge attributes, only a single edge attribute / weight. It would be important to add such support. It would be nice if you can help.

from diffpool.

RexYing avatar RexYing commented on August 23, 2024

Thanks for the suggestions!

For now i could think of using pytorch batch matmul

A_perm = A.permute(2, 0, 1)
S_perm = S.unsqueeze(0)
S_perm_T = S.T.unsqueeze(0)

A_next_level = S_perm_T @ A_perm @ S_perm
A_next_level = A_next_level.permute(1, 2, 0)

This should give the next level adj of shape [num_clusters x num_clusters x num_edge_attrs]

Wonder if this works?
This assumes a single clustering that takes into account of clustering the multi-edge-attr graphs.

from diffpool.

dtchang avatar dtchang commented on August 23, 2024

Per your suggestion, I made the following changes in dense_diff_pool():
if adj.dim() == 3:
out_adj = torch.matmul(torch.matmul(s.transpose(1, 2), adj), s)
else: # adj.dim() == 4
adj_perm = adj.permute(0, 3, 1, 2)
s_perm = s.unsqueeze(1)
s_t = s.transpose(1, 2)
s_t_perm = s_t.unsqueeze(1)
out_adj_perm = torch.matmul(torch.matmul(s_t_perm, adj_perm), s_perm)
out_adj = out_adj_perm.permute(0, 2, 3, 1)
That works out fine. Thanks much.

However, link_loss now calculation throws RuntimeError when adj.dim() = 4:
link_loss = adj - torch.matmul(s, s.transpose(1, 2))

What changes should I make?

from diffpool.

dtchang avatar dtchang commented on August 23, 2024

The following changes do away RuntimeError:
if adj.dim() == 4:
adj = adj.unbind(3)[0]

This would produce good link_loss if the first edge attribute (type) is edge weight. Is there a better way?

from diffpool.

RexYing avatar RexYing commented on August 23, 2024

Does this only use one of the edge type dimensions? You can also have weight matrices in graph conv layer to have an extra dimension corresponding to the edge attribute dimension. So that you don't have to unbind and only use one dimension.

from diffpool.

haojiang1 avatar haojiang1 commented on August 23, 2024

Have you guys finished this question?
I want a demo code to input my edge's weight.(1 D)

from diffpool.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.