GithubHelp home page GithubHelp logo

pt-dec's Introduction

pt-dec

Build Status codecov Codacy Badge

PyTorch implementation of a version of the Deep Embedded Clustering (DEC) algorithm. Compatible with PyTorch 1.0.0 and Python 3.6 or 3.7 with or without CUDA.

This follows (or attempts to; note this implementation is unofficial) the algorithm described in "Unsupervised Deep Embedding for Clustering Analysis" of Junyuan Xie, Ross Girshick, Ali Farhadi (https://arxiv.org/abs/1511.06335).

Examples

An example using MNIST data can be found in the examples/mnist/mnist.py which achieves around 85% accuracy.

Here is an example confusion matrix, true labels on y-axis and predicted labels on the x-axis.

Alt text

Usage

This is distributed as a Python package ptdec and can be installed with python setup.py install after installing ptsdae from https://github.com/vlukiyanov/pt-sdae. The PyTorch nn.Module class representing the DEC is DEC in ptdec.dec, while the train function from ptdec.model is used to train DEC.

Other implementations of DEC

pt-dec's People

Contributors

codacy-badger avatar dymil avatar vlukiyanov avatar xingzhizhou avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

pt-dec's Issues

accuracy is low

Hi, I am getting an accuracy of 21% on MNSIT dataset after running your code. Is the problem with the code or there is something wrong with my implementation? Can you confirm that you are getting better accuracy with the same code?

Why use cashedMnist?

Hello vlukiyanov.

First, thanks for your kind code.
I just have a question about the dataset you implemented.
There are a lot of dataset classes for mnist using pytorch.
But I don't see this style. I mean transform byte.
Why you use this cachedMnist?

Clusters get worse after running DEC

I'm experiencing a strange issue, where running DEC makes the clustering worse. At first I encountered it in my project with real data, but to debug the issue, I managed to reproduce it with synthetic data as well.

In the attached notebook (SyntheticDEC.zip), I generate 1000 points in R^2 using a multivariate normal distribution, in two clusters. I then train an autoencoder to encode to 2 dimensions, then decode back the input. Finally, I run DEC for 100 epochs. The problem can be reproduced by running the notebook which has no external dependencies.

Initial hidden state:

image

After 100 epochs of DEC:

image

After 200 epochs:

image

Loss jumping up and down for pretain and fine tuning

Hi I am trying out your model with another dataset (NTU RGB-D 60) but the losses for pretain and fine tuning are acting weirdly. First they starts to decrease for some epochs then jumps up then starts decreasing again. This goes on for the entire epochs. I have changed optimizer from SGD to Adam.

cluster_centers not updating

Hi vlukiyanov,

Thank you for your pytorch implement. It has been a big help. However, I noticed that in this version, the cluster_centers are not updating during training.

This is the code I used to test:

        a = list(model.parameters())[-2].clone()  # hidden layer
        b = list(model.parameters())[-1].clone()  # cluster centers
        optimizer.zero_grad()
        loss.backward()
        optimizer.step(closure=None)
        print (torch.equal(a, list(model.parameters())[-2]), torch.equal(b, list(model.parameters())[-1]))

It returns False, True

I am very new to pytorch, but I think this might be a bug from pytorch. I checked there's gradient
for cluster centers but somehow the parameters are not updating.

Maybe there's something I don't understand about pytorch, but I just wanted to point this out to you.

Soft assignment different from paper

Hi, vlukiyanov,

Thanks for your great pytorch implementation of DEC. However, I noticed there is a place where your implementation different from the original paper.

In the cluster.py class, ClusterAssignment.forward() function, your implementation is that:
`def forward(self, batch: torch.Tensor) :

    norm_squared = torch.sum((batch.unsqueeze(1) - self.cluster_centers) ** 2, 2)
    numerator = 1.0 / (1.0 + (norm_squared / self.alpha))
    power = float(self.alpha + 1) / 2
    numerator = numerator**power
    return numerator / torch.sum(numerator, dim=1, keepdim=True)`

However, in the original paper, there is a "minus" in the power term.
image

Would you please explain why this is different from the original paper?

Regards,
Bowen

Working with color image data

Hey, thank you so much for your awesome implementation! I have one problem that I could use some help with. How would you get this to work with color images, such as STL-10?

In the paper I see that they mention "concatenating HOG feature and a 8-by-8 color map to use as input to all algorithms." Unfortunately, I'm finding the original code a bit difficult to sift through...

Do you have any recommendations on how to proceed?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.