GithubHelp home page GithubHelp logo

podgorskiy / gpnd Goto Github PK

View Code? Open in Web Editor NEW
130.0 11.0 31.0 5.84 MB

Generative Probabilistic Novelty Detection with Adversarial Autoencoders

Python 100.00%
deep-learning deep-neural-networks deep-novelty-detection novelty-detection anomaly-detection adversarial-learning autoencoder pytorch generative-adversarial-network gan

gpnd's Introduction

Generative Probabilistic Novelty Detection with Adversarial Autoencoders

Stanislav Pidhorskyi, Ranya Almohsen, Donald A Adjeroh, Gianfranco Doretto

Lane Department of Computer Science and Electrical Engineering, West Virginia University
Morgantown, WV 26508
{stpidhorskyi, ralmohse, daadjeroh, gidoretto} @mix.wvu.edu

The e-preprint of the article on arxiv.

NeurIPS Proceedings.

@inproceedings{pidhorskyi2018generative,
  title={Generative probabilistic novelty detection with adversarial autoencoders},
  author={Pidhorskyi, Stanislav and Almohsen, Ranya and Doretto, Gianfranco},
  booktitle={Advances in neural information processing systems},
  pages={6822--6833},
  year={2018}
}

Content

  • partition_mnist.py - code for preparing MNIST dataset.
  • train_AAE.py - code for training the autoencoder.
  • novelty_detector.py - code for running novelty detector
  • net.py - contains definitions of network architectures.

How to run

You will need to run partition_mnist.py first.

Then run schedule.py. It will run as many concurent experiments as many GPUs are available. Reusults will be written to results.csv file


Alternatively, you can call directly functions from train_AAE.py and novelty_detector.py

Train autoenctoder with train_AAE.py, you need to call train function:

train_AAE.train(
  folding_id,
  inliner_classes,
  ic
)

Args:

  • folding_id: Id of the fold. For MNIST, 5 folds are generated, so folding_id must be in range [0..5]
  • inliner_classes: List of classes considered inliers.
  • ic: inlier class set index (used to save model with unique filename).

After autoencoder was trained, from novelty_detector.py, you need to call main function:

novelty_detector.main(
  folding_id,
  inliner_classes,
  total_classes,
  mul,
  folds=5
)
  • folding_id: Id of the fold. For MNIST, 5 folds are generated, so folding_id must be in range [0..5]
  • inliner_classes: List of classes considered inliers.
  • ic: inlier class set index (used to save model with unique filename).
  • total_classes: Total count of classes (deprecated, moved to config).
  • mul: multiplier for power correction. Default value 0.2.
  • folds: Number of folds (deprecated, moved to config).

Generated/Reconstructed images

MNIST Reconstruction

MNIST Reconstruction. First raw - real image, second - reconstructed.



MNIST Reconstruction

MNIST Generation.



COIL100 Reconstruction

COIL100 Reconstruction, single category. First raw - real image, second - reconstructed. Only 57 images were used for training.



COIL100 Generation

COIL100 Generation. First raw - real image, second - reconstructed. Only 57 images were used for training.



COIL100 Reconstruction

COIL100 Reconstruction, 7 categories. First raw - real image, second - reconstructed. Only about 60 images per category were used for training



COIL100 Generation

COIL100 Generation. First raw - real image, second - reconstructed. Only about 60 images per category were used for training.



PDF

PDF of the latent space for MNIST. Size of the latent space - 32

gpnd's People

Contributors

syglassbuild avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

gpnd's Issues

Typo in "novelty_detection.py"

Thanks for the author's sharing! When I run the code, I find a typo in "novelty_detection.py" in Line 526~531:

        minp, maxP = -maxP, -minP
        X1 = [-x for x in X1]
        Y1 = [-x for x in Y1]
        auprout = 0.0
        recallTemp = 1.0
        for e in np.arange(minP, maxP, 0.2):

There defined a minp = -maxP, however, the parameter in np.arange is minP.

The code for other datasets

Hello dear author, sorry to bother you. I would like to ask when the code for other datasets will be released. Also, can the code that has been released now be used in my own data set?

about defining z in train_AAE.py

