GithubHelp home page GithubHelp logo

lsh0520 / rgcl Goto Github PK

View Code? Open in Web Editor NEW
36.0 2.0 2.0 7.36 MB

Ratioanle-aware Graph Contrastive Learning codebase

Python 99.91% Shell 0.09%
explainability generalization graph-contrastive-learning icml2022 interpretability invariant-learning graph-pretraining invariant-rationale-discovery

rgcl's Introduction

Let Invariant Discovery Inspire Graph Contrastive Learning

This is our PyTorch implementation for the paper:

Sihang Li, Xiang Wang*, An Zhang, Ying-Xin Wu, Xiangnan He and Tat-Seng Chua (2022). Let Invariant Rationale Discovery Inspire Graph Contrastive Learning, Paper in arXiv. In ICML'22, Baltimore, Maryland, USA, July 17-23, 2022.

Author: Sihang Li (sihang0520 at gmail.com)

Introduction

Without supervision signals, Rationale-aware Graph Contrastive Learning (RGCL) uses a rationale generator to reveal salient features about graph instance-discrimination as the rationale, and then creates rationale-aware views for contrastive learning. This rationale-aware pre-training scheme endows the backbone model with the powerful representation ability, further facilitating the fine-tuning on downstream tasks.

Citation

If you want to use our codes and datasets in your research, please cite:

@inproceedings{RGCL,
  author    = {Sihang Li and
               Xiang Wang and
               An Zhang and
               Xiangnan He and
               Tat-Seng Chua},
  title     = {Let Invariant Rationale Discovery Inspire Graph Contrastive Learning},
  booktitle = {{ICML}},
  year      = {2022}
}

Experiments

  • Transfer Learning on MoleculeNet datasets
  • Semi-supervised learning on Superpixel MNIST dataset
  • Unsupervised representation learning on TU datasets

Potential Issues

Some issues might occur due to the version mismatch.

Acknowledgements

The backbone implementation is reference to https://github.com/Shen-Lab/GraphCL.

rgcl's People

Contributors

ljy0ustc avatar lsh0520 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

rgcl's Issues

Datasets

Hello, How can I get the datasets?

Why the results of AD-GCL are different from its original paper (请问没什么AD-GCL的结果不是原文的结果)

Meanwhile, why not adopt JOAO as baseline, as you employ the AD-GCL and AD-GCL made comparison with JOAO.(同时,为什么不和JOAO做比较,毕竟你引用了AD-GCL那么应该知道JOAO的存在的)

The results of JOAO and AD-GCL in unsupervised learning with non-linear SVM are attached below: (下面附上JOAO和AD-GCL的无监督原文(非线性SVM)结果):

Data NCI1 PROTEINS DD MUTAG COLLAB RDT-B RDT-M5K IMDB-B
JOAO 78.07 74.55 77.32 87.35 69.50 85.29 55.74 70.21
JOAOv2 78.36 74.07 77.40 87.67 69.33 86.42 56.03 70.83
AD-GCL-FIX 75.77 75.04 75.38 88.62 74.79 92.06 56.24 71.49
RGCL 78.14 75.03 78.86 87.66 70.92 90.34 56.38 71.85

About aug.py

Hello, first of all thank you very much for your excellent work.
Secondly, I have some questions about the data enhancement part of the unsupervised learning code you provided.
The code you gave only removes the discarded nodes in the data.edge_index, but not in the node attribute data.x.
I can understand that this is preventing them from messaging with undiscarded nodes, but I'm wondering if this will affect performance?
I think these discarded nodes also end up participating in pooling to generate graph-level vectors.
I look forward to your answer

`def drop_nodes_prob(data, node_score, rho):

node_num, _ = data.x.size()
_, edge_num = data.edge_index.size()
drop_num = int(node_num*(1.0-rho))

node_prob = node_score.float()
node_prob += 0.001
node_prob = np.array(node_prob)
node_prob /= node_prob.sum()

idx_nondrop = np.random.choice(node_num, node_num - drop_num, replace=False, p=node_prob)
idx_drop = np.setdiff1d(np.arange(node_num), idx_nondrop)
idx_nondrop.sort()

idx_dict = {idx_nondrop[n]: n for n in list(range(node_num - drop_num))}

edge_index = data.edge_index.numpy()

adj = torch.zeros((node_num, node_num))
adj[edge_index[0], edge_index[1]] = 1
adj[idx_drop, :] = 0
adj[:, idx_drop] = 0
edge_index = adj.nonzero().t()

data.edge_index = edge_index

return data

def drop_nodes_cp(data, node_score, rho):

node_num, _ = data.x.size()
_, edge_num = data.edge_index.size()
drop_num = int(node_num*(1.0-rho))

node_prob = node_prob = max(node_score.float()) - node_score.float()
node_prob += 0.001
node_prob = np.array(node_prob)
node_prob /= node_prob.sum()

idx_nondrop = np.random.choice(node_num, node_num - drop_num, replace=False, p=node_prob)
idx_drop = np.setdiff1d(np.arange(node_num), idx_nondrop)
idx_nondrop.sort()

idx_dict = {idx_nondrop[n]: n for n in list(range(node_num - drop_num))}

edge_index = data.edge_index.numpy()

adj = torch.zeros((node_num, node_num))
adj[edge_index[0], edge_index[1]] = 1
adj[idx_drop, :] = 0
adj[:, idx_drop] = 0
edge_index = adj.nonzero().t()

data.edge_index = edge_index

return data`

image

Training with subgraph augmentation ?

Hi ! Thank you for your work! However on TUDatasets, the code seems to failed when using subgraphs as augmentation :

Traceback (most recent call last):
  File "rgcl.py", line 172, in <module>
    for data in dataloader:
  File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 521, in __next__
    data = self._next_data()
  File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in _next_data
    return self._process_data(data)
  File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1229, in _process_data
    data.reraise()
  File "/opt/conda/lib/python3.7/site-packages/torch/_utils.py", line 434, in reraise
    raise exception
AssertionError: Caught AssertionError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
    data = fetcher.fetch(index)
  File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 49, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/user/.local/lib/python3.7/site-packages/torch_geometric/data/dataset.py", line 239, in __getitem__
    data = self.get(self.indices()[idx])
  File "/workspace/external_src/RGCL/unsupervised_TU/aug.py", line 251, in get
    assert False
AssertionError

In the get method in aug.py only 'drop_ra' seems to be implemented. drop_ra is equivalent to node dropping ? and i'm not sure to understant the role of n = np.random.randint(2) m = np.random.randint(2) in the get methods. Thanks in advance for your answer !

About loss

Hi, thank you for your great work, it helps me a lot.
Secondly, regarding the loss function in the semi-supervised task, I found that during training, the loss function seems to be different from that in the paper, and only part of it is used. Is my understanding correct, what is the effect of doing this?
In addition, there is another question. I don’t quite understand how to sample two subgraphs. How does the rationale score guide the sampling? Is the final contribution of each node a specific value? Then how to get rationale and non-rationale based on it What about the probability of the rationale part. How to sample after getting the probability?

Question about enviroment

Hi,Thank you for the paper and the code!But may you share your environment about torch's version informations?The issues in "GCL" didn't make me fix this problem(old version problems),may you directly tell me the vesion informations?thanks a lot

Details of the training of the rationale generator

Thank you for your enlightening work. I am curious about the training of the rationale generator which does not seem to be detailedly mentioned in the paper. In section 4.1, it seems that the rationale generator is trained unsupervised, but popular IRD training methods need labels. Could you tell me how to train the rationale generator through unsupervised training? Thank you.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.