GithubHelp home page GithubHelp logo

gns's Introduction

Global Neighbor Sampling for Mixed CPU-GPU Training on Giant Graphs (GNS)

Jialin Dong*, Da Zheng*, Lin F. Yang, Geroge Karypis

Contact

Da Zheng ([email protected]), Jialin Dong ([email protected])

Feel free to report bugs or tell us your suggestions!

Overview

We propose a new effective sampling framework to accelerate GNN mini-batch training on giant graphs by removing the main bottleneck in mixed CPU-GPU training. GNS creates a global cache to facilitate neighbor sampling and periodically updates the cache. Therefore, it reduces data movement between CPU and GPU.

We empirically demonstrate the advantages of the proposed algorithm in convergence rate, computational time, and scalability to giant graphs. Our proposed method has a significant speedup in training on large-scale dataset. We also theoretically analyze GNS and show that even with a small cache size, it enjoys a comparable convergence rate as the node-wise sampling method.

GNS shares many advantages of various sampling methods and is able to avoid their drawbacks.

Like node-wise neighbor sampling, e.g., Graphsage, it samples neighbors on each node independently and, thus, can be implemented and parallelized efficiently. Due to the cache, GNS tends to avoid the neighborhood explosion in multi-layer GNN.

GNS maintains a global and static distribution to sample the cache, which requires only one-time computation and can be easily amortized during the training. In contrast, LADIES computes the sampling distribution for every layer in every mini-batch, which makes the sampling procedure expensive.

Even though GNS constructs a mini-batch with more nodes than LADIES, forward and backward computation on a mini-batch is not the major bottleneck in many GNN models for mixed CPU-GPU training.

Even though both GNS and LazyGCN deploy caching to accelerate computation in mixed CPU-GPU training, they use cache very differently. GNS uses cache to reduce the number of nodes in a mini-batch to reduce computation and data movement between CPU and GPUs. It captures majority of connectivities of nodes in a graph.

LazyGCN caches and reuses the sampled graph structure and node data. This requires a large mega-batch size to achieve good accuracy, which makes it difficult to scale to giant graphs. Because LazyGCN uses node-wise sampling or layer-wise sampling to sample mini-batches, it suffers from the problems inherent to these two sampling algorithms.

Requirements

  • python >= 3.6.0
  • pytorch >= 1.6.0
  • dgl >=0.7.0
  • Numpy>=1.16.0
  • scipy >= 1.1.0
  • scikit-learn >= 0.19.1
  • tqdm >= 2.2.4
  • scikit-learn>=0.20.0

Datasets

All datasets used in our papers are available for download.

Available in https://github.com/GraphSAINT/GraphSAINT

  • Yelp
  • Amazon

Available in DGL library

  • OAG-paper
  • OGBN-products
  • OGBN-Papers100M

Results

Training

Check out a customized DGL from here.

git clone https://github.com/zheng-da/dgl.git
cd dgl
git checkout new_sampling

Follow the instruction here to install DGL from source.

The following commands train GraphSage with GNS.

python3 GNS_sampling_prob.py --dataset ogbn-products    # training on OGBN-products
python3 GNS_sampling_prob.py --dataset oag_max_paper.dgl     # training on OAG-paper, OGBN-products and OGBN-Papers100M
python3 GNS_sampling_prob.py --dataset ogbn-papers100M   # training on OGBN-Papers100M
python3 GNS_yelp_amazon.py --dataset yelp   # training on Yelp
python3 GNS_yelp_amazon.py --dataset amazon   # training on Amazon

Test set F1-Score summarized below.

GNS

Citation

@inproceedings{GNS-KDD21,
title={Global Neighbor Sampling for Mixed {CPU}-{GPU} Training on Giant Graphs},
author={Dong, Jialin and Zheng, Da and Yang, Lin F and Karypis, Geroge}
booktitle={ booktitle={Proceedings of the 27th ACM SIGKDD International Conference on Knowledge Discovery \& Data Mining},},
year={2021}
}

gns's People

Contributors

jadadong avatar zheng-da avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

Forkers