Dear podgorskiy,

Thank you for your contribution to the open source community and academic. I was trying to use your code to perform novelty detection. Some codes of yours confused me, so I'd like to ask you If I'm wrong or misunderstanding something.

I'd like to ask about some lines in train_AAE.py. In line 173, variable z is defined as below.

z = torch.randn((batch_size, zsize)).view(-1, zsize, 1, 1)

Since x_fake comes from x, I think z should be defined as below

z = E(x)

I appreciate for your effort to review this code in advance.

Error on schedule.py

When running schedule and trying to use multithreading, there will be an error because of multithreading contention.

Modify the code according to the error prompt and add multiprocessing. Freeze_ Support () can solve this problem.

Error tips:
RuntimeError:
An attempt has been made to start a new process before the
current process has finished its bootstrapping phase.

    This probably means that you are not using fork to start your
    child processes and you have forgotten to use the proper idiom
    in the main module:

        if __name__ == '__main__':
            freeze_support()
            ...

    The "freeze_support()" line can be omitted if the program
    is not going to be frozen to produce an executable.

The revised code is as follows:

from save_to_csv import save_results
import logging
import sys
import utils.multiprocessing
from defaults import get_cfg_defaults
import os
import multiprocessing as mult

def f(setting):
    import train_AAE
    import novelty_detector

    fold_id = setting['fold']
    inliner_classes = setting['digit']
    cfg = setting['cfg']
    train_AAE.train(fold_id, [inliner_classes], inliner_classes, cfg=cfg)

    res = novelty_detector.main(fold_id, [inliner_classes], inliner_classes, classes_count, mul, cfg=cfg)
    return res

if __name__ == '__main__':
    # in main module use freeze_support() 
    mult.freeze_support()

    full_run = True

    logger = logging.getLogger("logger")
    logger.setLevel(logging.DEBUG)
    ch = logging.StreamHandler(stream=sys.stdout)
    ch.setLevel(logging.DEBUG)
    formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s")
    ch.setFormatter(formatter)
    logger.addHandler(ch)

    mul = 0.2

    settings = []

    classes_count = 10

    if len(sys.argv) > 1:
        cfg_file = 'configs/' + sys.argv[1]
    else:
        cfg_file = 'configs/' + input("Config file:")
    cfg = get_cfg_defaults()
    cfg.merge_from_file(cfg_file)
    cfg.freeze()

    for fold in range(5 if full_run else 1):
        for i in range(classes_count):
            settings.append(dict(fold=fold, digit=i,cfg=cfg))

    gpu_count = utils.multiprocessing.get_gpu_count()

    results = utils.multiprocessing.map(f, gpu_count, settings)

    save_results(results, os.path.join(cfg.OUTPUT_FOLDER, cfg.RESULTS_NAME))

translate by google Translator
source text:
关于schedule.py的错误
运行schedule,尝试使用多线程时,会出现错误,原因是多线程争用。
按照错误提示修改代码,增加multiprocessing.freeze_support()可以解决这个问题。
修改后的代码如下:

Reduce hardcoded input and output shapes

In the train_AAE.py file a zsize of 32 is defined. But it is not consistently used. Sometimes the number 32 is used directly which makes it difficult to modify the code for different image sizes.

It would be helpful if you could clearly use variables designated channels, height, width and zsize (should that be different from height/width) instead of hardcoded values wherever channels, height, width or zsize are effectively used. This should be applied to the network definitions as well.

This would make it significantly easier to port your work to other datasets and/or Tensorflow/Keras.

Results on caltech-256 dataset

https://github.com/khalooei/ALOCC-CVPR2018 is along the same lines of your paper. In that paper, authors have reported state of the art results on caltech-256 dataset. But I failed to reproduce their results. I had tried various settings, architectures, hyperparameters. But accuracy never goes beyond 56% in authentic/fake classification. The authors of paper have not yet pushed their code. They seem to be busy and therefore did not divulge any specific details as well. I wanted to ask if you have tried your algorithm on caltech-256 dataset. My concern is that without any pretrained network, how can a GAN trained from scratch can reach accuracy of 90%. Did you ever try that?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.