graph-com / gsat Goto Github PK
View Code? Open in Web Editor NEW[ICML 2022] Graph Stochastic Attention (GSAT) for interpretable and generalizable graph learning.
Home Page: https://arxiv.org/abs/2201.12987
License: MIT License
[ICML 2022] Graph Stochastic Attention (GSAT) for interpretable and generalizable graph learning.
Home Page: https://arxiv.org/abs/2201.12987
License: MIT License
Hi, thanks for organizing the code so neatly, I am wondering if the code applicable for node classification task. If yes, can you please point out which part should be changed?
When generating the subgraph, GSAT samples from the Bernoulli(sigmoid(att_log_logits))
to get the "soft" value of the Bernoulli samples (it seems to be a continuous value rather than discrete 1/0).
I don't quite understand how Gumbel-Softmax is applied here (
Line 308 in b3d798f
-log(-log(u))
and u is from uniform(0,1) to be from Gumbel's. How does Gumbel softmax is applied?Thanks for your amazing work! But where does the code reflect the process of obtaining subgraphs (Gs) through sampling.
If you see this issue, please tell me the answer.Thanks in advance!
No problem now, thank you very much.
Hello, me again !
I read in your paper that your infoloss should be based on the distribution of the subgraphs knowing the original graph and the parameters.
However, in your code, in order, you 1) compute this distribution in logits, 2) sample with a gumbel-softmax trick, and 3) apply the infoloss on the sampled subgraph. From my understanding, you should rather 1) compute the distribution in logits, 2) transform the logits into probabilities, using the same temperature as in the gumbel-softmax code, 3) apply the infoloss on that distribution, and 4) do your gumbel-softmax trick on the logits to be used in other parts of the code.
Mathematically, I think what you do bring a lot of noise in the infoloss back-propagated gradients, and I would expect the loss to be more efficient and clean if you follow the order I propose. That is, apply the infoloss on (att_log_logits / temp).sigmoid()
(with temp
set to 1 in your code) rather than on self.sampling(att_log_logits, epoch, training)
.
What do you think? Have I missed something?
I would love to read your opinion on the matter.
ps: Thanks again for your paper and your reactivity to my previous issues!
Just wanted to say that the paper is amazing. Thanks for publishing it and the code.
I am trying to generate the mnist75 dataset by running: ./scripts/prepare_data.sh and I am getting the following stacktrace:
Fr Dez 16 17:57:00 CET 2022
start time: 2022-12-16 17:57:01.936994
dataset mnist
data_dir ./data
out_dir ./data
split train
threads 0
n_sp 75
compactness 0.25
seed 111
/home/mada/anaconda3/lib/python3.9/site-packages/skimage/_shared/utils.py:338: FutureWarning: multichannel is a deprecated argument name for slic. It will be removed in version 1.0. Please use channel_axis instead.
warnings.warn(self.warning_msg.format(
Traceback (most recent call last):
File "graph_attention_pool/extract_superpixels.py", line 128, in
sp_data.append(process_image((images[i], i, n_images, args, True, True)))
File "graph_attention_pool/extract_superpixels.py", line 55, in process_image
assert n_sp_extracted == np.max(superpixels) + 1, ('superpixel indices', np.unique(superpixels)) # make sure superpixel indices are numbers from 0 to n-1
AssertionError: ('superpixel indices', array([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34,
35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51,
52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68,
69, 70, 71]))
Do you know what the issue might be?
Thank you in advance!
Hi all, I tried to implement the info loss in my own GNN. I am using a custom convolution in a custom dataset that might have leakage, so this might be the source of error. But I am trying to understand why the model would behave the way it is behaving. I would appreciate any ideas/feedback.
My model is for link prediction on small subgraphs, where each for each edge I wanna predict, I sample a subgraph around it.
I am implementing the info_loss just like in your code:
info_loss = (edge_att * torch.log(edge_att/r + 1e-6) + (1-edge_att) * torch.log((1-edge_att)/(1-r+1e-6) + 1e-6)).mean()
If I don't use any sort of info loss, when I train my model, my edge attention looks like this:
If I use l1 loss (just minimizing edge_att.mean()), my edge attention looks like this:
If I use l1 loss, but multiply by 1e-3, it looks like this:
However, if I use the info loss proposed in your paper, my edge attention agglutinates in values close to r. For example, for r = 0.3, I get the attention distribution below. If I use r=0.5, then the dense part of the historgram moves to the middle, and if I use something like r=0.7 or r=0.9, then all my attention weights are closer to 1.
I tried to understand the intuition behind it by plotting the curve att x info_loss for different values of r
So basically the info_loss is approximately zero when closer to r, and positive everywhere else. This is forcing my model to try to have the attention always close to r (which I am not sure if I understand why), and apparently this is exactly what my model is doing. What confuses me is that in your paper, r is recommended to be between 0.5 and 0.9. However, in my current setting, this forces the majority of my edge attention to be > 0.5, instead of making them sparse.
I wonder if I am doing something wrong, if info_loss should have a smaller weight, or if my concrete_sampler should have a higher temperature to force a bernoulli-like distribution, or if maybe my model simply doesnt really need the edges, and it is ok with using any edge_attention value, hacking a way to get the same solution just based on node embedding, for example, without message passing. Maybe I have excess of dropout during training? (I do both node and edge dropout).
Please let me know if you have any ideas. Thanks in advance!
Hello,
I also have a second question (see #5 for the first one).
When I run your code on ba_2motifs, I obtain the following graph:
Fortunately, the validation set have the exact same behaviour as the test set, so when the validation set accuracy is high, the test set accuracy is high too, hence the good results.
Is it expected that it is so unstable? What could I do to avoid that?
Hello,
First of all, thank you for your paper and your code, it is a pleasure to work with it.
However, I have a question about the following line :
Line 79 in ea900db
edge_att = (att + att) / 2
.when i run the last
all_viz_set = get_viz_idx(test_set, dataset_name, num_viz_samples)
visualize_results(gsat, all_viz_set, test_set, num_viz_samples, dataset_name, model_config['use_edge_attr'])
0%| | 0/10 [00:00<?, ?it/s]
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
[<ipython-input-18-4b79db204498>](https://localhost:8080/#) in <cell line: 6>()
4 all_viz_set = get_viz_idx(test_set, dataset_name, num_viz_samples)
5
----> 6 visualize_results(gsat, all_viz_set, test_set, num_viz_samples, dataset_name, model_config['use_edge_attr'])
1 frames
[/usr/local/lib/python3.10/dist-packages/torch_geometric/utils/subgraph.py](https://localhost:8080/#) in subgraph(subset, edge_index, edge_attr, relabel_nodes, num_nodes, return_edge_mask)
97 edge_mask = node_mask[edge_index[0]] & node_mask[edge_index[1]]
98 edge_index = edge_index[:, edge_mask].to('cpu')
---> 99 edge_attr = edge_attr[edge_mask] if edge_attr is not None else None
100
101 if relabel_nodes:
RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu)
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.