GithubHelp home page GithubHelp logo

karlnapf / ds3_kernel_testing Goto Github PK

View Code? Open in Web Editor NEW
27.0 27.0 20.0 60.84 MB

Material for the practical of the DS3 course on "Representing and comparing probabilities with kernels"

Python 0.01% E 46.55% Fortran 52.89% Jupyter Notebook 0.55%

ds3_kernel_testing's People

Contributors

djsutherland avatar karlnapf avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

ds3_kernel_testing's Issues

Reproducing results with Shogun

Hi,

I was one of the attendant to the ds3 summer school. I was trying to go over the things that we learnt and repeat them using shogun (my goal is to use hypothesis testing via shogun for my own research). However, I cannot get the same results in shogun and in the code developed at DS3.

Something as simple as computing the MMD metric outputs different results using shogun w.r.t. using the MMD implementation of the summer school. I can show that with the following example:

import numpy as np
from tqdm import tqdm_notebook as tqdm
from scipy.spatial.distance import squareform, pdist, cdist
import shogun as sg

data = np.load("blobs.npz")
X_learn = data["X"]
Y_learn = data["Y"]

def two_sample_permutation_test(test_statistic, X, Y, num_permutations, prog_bar=True):
    assert X.ndim == Y.ndim
    
    statistics = np.zeros(num_permutations)
    
    range_ = range(num_permutations)
    if prog_bar:
        range_ = tqdm(range_)
    for i in range_:
        # concatenate samples
        if X.ndim == 1:
            Z = np.hstack((X,Y))
        elif X.ndim == 2:
            Z = np.vstack((X,Y))
            
        # IMPLEMENT: permute samples and compute test statistic
        perm_inds = np.random.permutation(len(Z))
        Z = Z[perm_inds]
        X_ = Z[:len(X)]
        Y_ = Z[len(X):]
        my_test_statistic = test_statistic(X_, Y_)
        statistics[i] = my_test_statistic
    return statistics

def quadratic_time_mmd(X,Y,kernel):
    assert X.ndim == Y.ndim == 2
    K_XX = kernel(X,X)
    K_XY = kernel(X,Y)
    K_YY = kernel(Y,Y)
       
    n = len(K_XX)
    m = len(K_YY)
    
    # IMPLEMENT: unbiased MMD statistic (could also use biased, doesn't matter if we use permutation tests)
    np.fill_diagonal(K_XX, 0)
    np.fill_diagonal(K_YY, 0)
    mmd = np.sum(K_XX) / (n*(n-1))  + np.sum(K_YY) / (m*(m-1))  - 2*np.sum(K_XY)/(n*m)
    return mmd


def gauss_kernel(X, Y=None, sigma=1.0):
    """
    Computes the standard Gaussian kernel k(x,y)=exp(- ||x-y||**2 / (2 * sigma**2))

    X - 2d array, samples on left hand side
    Y - 2d array, samples on right hand side, can be None in which case they are replaced by X
    
    returns: kernel matrix
    """

    # IMPLEMENT: compute squared distances and kernel matrix
    sq_dists = sq_distances(X,Y)
    K = np.exp(-sq_dists / (2 * sigma**2))
    return K

def sq_distances(X,Y=None):
    assert(X.ndim==2)

    # IMPLEMENT: compute pairwise distance matrix. Don't use explicit loops, but the above scipy functions
    # if X=Y, use more efficient pdist call which exploits symmetry
    if Y is None:
        sq_dists = squareform(pdist(X, 'sqeuclidean'))
    else:
        assert(Y.ndim==2)
        assert(X.shape[1]==Y.shape[1])
        sq_dists = cdist(X, Y, 'sqeuclidean')

    return sq_dists


log_sigma=-2
num_permutations=200

# Shogun implementation
feat_p=sg.RealFeatures(X_learn.T.astype(np.float64))
feat_q=sg.RealFeatures(Y_learn.T.astype(np.float64))
kernel=sg.GaussianKernel(2 * (10**log_sigma)**2)

mmd=sg.QuadraticTimeMMD(feat_p,feat_q)
mmd.set_kernel(kernel)

mmd.set_statistic_type(sg.ST_UNBIASED_FULL)
statistic=mmd.compute_statistic()

mmd.set_null_approximation_method(sg.NAM_PERMUTATION)
mmd.set_num_null_samples(num_permutations)


# DS3 summer school implementation
my_kernel = lambda X,Y : gauss_kernel(X,Y,sigma=10**log_sigma)
my_mmd = lambda X,Y : quadratic_time_mmd(X,Y, my_kernel)
my_statistic = my_mmd(X_learn, Y_learn)
statistics = two_sample_permutation_test(my_mmd, X_learn, Y_learn, num_permutations, prog_bar=False)
p_value = np.mean(my_statistic <= np.sort(statistics))

print(statistic)
print(my_statistic)
print(mmd.compute_p_value(statistic))
print(p_value)

Any guess on what might be happening? The MMD implementation should be the same in the toolbox and in the code. Might it be the kernel?

Edit: I have tried to use linear kernels and I still get different MMD values. I used the linear_kernel method from the summer school).

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.