GithubHelp home page GithubHelp logo

gae-pytorch's Introduction

gae-pytorch

Graph Auto-Encoder in PyTorch

This is a PyTorch implementation of the Variational Graph Auto-Encoder model described in the paper:

T. N. Kipf, M. Welling, Variational Graph Auto-Encoders, NIPS Workshop on Bayesian Deep Learning (2016)

The code in this repo is based on or refers to https://github.com/tkipf/gae, https://github.com/tkipf/pygcn and https://github.com/vmasrani/gae_in_pytorch.

Requirements

  • Python 3
  • PyTorch 0.4
  • install requirements via pip install -r requirements.txt

How to run

python gae/train.py

gae-pytorch's People

Contributors

zfjsail avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

gae-pytorch's Issues

pos_weight should be a Tensor?

When I train with Cora dataset, I get the following error in binary_cross_entropy_with_logits. Shouldn't pos_weight be a Tensor? Thanks!

Traceback (most recent call last):
  File "train.py", line 83, in <module>
    gae_for(args)
  File "train.py", line 62, in gae_for
    norm=norm, pos_weight=pos_weight)
  File "/gae-pytorch/gae/optimizer.py", line 7, in loss_function
    cost = norm * F.binary_cross_entropy_with_logits(preds, labels, pos_weight=pos_weight)
  File "/anaconda3/lib/python3.6/site-packages/torch/nn/functional.py", line 2077, in binary_cross_entropy_with_logits
    return torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction_enum)
TypeError: binary_cross_entropy_with_logits(): argument 'pos_weight' (position 4) must be Tensor, not numpy.float64

ROC-AUC-score calculation function error

In the code in utils.py, the function get_roc_score() has this line:

labels_all = np.hstack([np.ones(len(preds)), np.zeros(len(preds))])

The second argument should be np.zeros(len(preds_neg))

I was using your code to understand GVAE and found this small mistake. Thank you.

should use encode() to get hidden_emb

I found that in train.py mu.data.numpy() is used to get hidden_emb, but it would get None when using GCNModelAE as model, hidden_emb should be got from model.encode() instead.

Some question of KLD

KLD = -0.5 / n_nodes * torch.mean(torch.sum(1 + 2 * logvar - mu.pow(2) - logvar.exp().pow(2), 1))
/ n_nodes should be removed or
torch.mean โ†’ torch.sum

About sampling

In VAE, sampling is z_mean + torch.exp(0.5 * z_log_var) * epsilon , but why is z_mean + torch.exp( z_log_var) in VGAE, does it cause anything different?

adj_label

May I know why adding self loop to adj_train to get adj_label?

adj_label = adj_train + sp.eye(adj_train.shape[0])

Custom Datasets + Module Loading Error

Can I use this framework on my own custom datasets?

Also, I'm experiencing an issue, when I try to run:
python gae/train.py

I get:

Traceback (most recent call last):
File "gae/train.py", line 12, in
from gae.model import GCNModelVAE
ModuleNotFoundError: No module named 'gae'

Yet this command works interactively.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.