GithubHelp home page GithubHelp logo

hookk / gnns-easy-to-use Goto Github PK

View Code? Open in Web Editor NEW

This project forked from zhao-tong/gnns-easy-to-use

0.0 1.0 0.0 187 KB

An PyTorch implementation of graph neural networks (GCN, GraphSAGE and GAT) that can be simply imported and used.

Python 100.00%

gnns-easy-to-use's Introduction

A PyTorch GNNs

This package contains a easy-to-use PyTorch implementation of GCN, GraphSAGE, and Graph Attention Network. It can be easily imported and used like using logistic regression from sklearn. Two versions for supervised GNNs are provided: one implemented with only PyTorch, the other implemented with DGL and PyTorch.

Note: The unsupervised version is built upon our GraphSAGE-pytorch implementation, and the DGL version is built upon the examples given by DGL.

Authors of this code package:

Tong Zhao ([email protected]), Tianwen Jiang ([email protected]).

Important dependencies

  • python==3.6.8
  • pytorch==1.0.1.post2
  • dgl==0.4.2

Usage

Parameters (GNNs_unsupervised):

adj_matrix: scipy.sparse.csr_matrix
    The adjacency matrix of the graph, where nonzero entries indicates edges.
    The number of each nonzero entry indicates the number of edges between these two nodes.

features: numpy.ndarray, optional
    The 2-dimension np array that stores given raw feature of each node, where the i-th row
    is the raw feature vector of node i.
    When raw features are not given, one-hot degree features will be used.

labels: list or 1-D numpy.ndarray, optional
    The class label of each node. Used for supervised learning.

supervised: bool, optional, default False
    Whether to use supervised learning.

model: {'gat', 'graphsage'}, default 'gat'
    The GNN model to be used.
    - 'graphsage' is GraphSAGE: https://cs.stanford.edu/people/jure/pubs/graphsage-nips17.pdf
    - 'gat' is graph attention network: https://arxiv.org/pdf/1710.10903.pdf

n_layer: int, optional, default 2
    Number of layers in the GNN

emb_size: int, optional, default 128
    Size of the node embeddings to be learnt

random_state, int, optional, default 1234
    Random seed

device: {'cpu', 'cuda', 'auto'}, default 'auto'
    The device to use.

epochs: int, optional, default 5
    Number of epochs for training

batch_size: int, optional, default 20
    Number of node per batch for training

lr: float, optional, default 0.7
    Learning rate

unsup_loss_type: {'margin', 'normal'}, default 'margin'
    Loss function to be used for unsupervised learning
    - 'margin' is a hinge loss with margin of 3
    - 'normal' is the unsupervised loss function described in the paper of GraphSAGE

print_progress: bool, optional, default True
    Whether to print the training progress

Example Usage

A detailed example of usage for unsupervised GNNs under different settings on the Cora dataset can be found in example_usage.py

To run the unsupervised GraphSAGE on Cuda:

from GNNs_unsupervised import GNN
gnn = GNN(adj_matrix, features=raw_features, supervised=False, model='graphsage', device='cuda')
# train the model
gnn.fit()
# get the node embeddings with the trained model
embs = gnn.generate_embeddings()

TODO Docs and examples for supervised GNNs will be added soon.

gnns-easy-to-use's People

Contributors

zhao-tong avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.