GithubHelp home page GithubHelp logo

torch-two-sample's Introduction

torch-two-sample

Documentation Status Build Status

A PyTorch library for differentiable two-sample tests

Description

This package implements a total of six two sample tests:

  • The classical Friedman-Rafsky test [FR79].
  • The classical k-nearest neighbours (kNN) test [FR83].
  • The differentiable Friedman-Rafsky test [DK17].
  • The differentiable k-nearest neighbours (kNN) test [DK17].
  • The maximum mean discrepancy (MMD) test [GBR+12].
  • The energy test [SzekelyR13].

Please refer to the documentation for more information about the project. You can also have a look at the following notebook that showcases how to use the code to train a generative model on MNIST.

Installation

After installing PyTorch, you can install the package with:

python setup.py install

Testing

To run the tests you simply have to run:

python setup.py test

Note that you will need to have Shogun installed for one of the test cases.

Bibliography

  • [DK17] J. Djolonga and A. Krause. Learning Implicit Generative Models Using Differentiable Graph Tests. ArXiv e-prints, September 2017. arXiv:1709.01006.
  • [FR79] Jerome H Friedman and Lawrence C Rafsky. Multivariate generalizations of the wald-wolfowitz and smirnov two-sample tests. Annals of Statistics, pages 697–717, 1979.
  • [FR83] Jerome H Friedman and Lawrence C Rafsky. Graph-theoretic measures of multivariate association and prediction. Annals of Statistics, pages 377–391, 1983.
  • [GBR+12] Arthur Gretton, Karsten M Borgwardt, Malte J Rasch, Bernhard Schölkopf, and Alexander Smola. A kernel two-sample test. Journal of Machine Learning Research, 13(Mar):723–773, 2012.
  • [SST+12] Kevin Swersky, Ilya Sutskever, Daniel Tarlow, Richard S Zemel, Ruslan R Salakhutdinov, and Ryan P Adams. Cardinality restricted boltzmann machines. In Advances in Neural Information Processing Systems (NIPS), 3293–3301. 2012.
  • [SzekelyR13] Gábor J Székely and Maria L Rizzo. Energy statistics: a class of statistics based on distances. Journal of Statistical Planning and Inference, 143(8):1249–1272, 2013.
  • [TSZ+12] Daniel Tarlow, Kevin Swersky, Richard S Zemel, Ryan Prescott Adams, and Brendan J Frey. Fast exact inference for recursive cardinality models. Uncertainty in Artificial Intelligence (UAI), 2012.

torch-two-sample's People

Contributors

bradyneal avatar calincru avatar josipd avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

torch-two-sample's Issues

runtimeError in pytorch 0.4.1

pytorch 0.4.1

RuntimeError Traceback (most recent call last)
in ()
37 loss = loss_fn(Variable(batch), generator(noise), alphas=alphas)
38 print(" loss is ", batch)
---> 39 loss.backward()
40 print(" backward")
41 optimizer.step()

