wgao9 / lnn Goto Github PK
View Code? Open in Web Editor NEWLocal Nearest Neighbot Information Estimator
Local Nearest Neighbot Information Estimator
Hi,
Thanks for sharing the code implementing the methodology.
I have been reviewing the code and I think there may be an error. Lines 63 and 64,
Lines 63 to 64 in e5e913a
and should be
S1 += np.multiply(dis/bw[i],exp(-dis*dis.transpose()/(2*bw[i]**2)))
S2 += np.multiply(dis.transpose()*dis/(bw[i]**2), exp(-dis*dis.transpose()/(2*bw[i]**2)))
I provide the code with the update below and it provides the correct answer on the test provided on the demo file.
import numpy.random as nr
import numpy as np
import scipy.spatial as ss
from math import log,pi,exp
def LNN_2_entropy(x,k=5,tr=30,bw=0):
'''
Estimate the entropy H(X) from samples {x_i}_{i=1}^N
Using Local Nearest Neighbor (LNN) estimator with order 2
Input: x: 2D list of size N*d_x
k: k-nearest neighbor parameter
tr: number of sample used for computation
bw: option for bandwidth choice, 0 = kNN bandwidth, otherwise you can specify the bandwidth
Output: one number of H(X)
'''
assert k <= len(x)-1, "Set k smaller than num. samples - 1"
assert tr <= len(x)-1, "Set tr smaller than num.samples - 1"
N = len(x)
d = len(x[0])
local_est = np.zeros(N)
S_0 = np.zeros(N)
S_1 = np.zeros(N)
S_2 = np.zeros(N)
tree = ss.cKDTree(x)
if (bw == 0):
bw = np.zeros(N)
for i in range(N):
lists = tree.query(x[i],tr+1,p=2)
knn_dis = lists[0][k]
list_knn = lists[1][1:tr+1]
if (bw[i] == 0):
bw[i] = knn_dis
S0 = 0
S1 = np.matrix(np.zeros(d))
S2 = np.matrix(np.zeros((d,d)))
for neighbor in list_knn:
dis = np.matrix(x[neighbor] - x[i])
S0 += exp(-dis*dis.transpose()/(2*bw[i]**2))
S1 += np.multiply(dis/bw[i],exp(-dis*dis.transpose()/(2*bw[i]**2))) #(dis/bw[i])*exp(-dis*dis.transpose()/(2*bw[i]**2))
S2 += np.multiply(dis.transpose()*dis/(bw[i]**2), exp(-dis*dis.transpose()/(2*bw[i]**2)))#(dis.transpose()*dis/(bw[i]**2))*exp(-dis*dis.transpose()/(2*bw[i]**2))
Sigma = S2/S0 - S1.transpose()*S1/(S0**2)
det_Sigma = np.linalg.det(Sigma)
if (det_Sigma < (1e-4)**d):
local_est[i] = 0
else:
offset = (S1/S0)*np.linalg.inv(Sigma)*(S1/S0).transpose()
local_est[i] = -log(S0) + log(N-1) + 0.5*d*log(2*pi) + d*log(bw[i]) + 0.5*log(det_Sigma) + 0.5*offset[0][0]
if (np.count_nonzero(local_est) == 0):
return 0
else:
return np.mean(local_est[np.nonzero(local_est)])
def _3LNN_2_mi(data,split,k=5,tr=30):
'''
Estimate the mutual information I(X;Y) from samples {x_i,y_i}_{i=1}^N
Using I(X;Y) = H_{LNN}(X) + H_{LNN}(Y) - H_{LNN}(X;Y)
where H_{LNN} is the LNN entropy estimator with order 2
Input: data: 2D list of size N*(d_x + d_y)
split: should be d_x, splitting the data into two parts, X and Y
k: k-nearest neighbor parameter
tr: number of sample used for computation
Output: one number of I(X;Y)
'''
assert split >=1, "x must have at least one dimension"
assert split <= len(data[0]) - 1, "y must have at least one dimension"
x = data[:,:split]
y = data[:,split:]
N = len(data)
H_x = LNN_2_entropy(x,k,tr)
H_y = LNN_2_entropy(y,k,tr)
H_xy = LNN_2_entropy(data,k,tr)
return H_x + H_y - H_xy
# generate some random data
r = 0.5
data = nr.multivariate_normal([0,0],[[1,r],[r,1]],500)
print("Entropy: ")
print("Ground Truth = ", np.log(2*np.pi*np.exp(1))+0.5*np.log(1-r*r))
print(LNN_2_entropy(data,tr=50))
print("Mutual Information: ")
print("Ground Truth = ", -0.5*np.log(1.0-r*r))
print(_3LNN_2_mi(data=data,split=1,tr=50))
Outout:
Entropy:
Ground Truth = 2.694036030183455
2.6112571933965274
Mutual Information:
Ground Truth = 0.14384103622589045
0.1939686547758117
Look forward to your comments. Kind Regards
Inti
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.