GithubHelp home page GithubHelp logo

gsdmm's Introduction

GSDMM: Short text clustering

This project implements the Gibbs sampling algorithm for a Dirichlet Mixture Model of Yin and Wang 2014 for the clustering of short text documents. Some advantages of this algorithm:

  • It requires only an upper bound K on the number of clusters
  • With good parameter selection, the model converges quickly
  • Space efficient and scalable

This project is an easy to read reference implementation of GSDMM -- I don't plan to maintain it unless there is demand. I am however actively maintaining the much faster Rust version of GSDMM here.

The Movie Group Process

In their paper, the authors introduce a simple conceptual model for explaining the GSDMM called the Movie Group Process.

Imagine a professor is leading a film class. At the start of the class, the students are randomly assigned to K tables. Before class begins, the students make lists of their favorite films. The professor repeatedly reads the class role. Each time the student's name is called, the student must select a new table satisfying one or both of the following conditions:

  • The new table has more students than the current table.
  • The new table has students with similar lists of favorite movies.

By following these steps consistently, we might expect that the students eventually arrive at an "optimal" table configuration.

Usage

To use a Movie Group Process to cluster short texts, first initialize a MovieGroupProcess:

from gsdmm import MovieGroupProcess
mgp = MovieGroupProcess(K=8, alpha=0.1, beta=0.1, n_iters=30)

It's important to always choose K to be larger than the number of clusters you expect exist in your data, as the algorithm can never return more than K clusters.

To fit the model:

y = mgp.fit(docs)

Each doc in docs must be a unique list of tokens found in your short text document. This implementation does not support counting tokens with multiplicity (which generally has little value in short text documents).

gsdmm's People

Contributors

mgallaspy avatar rwalk avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

gsdmm's Issues

Slow clustering for large number of data

How can i increase the speed of clustering or run the code using GPU to increase speed. As i am clustering 500k data to 1000 clusters and the speed is extremely slow.

Clarification on Output Meaning

@rwalk I have read 'A Dirichlet Multinomial Mixture Model-based Approach for Short Text Clustering' paper, and not sure what transferred 914 and 85 clusters below mean. Any feedback would be appreciated.

In stage 0: transferred 914 clusters with 2 clusters populated
In stage 1: transferred 85 clusters with 2 clusters populated

Viz Error with Empty Clusters (pyLDAvis, Int64Index)

Have spent a couple days trying to resolve this on my own, and am at a loss/wanted to get your thoughts.

Issue:
In trying to visualize a model run that results in multiple empty clusters, I get the following error, that I assume (perhaps incorrectly) is being causes by there being duplicate entries that are empty, given the error is being thrown at points where pandas is trying to return unique entries. The error seems to be associated with doc_topic_dists2 but I'm not 100% sure of that either.

Alternatively, I thought maybe the issue is that many of the rows in doc_topic_dists2don't sum to 1. In my case 774 of the entries don't sum to exactly 1. I'm really not sure how to fix that without just pulling every entry that doesn't sum to 1 out of the dataset, which would mean filtering out about 1/3 of all my data.

