GithubHelp home page GithubHelp logo

lnn's People

Contributors

wgao9 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

lnn's Issues

Error in vector multiplication

Hi,

Thanks for sharing the code implementing the methodology.
I have been reviewing the code and I think there may be an error. Lines 63 and 64,

lnn/lnn.py

Lines 63 to 64 in e5e913a

S1 += (dis/bw[i])*exp(-dis*dis.transpose()/(2*bw[i]**2))
S2 += (dis.transpose()*dis/(bw[i]**2))*exp(-dis*dis.transpose()/(2*bw[i]**2))

and should be

            S1 += np.multiply(dis/bw[i],exp(-dis*dis.transpose()/(2*bw[i]**2))) 
            S2 += np.multiply(dis.transpose()*dis/(bw[i]**2), exp(-dis*dis.transpose()/(2*bw[i]**2)))

I provide the code with the update below and it provides the correct answer on the test provided on the demo file.

import numpy.random as nr
import numpy as np
import scipy.spatial as ss
from math import log,pi,exp

def LNN_2_entropy(x,k=5,tr=30,bw=0):

    '''
        Estimate the entropy H(X) from samples {x_i}_{i=1}^N
        Using Local Nearest Neighbor (LNN) estimator with order 2
        Input: x: 2D list of size N*d_x
        k: k-nearest neighbor parameter
        tr: number of sample used for computation
        bw: option for bandwidth choice, 0 = kNN bandwidth, otherwise you can specify the bandwidth
        Output: one number of H(X)
        '''

    assert k <= len(x)-1, "Set k smaller than num. samples - 1"
    assert tr <= len(x)-1, "Set tr smaller than num.samples - 1"
    N = len(x)
    d = len(x[0])
    
    local_est = np.zeros(N)
    S_0 = np.zeros(N)
    S_1 = np.zeros(N)
    S_2 = np.zeros(N)
    tree = ss.cKDTree(x)
    if (bw == 0):
        bw = np.zeros(N)
    for i in range(N):
        lists = tree.query(x[i],tr+1,p=2)
        knn_dis = lists[0][k]
        list_knn = lists[1][1:tr+1]
        
        if (bw[i] == 0):
            bw[i] = knn_dis

        S0 = 0
        S1 = np.matrix(np.zeros(d))
        S2 = np.matrix(np.zeros((d,d)))
        for neighbor in list_knn:
            dis = np.matrix(x[neighbor] - x[i])
            S0 += exp(-dis*dis.transpose()/(2*bw[i]**2))
            S1 += np.multiply(dis/bw[i],exp(-dis*dis.transpose()/(2*bw[i]**2))) #(dis/bw[i])*exp(-dis*dis.transpose()/(2*bw[i]**2))
            S2 += np.multiply(dis.transpose()*dis/(bw[i]**2), exp(-dis*dis.transpose()/(2*bw[i]**2)))#(dis.transpose()*dis/(bw[i]**2))*exp(-dis*dis.transpose()/(2*bw[i]**2))

        Sigma = S2/S0 - S1.transpose()*S1/(S0**2)
        det_Sigma = np.linalg.det(Sigma)
        if (det_Sigma < (1e-4)**d):
            local_est[i] = 0
        else:
            offset = (S1/S0)*np.linalg.inv(Sigma)*(S1/S0).transpose()
            local_est[i] = -log(S0) + log(N-1) + 0.5*d*log(2*pi) + d*log(bw[i]) + 0.5*log(det_Sigma) + 0.5*offset[0][0]

    if (np.count_nonzero(local_est) == 0):
        return 0
    else: 
        return np.mean(local_est[np.nonzero(local_est)])
    
def _3LNN_2_mi(data,split,k=5,tr=30):
    '''
        Estimate the mutual information I(X;Y) from samples {x_i,y_i}_{i=1}^N
        Using I(X;Y) = H_{LNN}(X) + H_{LNN}(Y) - H_{LNN}(X;Y)
        where H_{LNN} is the LNN entropy estimator with order 2
        Input: data: 2D list of size N*(d_x + d_y)
        split: should be d_x, splitting the data into two parts, X and Y
        k: k-nearest neighbor parameter
        tr: number of sample used for computation
        Output: one number of I(X;Y)
    '''
    assert split >=1, "x must have at least one dimension"
    assert split <= len(data[0]) - 1, "y must have at least one dimension"
    x = data[:,:split]
    y = data[:,split:]

    N = len(data)
    H_x = LNN_2_entropy(x,k,tr)
    H_y = LNN_2_entropy(y,k,tr) 
    H_xy = LNN_2_entropy(data,k,tr)

    return  H_x +  H_y - H_xy 


# generate some random data
r = 0.5
data = nr.multivariate_normal([0,0],[[1,r],[r,1]],500)
print("Entropy: ")
print("Ground Truth = ", np.log(2*np.pi*np.exp(1))+0.5*np.log(1-r*r))
print(LNN_2_entropy(data,tr=50))
print("Mutual Information: ")
print("Ground Truth = ", -0.5*np.log(1.0-r*r))
print(_3LNN_2_mi(data=data,split=1,tr=50))

Outout:

Entropy: 
Ground Truth =  2.694036030183455
2.6112571933965274
Mutual Information: 
Ground Truth =  0.14384103622589045
0.1939686547758117

Look forward to your comments. Kind Regards

Inti

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.