GithubHelp home page GithubHelp logo

basiralab / multigraphgnet Goto Github PK

View Code? Open in Web Editor NEW
8.0 2.0 1.0 2.45 MB

Predicting multigraph brain population from a single graph

Python 100.00%
graph-augmentation brain-connectivity connectomics deep-learning geometric-deep-learning gnn graph-neural-networks medical-dataset multigraph-generation network-neuroscience

multigraphgnet's Introduction

MultigraphGNet

MultigraphGNet for generating/augmenting multi-view brain graphs using only a single view-brain graph.

Please contact [email protected] for inquiries. Thanks.

MultigraphGNet architecture

Introduction

This work is accepted at the PRIME-MICCAI 2022 workshop, Singapore.

Predicting Brain Multigraph Population From a Single Graph Template for Boosting One-Shot Classification

Furkan Pala and Islem Rekik

BASIRA Lab, Faculty of Computer and Informatics, Istanbul Technical University, Istanbul, Turkey

Abstract: A central challenge in training one-shot learning models is the limited representativeness of the available shots of the data space. Particularly in the field of network neuroscience where the brain is represented as a graph, such models may lead to low performance when classifying brain states (e.g., typical vs. autistic). To cope with this, most of the existing works involve a data augmentation step to increase the size of the training set, its diversity and representativeness. Though effective, such augmentation methods are limited to generating samples with the same size as the input shots (e.g., generating brain connectivity matrices from a single shot matrix). To the best of our knowledge, the problem of generating brain multigraphs capturing multiple types of connectivity between pairs of nodes (i.e., anatomical regions) from a single brain graph remains unsolved. In this paper, we unprecedentedly propose a hybrid graph neural network (GNN) architecture, namely Multigraph Generator Network or briefly MultigraphGNet, comprising two subnetworks: (1) a many-to-one GNN which integrates an input population of brain multigraphs into a single template graph, namely a connectional brain temple (CBT), and (2) a reverse one-to-many U-Net network which takes the learned CBT in each training step and outputs the reconstructed input multigraph population. Both networks are trained in an end-to-end way using a cyclic loss. Experimental results demonstrate that our MultigraphGNet boosts the performance of an independent classifier when trained on the augmented brain multigraphs in comparison with training on a single CBT from each class. We hope that our framework can shed some light on the future research of multigraph augmentation from a single graph.

Code

This project was implemented using Python and PyTorch.

Installation

Clone this repo. Then, create a virtual env using the following command

python3 -m venv MultigraphGNet

Activate the virtual env using the following command (for Linux and OS X)

source MultigraphGNet/bin/activate

Activate the virtual env using the following command (for Windows)

MutligraphGNet\Scripts\activate.bat

Use pip to install dependencies. For cpu installation, use the following command

pip3 install torch scipy scikit-learn matplotlib torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.12.0+cpu.html

For cuda installation, use the following commands. These commands are for CUDA version 11.3. Installation for a different version, please check out PyTorch Geometric Docs and PyTorch Website.

pip3 install scipy scikit-learn matplotlib
pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
pip3 install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.12.0+cu113.html

Dataset

We represent the data for a subject as a 3 dimensional tensor of shape (N_ROI, N_ROI, N_VIEWS) where N_ROI is the number of region of interests in the brain and N_VIEWS is the number of views each capturing a different relation between the ROIs.

To create a simulated dataset, you can use create_simulated_data.py script.

Config

In config.py file, you can change the number of samples, region of interests in the simulated dataset. You can train on your custom dataset as well by changing the paths in DatasetClass1 and DatasetClass2 classes. However, if your dataset has a different formant than np.array (.npy), you need to write your function to read the dataset.

Running the MultigraphGNet code

First, you need to run the train_dgn_rdgn.py script to train DGN and RDGN networks using K-Fold cross validation. You can specify the number of folds and seed for the train-test split in K-Fold CV.

Once the training is done, you can see the testing results at dgn_rdgn_results/ directory. Also, you can see the visualizations of the reconstructed views in the directory for each fold.

Then, you can run the train_classifier.py script to train two SVM classifiers. To evaluate the effectiveness of our framework, we train two independent SVM classifiers using

  • one global CBT from each class (one-shot CBT baseline)
  • samples augmented by our trained RDGN net.
    • We augment k samples, you can specify the number of augmented samples by changing the config.K

You can see the classifiction results at classifier_results/ directory.

Results on real data

Visual inspection

Visual inspection We select a random test subject which is a multi-view brain graph and we pass it through the trained DGN network to obtain a subject specifc CBT. Then, we feed it to the trained RDGN network to reconstruct the original multiple views. In the figure above, we demonstrate this process as well as the mean absolute erros between the original and reconstructed views.

Classification results

Classification results In the figure above, we present the classification results on the real test set of two approaches:

  • One-shot CBT where we train two independent SVM classifiers using the global CBTs from each class generated by the trained DGN
  • Samples augmented using the trained RDGN network.

Results are averaged over the testing set, 5-Folds, 5 seeds and k=10, 25, 50 (number of augmented samples).

YouTube videos to install and run the code and understand how MultigraphGNet works

To install and run MultigraphGNet, check the following YouTube video:

https://www.youtube.com/watch?v=iNNFNlML_CU&ab_channel=BASIRALab

To learn about how MultigraphGNet works, check the following YouTube video:

https://www.youtube.com/watch?v=LQZBVwo_iNU&ab_channel=BASIRALab

ArXiv Link

You can access the published paper at https://arxiv.org/abs/2209.06005

Please cite the following paper when using MultigraphGNet

  @inproceedings{pala2022,
    title={Predicting Brain Multigraph Population From a Single Graph Template for Boosting One-Shot Classification},
    author={Pala, Furkan and Rekik, Islem},
    booktitle={International Workshop on PRedictive Intelligence In MEdicine},
    year={2022},
    organization={Springer}
  }

multigraphgnet's People

Contributors

aycagokdag avatar basiralab avatar furkanpala avatar islemrekik avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

Forkers

aycagokdag

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.