GithubHelp home page GithubHelp logo

home7's Introduction

README

这篇文章中介绍了一种采用生成对抗网络的新的方案来解决图中的半监督学习。首先是我的阅读笔记。然后在此基础上做一些简单的实验。

Link

Semi-supervised Learning on Graphs with Generative Adversarial Nets https://github.com/THUDM/GraphSGAN

Motivation

Using Generative Adversarial Nets(GAN) to help semi-supervised learning on graphs. The curse of dimensionality may make the propagation across different cluster easier, so authors want to generate more fake nodes between different clusters.

Idea

Using GAN to generate fake nodes for low gap density area.

GANs

$$\min\limits_{G}\max\limits_{D}V(G,D)=E_{x\sim p_d(x)}\log D(x)+E_{z\sim p_z(z)}\log (1-D(G(z)))$$

This formula means that we want to minimize the Generated graphs(fake) and maximize the real graphs for D function. At the same time, we should maximize the second term for G function to fool the D function.

Gap density

Try to weaken the effect of propagation across density gaps by adding fake nodes.

Problem: GAN cannot use the graphs' topology, the G function cannot generate fake nodes in low density areas.

Solution: Use some embedding techniques such as DeepWalk, Line, NetMF and so on to learn the latent distributed representation. After the graph topology learning, we can use these infomation to train GANs.

General techniques

  1. Batch Normalization (BN) for gradient stability

  2. Weight Normalization for trainable weight scale

  3. additive Gaussian noise in D function (classifier in this paper) for training

  4. neighbor fusion

Architecture

Classifier(D function): softmax with an additional fake class

Loss function:

  1. D function:$L_D =loss_{sup} + \lambda_0 loss_{un} + \lambda_1 loss_{ent} + loss_{pt}$

    1. for labeled nodes: maximize cross entropy
    2. for classify: should not be mapped into the low density area and fake nodes should be mapped into the low density area: minimize p(M|x),maximize p(M|g(z))
    3. For unlabeled nodes: maximize the entropy over definite labels
    4. For low density area: widen the gap: maximize the cosine distance
  2. G function:$L_G =loss_{fm} + \lambda_2 loss_{pt}$

    1. Generated nodes: should be in the low density area: minimize the distance to central point
    2. Generated samples should not overfit at the only center point: the same as 4 in D function.
  3. hyper-parameter: to control the relative strength of different terms.

Code

data preprocess

To random walk to generate word2vec model, I choose the number of walk rounds as 10, because the average edge for each node is 2. Meanwhile, I choose the path length as 400 due because the average nodes' number for each type is 400.

import random
from gensim.models import Word2Vec
def walker(G, walk_length, start_node):
    walk = [str(start_node)]

    while len(walk) < walk_length:
        cur = int(walk[-1])
        cur_nbrs = list(G.neighbors(cur))
        if len(cur_nbrs) > 0:
            walk.append(str(random.choice(cur_nbrs)))
        else:
            break
    return walk
# From https://github.com/shenweichen/GraphEmbedding/blob/master/ge/walker.py
def _simulate_walks(G, nodes, num_walks, walk_length):
    walks = []
    for _ in range(num_walks):
        random.shuffle(nodes)
        for v in nodes:
            walks.append(walker(G, walk_length=walk_length, start_node=v))
    return walks

def walk2vec(G, num_walks, walk_length):
    walks = _simulate_walks(G, list(G.nodes()), num_walks, walk_length)
    return Word2Vec(walks,sg=1,hs=1)

# The DataSet has 2708 nodes and 5429 edges with 7 classes

Neighbour Fusion

import numpy as np
def neighbour_fushion(node, nblist, fmatrix, alpha):
    sum = np.zeros(len(fmatrix[0]))
    for idx in nblist:
        sum = sum + fmatrix[idx]
    return sum * (1 - alpha) / len(nblist) + alpha * fmatrix[node]

For a test:

G = nx.Graph(adj)
Ln = list(G.neighbors(1))
f = neighbour_fushion(1, Ln, features, 0.9)
a = list(G.nodes())
c = walk2vec(G, 10, 400)
f = np.append(f, c['1'])

Training log

  1. walk2vec(G, 10, 400): acc = 67% 这个结果和文章中预期的结果83%相去甚远,我首先认为是由于对网络结构信息的过度采样导致的,于是我对自己实现的deepwalk进行了参数调整。

walk2vec(G, 10, 100): acc = 68.5%

walk2vec(G, 10, 10): acc = 71.2%

walk2vec(G, 10, 5): acc = 64.5%

walk2vec(G, 10, 7): acc = 65%

walk2vec(G, 10, 15): acc = 65.4%

  1. 好像没用,依然猜测和网络表示不足相关

word2vec(size = ):

size = 1000: acc = 70%

size = 1500: acc = 70%

size = 2000: acc = 73%

size = 2500: acc = 72%

size = 3000: acc = 71%

size = 5000: acc = 67.5%

最好的结果依然不能和文中的83%相比。 观察到生成器的Loss很难降低,而分类器的Loss迅速降低,转而去调整算法中的超参数。

  1. 上面的DeepWalk是自己实现+word2vec的简易版本,为了对比不同的网络嵌入向量,采用OpenNE系统,系统性地生成不同的embedding试一下

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.