~/Downloads/venv/lib/python3.6/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
91 products. Defaults to False.
92 """
---> 93 torch.autograd.backward(self, gradient, retain_graph, create_graph)
94
95 def register_hook(self, hook):

~/Downloads/venv/lib/python3.6/site-packages/torch/autograd/init.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
88 Variable._execution_engine.run_backward(
89 tensors, grad_tensors, retain_graph, create_graph,
---> 90 allow_unreachable=True) # allow_unreachable flag
91
92

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

Computing p-value in non differentiable statistics

Hi,

I am trying to use the non-differentiable statistics to compute the p-values. My code is as follows:

sample_1 = Variable(torch.FloatTensor(sample_1))
sample_2 = Variable(torch.FloatTensor(sample_2))
statistic = FRStatistic(sample_1, sample_2)
_, mst = statistic.__call__(sample_1, sample_2, ret_matrix=True)
pvalue = statistic.pval(mst)

sample_1, sample_2 are pytorch variables of size [100x22]. However I have the following two errors (they are exactly the same regardless of which statistic I want to use):

Traceback (most recent call last):
  File "code/compare.py", line 147, in <module>
    compare.KnnStatistic(original_data, produced_data, args.cardinality, 100)
  File "code/compare.py", line 102, in KnnStatistic
    _, mst = statistic.__call__(sample_1, sample_2, ret_matrix=True)
  File "/Users/dimkou/Documents/deep_learning/deep/lib/python3.6/site-packages/torch_two_sample-0.1-py3.6-macosx-10.13-x86_64.egg/torch_two_sample/statistics_nondiff.py", line 186, in __call__
    assert n_1 == self.n_1 and n_2 == self.n_2
  File "/Users/dimkou/Documents/deep_learning/deep/lib/python3.6/site-packages/torch/autograd/variable.py", line 125, in __bool__
    torch.typename(self.data) + " is ambiguous")
RuntimeError: bool value of Variable objects containing non-empty torch.ByteTensor is ambiguous

Uncommenting that line from the source code and recompiling the package solves that error.

The second one is:

Traceback (most recent call last):
  File "code/compare.py", line 133, in <module>
    compare.FRStatistic(original_data, produced_data, args.cardinality, 100)
  File "code/compare.py", line 85, in FRStatistic
    pvalue = statistic.pval(mst)
  File "/Users/dimkou/Documents/deep_learning/deep/lib/python3.6/site-packages/torch_two_sample-0.1-py3.6-macosx-10.13-x86_64.egg/torch_two_sample/statistics_nondiff.py", line 141, in pval
    self.n_1, self.n_2, n_permutations)
  File "torch_two_sample/permutation_test.pyx", line 57, in torch_two_sample.permutation_test.permutation_test_mat
  File "/Users/dimkou/Documents/deep_learning/deep/lib/python3.6/site-packages/torch/autograd/variable.py", line 130, in __int__
    return int(self.data)
  File "/Users/dimkou/Documents/deep_learning/deep/lib/python3.6/site-packages/torch/tensor.py", line 389, in __int__
    raise TypeError("only 1-element tensors can be converted "
TypeError: only 1-element tensors can be converted to Python scalars

The return type of mst is a FloatTensor of size [200x200]. Am I doing something wrong in the utilization of the package? How can I fix the second error?

Edit: Ok, the first fix was the problem for the second one. Setting:

self.n_1 = sample_1.size(0)
self.n_2 = sample_2.size(0)

in both statistics seems to fix the errors without messing up the functionality of the code. Shall I submit a PR?

RuntimeErrors with "torch_two_sample.statistics_diff.SmoothKNNStatistic"

Hello !
Thank you very much for sharing these functions to the PyTorch community,

I am successfully using the "torch_two_sample.statistics_diff.MMDStatistic" both at training/backprop and evaluation, running on GPU.

I am trying to have an alternative criterion, using instead the SmoothKNN that I set as the MMD (with the additional True boolean for cuda and the k parameter).

The code is absolutely identical to when using the MMD for the criterion but with the SmoothKNN function it gives a "RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation"

Does anyone have a fix to that please ?

Thanks in advance !

Numerous build errors

Hi,

I just ran python setup.py test and got mostly failures. It's not clear which dependencies I am missing, or whether that's the only problem

  1. How can I determine the dependencies?
  2. If the issue is with dependencies , perhaps this is a docs issue as well (just saying this to help :) )

==================================== short test summary info =====================================
FAILED tests/inference_trees_test.py::test_cardinality - AttributeError: module 'torch' has no ...
FAILED tests/inference_trees_test.py::test_chain - AttributeError: module 'torch' has no attrib...
FAILED tests/inference_trees_test.py::test_cycle - AttributeError: module 'torch' has no attrib...
FAILED tests/inference_trees_test.py::test_dumbbell - AttributeError: module 'torch' has no att...
FAILED tests/statistics_diff_test.py::test_mean_std - AttributeError: module 'torch' has no att...
FAILED tests/statistics_diff_test.py::test_mmd - ModuleNotFoundError: No module named 'modshogun'
FAILED tests/statistics_nondiff_test.py::test_k_smallest - TypeError: forward() missing 2 requi...
FAILED tests/statistics_nondiff_test.py::test_mst - TypeError: forward() missing 1 required pos...
FAILED tests/statistics_nondiff_test.py::test_fr - IndexError: invalid index of a 0-dim tensor....
FAILED tests/statistics_nondiff_test.py::test_1nn - IndexError: invalid index of a 0-dim tensor...
FAILED tests/statistics_nondiff_test.py::test_1nn_smooth - IndexError: invalid index of a 0-dim...
FAILED tests/util_test.py::test_pdist - assert False
============================ 12 failed, 3 passed, 8 warnings in 9.92s ============================

I am running

pytorch 1.4.0, py3.7_cpu_0 [cpuonly]

managed by

conda 4.7.12

with version

Python 3.7.7

on a Linux OS.

ModuleNotFoundError: No module named 'torch_two_sample.permutation_test'

I receive this error while trying to do this:
from torch_two_sample import SmoothFRStatistic
There was no error in the installation. Inside the torch_to_sample folder, I noticed that there is a permutation_test.pyx file, but not a .py file.
The complete stack trace:
Traceback (most recent call last):
File "", line 1, in
File "/media/someone/A00AA3760AA347DC/torch-two-sample/torch_two_sample/init.py", line 1, in
from .statistics_diff import (
File "/media/someone/A00AA3760AA347DC/torch-two-sample/torch_two_sample/statistics_diff.py", line 9, in
from .permutation_test import permutation_test_tri, permutation_test_mat
ModuleNotFoundError: No module named 'torch_two_sample.permutation_test'

Invalid indexing operations

Hi,

I just ran python setup.py test and got numerous IndexErrors like

IndexError: invalid index of a 0-dim tensor. Use tensor.item() to convert a 0-dim tensor to a Python number

I am running

pytorch 1.4.0, py3.7_cpu_0 [cpuonly]

managed by

conda 4.7.12

with version

Python 3.7.7

on a Linux OS.

Please help,
Thanks!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.