hirayaku r0mer0m

gns's Issues

Problems with dataset oag-papers

Hi, I download oag_max_paper.dgl from here and use your preprocessing in code. I get 309 instead of 146 classes. Then I analyze the labels and find 293 classes that contains nodes. Is there a mistake in my download or did you do some other process with the data?

which tool do you use to profile the time of "slice data" and "sample"?

I found that in this repo you use the pyinstrument for profiling, do you use this tool to get the Figure 1 in the paper as well?

I set the arguments of dgl graphsage example as stated in the paper (num-hidden=256, num-layers=3, batch-size=1000, num-workers=4, fanout=15,10,5, dataset=ogbn-products, --data-cpu --inductive), but the result of pyinstrument shown in the second attachments differs a lot than that in the Figure 1.

To be more specific: I assume that "sample" time is composed of __next__ and __iter__ of the dataloader, but it takes ~3s, which is much longer than tensor copy(~2.3s Tensor.to entry in the second attachment) or forward/backward computation(~1.8s). I am using a comparable machine(72-core Xeon CPU @2.60GHz, 251GB RAM, 2080Ti) as the AWS EC2 g4dn.16xlarge instance you used in the experiment, so I am wondering do you also define the "sample" time as __next__ + __iter__?

image

image

several small bugs

  1. in GNS_yelp_amazon.py, the profiler will initialize and start several times if the num_epoch != 1

    for epoch in range(args.num_epochs):
    profiler = Profiler()
    profiler.start()
    if epoch % args.buffer_rs_every == 0:
    if args.buffer_size != 0:
    # initial the buffer

  2. arguments error for code below, I guess a max-fanout [60,60,60] is missing?

    sampler_test = dgl.dataloading.MultiLayerNeighborSampler(
    [60, 60, 60], args.buffer, args.buffer_size, g)

  3. just a suggestion: maybe we can check the existence of dir "./results1" or provide a results_path in the argparser?

    f = open('results1/history_' + args.dataset + '_' + str(args.buffer_size) + '_' + str(

  4. after fix the bugs mentioned above, the program ended with an Exception. But all results(logs and pictures) are saved, so maybe this is not important?

Using backend: pytorch                                                                                                              
loading graph                                                                                                                                
processing graph                                                                                                                             
Pack graph                                                                                                                                   
get cache sampling probability                                                                                                               
create the model                                                                                                                             
/home/data/xzhanggb/envs/miniconda3/envs/gns/lib/python3.7/site-packages/torch/_tensor.py:575: UserWarning: floor_divide is deprecated, and w
ill be removed in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrec
t rounding for negative values.                                                                                                              
To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='fl
oor'). (Triggered internally at  ../aten/src/ATen/native/BinaryOps.cpp:467.)                                                                 
  return torch.floor_divide(self, other)                                                                                                     
start training                                                                                                                               
/home/data/xzhanggb/envs/miniconda3/envs/gns/lib/python3.7/site-packages/numpy/core/fromnumeric.py:3441: RuntimeWarning: Mean of empty slice.
  out=out, **kwargs)                                                                                                                         
/home/data/xzhanggb/envs/miniconda3/envs/gns/lib/python3.7/site-packages/numpy/core/_methods.py:189: RuntimeWarning: invalid value encountere
d in double_scalars                                                                                                                          
  ret = ret.dtype.type(ret / rcount)                                                                                                         
Epoch 00000 | Step 00000 | Loss 1.0434 | Train Acc 0.1505 | Speed (samples/sec) nan | GPU 167.2 MiB                                          
Epoch 00000 | Step 00100 | Loss 0.1851 | Train Acc 0.4962 | Speed (samples/sec) 31462.7002 | GPU 210.5 MiB                                   
Epoch 00000 | Step 00200 | Loss 0.1794 | Train Acc 0.5803 | Speed (samples/sec) 32311.8589 | GPU 219.1 MiB                                   
Epoch 00000 | Step 00300 | Loss 0.1719 | Train Acc 0.5919 | Speed (samples/sec) 31085.2463 | GPU 219.1 MiB                                   
Epoch 00000 | Step 00400 | Loss 0.1670 | Train Acc 0.5874 | Speed (samples/sec) 31076.3003 | GPU 219.1 MiB                                   
Epoch 00000 | Step 00500 | Loss 0.1627 | Train Acc 0.5765 | Speed (samples/sec) 30894.4708 | GPU 219.1 MiB                                   
Epoch Time(s): 19.5753                                                                                                                       
Epoch 00001 | Step 00000 | Loss 0.1696 | Train Acc 0.5864 | Speed (samples/sec) 30803.0536 | GPU 219.5 MiB                                   
Epoch 00001 | Step 00100 | Loss 0.1605 | Train Acc 0.6020 | Speed (samples/sec) 30148.8415 | GPU 219.5 MiB                                   
Epoch 00001 | Step 00200 | Loss 0.1633 | Train Acc 0.5847 | Speed (samples/sec) 29517.9529 | GPU 219.5 MiB                                   
Epoch 00001 | Step 00300 | Loss 0.1599 | Train Acc 0.5800 | Speed (samples/sec) 28822.6234 | GPU 219.5 MiB                                   
Epoch 00001 | Step 00400 | Loss 0.1579 | Train Acc 0.6054 | Speed (samples/sec) 28289.1415 | GPU 219.5 MiB                                   
Epoch 00001 | Step 00500 | Loss 0.1553 | Train Acc 0.6036 | Speed (samples/sec) 27906.0660 | GPU 219.5 MiB                                   
Epoch Time(s): 22.8442                                                                                                                       
Exception ignored in: <function WeakKeyDictionary.__init__.<locals>.remove at 0x7ff30db62950>                                                
Traceback (most recent call last):                                                                                                           
  File "/home/data/xzhanggb/envs/miniconda3/envs/gns/lib/python3.7/weakref.py", line 356, in remove                                          
  File "/home/data/xzhanggb/envs/miniconda3/envs/gns/lib/python3.7/site-packages/pyinstrument/stack_sampler.py", line 137, in _sample        
  File "/home/data/xzhanggb/envs/miniconda3/envs/gns/lib/python3.7/site-packages/pyinstrument/stack_sampler.py", line 189, in build_call_stac
k                                                                                                                                            
TypeError: 'NoneType' object is not callable

questions about magic-numbers in dataset splitting

I have several questions about the statistics related to sampling and dataset splitting:

  1. Could you please explain why you choose the specific splitting ratio for Train/Val/Test on the datasets?

1

  1. I found that the #inputs statistics of OGBN-Papers100M will differ a lot if you use a larger training set. For example on the full graph, NS with 15-10-5 fanout only generates ~117406 inputs per mini-batch(because this dataset is quite dense). Do you have corresponding #cached nodes stats of OGBN-Papers100M when using a larger training set?

image

  1. In the chapter 4.1 of the paper, you said that "The sampling fan-outs of each layer are 15, 10 for the third and second layer. We sample nodes in the first layer (input layer) only from the cache". Could you please elaborate how you set the fan-out for the first layer, e.g. is there a upper/lower bound?

Thank you for your time!

problem in OAG-paper dataset splitting

in the code below, valid_labal_idx's length is actually the number of all nodes coming with labels (train+test+val). So train_size should be int(len(valid_labal_idx) * (0.43/0.53)) instead of int(len(valid_labal_idx) * 0.43). Current code actually yields a splitting with smaller training & validation set while a much bigger test set, namely Train:Val:Test = 0.53*0.43: 0.53*0.05 : 0.53*(1-0.48).

GNS/GNS_sampling_prob.py

Lines 501 to 508 in b5b2ce8

if 'oag' in args.dataset:
labels = graph.ndata['field']
graph.ndata['feat'] = graph.ndata['emb']
label_sum = labels.sum(1)
valid_labal_idx = th.nonzero(label_sum > 0, as_tuple=True)[0]
train_size = int(len(valid_labal_idx) * 0.43)
val_size = int(len(valid_labal_idx) * 0.05)
test_size = len(valid_labal_idx) - train_size - val_size

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.