GithubHelp home page GithubHelp logo

mldl / p-reg Goto Github PK

View Code? Open in Web Editor NEW

This project forked from yang-han/p-reg

0.0 1.0 0.0 540 KB

Rethinking Graph Regularization for Graph Neural Networks.

Home Page: https://arxiv.org/abs/2009.02027

Python 79.19% Shell 20.81%

p-reg's Introduction

Rethinking Regularization for Graph Neural Networks

This is the source code to reproduce the experimental results for Rethinking Graph Regularization for Graph Neural Networks.

The code for graph-level experiments is in the ./graph_level/ sub-folder.

Dependencies

python==3.7.6
pytorch==1.5.0
pytorch_geometric==1.4.3
numpy==1.18.1

Code Description

main.py

The entry file. Load the datasets and models, train and evaluate the model.

conv.py

The IConv is modified from the torch.geometric.nn.GCNConv, to implement the propagation of output, i.e., $\hat{A}Z$ in the paper.

Comparing to the original GCNConv, IConv removed the bias matrix, and replaced the weight matrix by an untrainable Identity matrix.

models.py

GCN, GAT and MLP are implemented in a standard way and provided in this file.

PREGGCN, PREGGAT and PREGMLP have an additional method propagation(), which is to further propagate the output of the vanilla GCN,GAT and MLP models.

A typical Propagation-regularization can be computed as:

soft_cross_entropy(
    F.softmax(
        model.propagation(data.x, data.edge_index),
        dim=1
    ),
    F.softmax(
        model(data.x, data.edge_index),
        dim=1
    )
)

phi.py

soft_cross_entropy(), kl_div(), squared_error() are provided in phi.py as different $\phi$ functions.

loss.py

LabelSmoothingLoss, confidence_penalty and laplacian_reg are provided in loss.py as baselines.

utils.py

Some useful functions are implemented in this file.

generate_split() is used to generate the random splits for each dataset. And the generated splits we used in our experiments are in the ./splits/ folder.

Mask is the structure that random split are stored.

load_dataset(), load_split() are provided to load the datasets and random splits.

Reproducing Experimental Results

Random splits (in Table 1)

Passing --num_splits 5 to main.py means using the first 5 randomly generated splits provided in the ./splits/ folder. Set --mu 0 to use the vanilla models without P-reg.

models: ['PREGGCN', 'PREGGAT', 'PREGMLP']
datasets: ['cora', 'citeseer', 'pubmed', 'cs', 'physics', 'computers', 'photo']

The command to train and evaluate a model is:

python main.py --dataset $dataset --model $model --mu $mu --num_seeds $num_seeds --num_splits $num_splits

For example, experiments with GCN+P-reg (mu=0.5) on CORA dataset for 5 splits and 5 seeds for each split:

python main.py --dataset cora --model preggcn --mu 0.5 --num_seeds 5 --num_splits 5

For complete commands to run all experiments, please refer to random_run.sh.

Plantoid standard split (in Table 2)

Passing --num_splits 1 to main.py means using the standard split of the Plaintoid datasets. Set --mu 0 to use the vanilla models without P-reg.

models: ['PREGGCN', 'PREGGAT']
datasets: ['cora', 'citeseer', 'pubmed']

Commands to reproduce experimental results on CORA, CiteSeer and PubMed datasets:

# CORA GAT+P-reg mu=0.45 standard split 10 seeds
python main.py --num_splits 1 --num_seeds 10 --dataset cora --model preggat --mu 0.45
# CiteSeer GCN+P-reg mu=0.35 standard split 10 seeds
python main.py --num_splits 1 --num_seeds 10 --dataset citeseer --model preggcn --mu 0.35
# PubMed GCN+P-reg mu=0.15 standard split 10 seeds
python main.py --num_splits 1 --num_seeds 10 --dataset pubmed --model preggcn --mu 0.15

Tips

  1. --model PREGGCN --mu 0 means to use the vanilla GCN model. (Similarly, to use vanilla GAT and MLP, please set --mu 0.)
  2. --num_splits 1 means to use the standard split that is provided in the Plantoid dataset (CORA, CiteSeer and PubMed), while --num_splits 5 to main.py means using the first 5 randomly generated splits provided in the ./splits/ folder (for all 7 datasets).
  3. In main.py, replace the soft_cross_entropy with kl_div or squared_error (provided in phi.py) to experiment with different $phi$ functions.
  4. In main.py, replace the nll_loss to LabelSmoothingLoss (provided in loss.py) to experiment with Label Smoothing. Add confidence_penalty or laplacian_reg (provided in loss.py) to the original loss item to experiment with Confidence Penalty or Laplacian Regularizer.
  5. In our experiments, for GCN and MLP, we use hidden_size=64, while for GAT, we use hidden_size=16.
  6. In our experiments, for CORA, CiteSeer and PubMed, we use weight_decay=5e-4, while for CS, Physics, Computers and Photo, we use weight_decay=0. This is determined by the vanilla model performance.
  7. By default, the training is stopped with validation accuracy no longer increases for 200 epochs (patience=200).
  8. The code of other state-of-the-art methods is either from their corresponding official repository or pytoch-geometric benchmarking code. Details are attached below.
  9. The code for graph-level experiments is in the ./graph_level/ folder.

Citation

@article{yang2020rethinking,
    title={Rethinking Graph Regularization For Graph Neural Networks},
    author={Han Yang and Kaili Ma and James Cheng},
    journal={arXiv preprint arXiv:2009.02027},
    year={2020}
}

p-reg's People

Contributors

makaili avatar yang-han avatar

Watchers

James Cloos avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.