GithubHelp home page GithubHelp logo

lirongwu / graphmixup Goto Github PK

View Code? Open in Web Editor NEW
20.0 2.0 5.0 31.63 MB

Code for ECML-PKDD 2022 paper "GraphMixup: Improving Class-Imbalanced Node Classification by Reinforcement Mixup and Self-supervised Context Prediction"

License: MIT License

Python 100.00%
graph-algorithms graph-self-supervised-learning imbalanced-classification imbalanced-data reinforcement-learning

graphmixup's Introduction

GraphMixup

This is a PyTorch implementation of the GraphMixup, and the code includes the following modules:

  • Dataset Loader (Cora, BlagCatalog, and Wiki-CS)

  • Various Architectures (GCN, SAGE, GAT, and SEM)

  • Five compared baselines (Origin, Over-Sampling, Re-weight, SMOTE, and Embed-SMOTE)

  • Training paradigm (joint learning, pre-training, and fine-tuning) for node classification on three datasets

  • Visualization and evaluation metrics

Main Requirements

  • networkx==2.5
  • numpy==1.19.2
  • scikit-learn==0.24.1
  • scipy==1.5.2
  • torch==1.6.0

Description

  • train.py

    • train() -- Train a new model for node classification task on the Cora, BlagCatalog, and Wiki-CS datasets
    • test() -- Test the learned model for node classification task on the Cora, BlagCatalog, and Wiki-CS datasets
    • save_model() -- Save the pre-trained model
    • load_model() -- Load model for fine-tuning
  • data_load.py

    • load_cora() -- Load Cora Dataset
    • load_BlogCatalog() -- Load BlogCatalog Dataset
    • load_wiki_cs() -- Load Wiki-CS Dataset
  • models.py

    • GraphConvolution() -- GCN Layer
    • SageConv() -- SAGE Layer
    • SemanticLayer() -- Semantic Feature Layer
    • GraphAttentionLayer() -- GAT Layer
    • PairwiseDistance() -- Perform self-supervised Local-Path Prediction
    • DistanceCluster() -- Perform self-supervised Global-Path Prediction
  • utils.py

    • src_upsample() -- Perform interpolation in the input space
    • src_smote() -- Perform interpolation in the embedding space
    • mixup() -- Perform mixup in the semantic relation space
  • QLearning.py

    • GNN_env() -- Calculate rewards, perform actions, and update states
    • isTerminal() -- Determine whether the termination conditions have been met

Running the code

  1. Install the required dependency packages

  2. To get the results on a specific dataset, first run with proper hyperparameters to perform pre-training

python train.py --dataset data_name --setting pre-train

where the data_name is one of the 3 datasets (CCora, BlagCatalog, and Wiki-CS). The pre-trained model will be saved to the corresponding checkpoint folder in ./checkpoint for evaluation.

  1. To fine-tune the pre-trained model, run
python train.py --dataset data_name --setting fine-tune --load model_path

where the model_path is the path where the pre-trained model is saved.

  1. We provide five compared baselines in this code. They can be configured via the '--setting' arguments:
  • Origin: Vanilla backbone models with '--setting raw'
  • Over-Sampling: Repeat nodes in the minority classes with '--setting over-sampling'
  • Re-weight: Give samples from minority classes a larger weight when calculating the loss with '--setting re-weight'
  • SMOTE: Interpolation in the input space with '--setting smote'
  • Embed-SMOTE: Perform SMOTE in the intermediate embedding space with '--setting embed_smote'

Use Embed-SMOTE as an example:

python train.py --dataset cora --setting embed_smote

Citation

If you find this project useful for your research, please use the following BibTeX entry.

@inproceedings{wu2023graphmixup,
  title={Graphmixup: Improving class-imbalanced node classification by reinforcement mixup and self-supervised context prediction},
  author={Wu, Lirong and Xia, Jun and Gao, Zhangyang and Lin, Haitao and Tan, Cheng and Li, Stan Z},
  booktitle={Machine Learning and Knowledge Discovery in Databases: European Conference, ECML PKDD 2022, Grenoble, France, September 19--23, 2022, Proceedings, Part IV},
  pages={519--535},
  year={2023},
  organization={Springer}
}

graphmixup's People

Contributors

lirongwu avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

graphmixup's Issues

关于GAT的结果

您好,我在做对比实验的时候发现编码器为GAT的结果很差,请问是什么原因?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.