#THE ERROR

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<string>", line 51, in <module>
  File "<string>", line 44, in prepare_visualization_data
  File "/Users/[user]/.local/share/virtualenvs/lasocial-d_HGkGzp/lib/python3.8/site-packages/pyLDAvis/_prepare.py", line 439, in prepare topic_info = _topic_info(topic_term_dists, topic_proportion,
  File "/Users/[user]/.local/share/virtualenvs/lasocial-d_HGkGzp/lib/python3.8/site-packages/pyLDAvis/_prepare.py", line 280, in _top
ic_info
    return pd.concat([default_term_info] + list(topic_dfs))
  File "/Users/[user]/.local/share/virtualenvs/lasocial-d_HGkGzp/lib/python3.8/site-packages/pyLDAvis/_prepare.py", line 264, in topi
c_top_term_df
    term_ix = topic_terms.unique()
  File "/Users/[user]/.local/share/virtualenvs/lasocial-d_HGkGzp/lib/python3.8/site-packages/pandas/core/series.py", line 1872, in unique
result = super().unique()
  File "/Users/[user]/.local/share/virtualenvs/lasocial-d_HGkGzp/lib/python3.8/site-packages/pandas/core/base.py", line 1047, in unique
   result = unique1d(values)
  File "/Users/[user]/.local/share/virtualenvs/lasocial-d_HGkGzp/lib/python3.8/site-packages/pandas/core/algorithms.py", line 407, in
 unique
    uniques = table.unique(values)
  File "pandas/_libs/hashtable_class_helper.pxi", line 4719, in pandas._libs.hashtable.PyObjectHashTable.unique
  File "pandas/_libs/hashtable_class_helper.pxi", line 4666, in pandas._libs.hashtable.PyObjectHashTable._unique
  File "/Users/[user]/.local/share/virtualenvs/lasocial-d_HGkGzp/lib/python3.8/site-packages/pandas/core/indexes/base.py", line 4273,
 in __hash__
    raise TypeError(f"unhashable type: {repr(type(self).__name__)}")
TypeError: unhashable type: 'Int64Index'

Background:
Tried resolving the issue using code from a previous/related issue (Issue #5), changing the code to reflect that I'm not using iPython/a notebook. The changes are below the point where the script is throwing an error, so I don't think they're the issue, but I've included the changed lines below to make sure.

def prepare_visualization_data(mgp):
    vis_data = pyLDAvis.prepare(*prepare_data(mgp), sort_topics=False)
    # with open(f"gsdmm-pyldavis-{mgp.K}-{mgp.alpha}-{mgp.beta}-{mgp.n_iters}.html", "w") as f:
    #     pyLDAvis.save_html(vis_data, f)
    return vis_data

...

vis_data = prepare_visualization_data(mgp)
pyLDAvis.save_html(vis_data, 'sttmChart.html')  # path to output

What I've tried:
I sadly didn't catalog everything I tried in detail, but my main two efforts were as follows:

  1. First tried just converting doc_topic_dists2 to a data frame as follows below. This however just resulted in a different error (TypeError: '<' not supported between instances of 'list' and 'float') that I spent some time trying to work around but couldn't find a solution for.
#how I tried to convert to dataframe

...

PDX = []
for i in doc_topic_dists2:
    PDX.append({'data':i})

docPD = pd.DataFrame(PDX)

#then... 
vis_data = pyLDAvis.prepare(matrix, docPD.data, doc_lengths, vocabulary, term_counts)

...
  1. Tried to pull duplicates out of doc_topic_dists2 and then filtering those duplicates out of doc_lengths as well so those two lists would still match up to each other. Still ran into the same Int64Index error, however.
...
d3 = pd.DataFrame(doc_topic_dists2)
hmm = d3[d3.duplicated()]
hmmList = hmm.index.values.tolist()
for i in hmmList:
    doc_lengths.pop(i)
    doc_topic_dists2.pop(i)

#then run matrix part of the code... 

So I'm pretty much at a loss here. Would appreciate anyone's thoughts/ideas. Would welcome any thoughts on other ways to interactively visualize as well, as I really just need this to help my team get their head around the clustering.

Thanks for your help!

Saving gsdmm model

Once the model is trained, how do I save it for future classification?

Predict new documents

HI.
I am interested in using your library, but so far this library only seems to be able to assign topics to sentences in its training set. Is there a way to predict another subset different from training set?

Thank you.

IndexError: only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and integer or boolean arrays are valid indices

from gsdmm.gsdmm import MovieGroupProcess

docs=[['house',
'burning',
'need',
'fire',
'truck',
'ml',
'hindu',
'response',
'christian',
'conversion',
'alm']]

mgp = MovieGroupProcess(K=10, alpha=0.1, beta=0.1, n_iters=30)

vocab = set(x for doc in docs for x in doc)
n_terms = len(vocab)
n_docs = len(docs)

Fit the model on the data given the chosen seeds

y = mgp.fit(docs, n_terms)

The above code was very well working 7 days back. Suddenly we are encountering the error. Kindly advise.

Missing Packages

where do we find the packages:
from topic_allocation import top_words, topic_attribution

Score function assigns wrong topic to texts

Hello,
I'm using GSDMM to model topics of some short text corpus. After the model is trained, I obtain quite a strange behavior. When I try to run a score function on some text, it assigns it to a topic that is irrelevant to it. Most of times.
For example, list(enumerate(mgp.score("exception"))) gives:

[(0, 5.408052732645677e-10),
 (1, 2.2084257143830565e-05),
 (2, 0.022683564559983955),
 (3, 0.0009818820143112778),
 (4, 0.5482598416951635),         // highest probability is assigned to the 4th topic
 (5, 0.1022196373411148),
 (6, 0.049452600413189024),
 (7, 1.710484812050781e-06),
 (8, 0.004729435817210187),
 (9, 8.314540183290251e-06),
 (10, 0.13690067779136778),
 (11, 0.00018238304274319198),
 (12, 5.3021228234619835e-05),
 (13, 0.010612516369262397),
 (14, 7.432649699440398e-05),
 (15, 0.004651235422841561),
 (16, 0.001587346363887785),
 (17, 0.005119983861692937),
 (18, 6.710854911425749e-05),
 (19, 0.11239232920994387)]

while the word occurrence is quite different. The command list(enumerate([mgp.cluster_word_distribution[i].get('exception', 0) for i in range(K)])) gives me:

[(0, 108),
 (1, 0),
 (2, 13),
 (3, 0),
 (4, 13),
 (5, 0),
 (6, 5),
 (7, 3),
 (8, 0),
 (9, 11),
 (10, 14),
 (11, 0),
 (12, 0),
 (13, 180), // Actual maximum is here, and it's much more relevant topic for this word
 (14, 5),
 (15, 3),
 (16, 7),
 (17, 29),
 (18, 0),
 (19, 4)]

@rwalk do you have any idea about this behavior?

Unexpected argument for the fit method.

Hello!
Thanks for the awesome work.

I've tried to run the model on a list of very short documents (6 tokens on average per document), but I've been asked for a "vocab_size" argument when trying to fit it.

What's the role of that argument? Is it the vocabulary size as in most document-term matrices?

Thank you very much!

Guido

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.