GithubHelp home page GithubHelp logo

dmetivie / expectationmaximization.jl Goto Github PK

View Code? Open in Web Editor NEW
32.0 2.0 1.0 1.09 MB

A simple but generic implementation of Expectation Maximization algorithms to fit mixture models.

Home Page: https://dmetivie.github.io/ExpectationMaximization.jl/

License: MIT License

Julia 100.00%
julia expectation-maximization mixture-models gaussian-mixture-models clustering

expectationmaximization.jl's Introduction

ExpectationMaximization

Docs

This package provides a simple implementation of the Expectation Maximization (EM) algorithm used to fit mixture models. Due to Julia amazing dispatch systems, generic and reusable code spirit, and the Distributions.jl package, the code while being very generic is both very expressive and fast! (Have a look at the Benchmark section)

What type of mixtures?

In particular, it works on a lot of mixtures:

  • Mixture of Univariate continuous distributions
  • Mixture of Univariate discrete distributions
  • Mixture of Multivariate distributions (continuous or discrete)
  • Mixture of mixtures (univariate or multivariate and continuous or discrete)
  • More?

What EM algorithm?

So far the classic EM algorithm and the Stochastic EM are implemented. Look at the Bibliography section for references.

How?

Just define a mix::MixtureModel and do fit_mle(mix, y) where y is you observation array (vector or matrix). That's it! For Stochastic EM, just do fit_mle(mix, y, method = StochasticEM()). Have a look at the Examples section.

To work, the only requirements are that the components of the mixture dist ∈ dists = components(mix) considered (custom or coming from an existing package)

  1. Are a subtype of Distribution i.e. dist<:Distribution.
  2. The logpdf(dist, y) is defined (it is used in the E-step)
  3. The fit_mle(dist, y, weigths) returns the distribution with parameters equals to MLE. This is used in the M-step of the ClassicalEM algorithm. For the StocasticEM version, only fit_mle(dist, y) is needed. Type or instance version of fit_mle for your dist are accepted thanks to this conversion line.

TODO (feel free to contribute)

[] Add more variants to of the EM algorithm (so far there are the classic and stochastic version).

[] Better benchmark against other EM implementations

[] Speed up code (always!). So far, I focused on readable code.

Example

Also have a look at the [examples](@ref Examples) section.

using Distributions
using ExpectationMaximization

Model

N = 50_000
θ₁ = 10
θ₂ = 5
α = 0.2
β = 0.3
# Mixture Model here one can put any classical distributions
mix_true = MixtureModel([Exponential(θ₁), Gamma(α, θ₂)], [β, 1 - β]) 

# Generate N samples from the mixture
y = rand(mix_true, N) 

Inference

# Initial guess
mix_guess = MixtureModel([Exponential(1), Gamma(0.5, 1)], [0.5, 1 - 0.5])

# Fit the MLE with the EM algorithm
mix_mle = fit_mle(mix_guess, y; display = :iter, atol = 1e-3, robust = false, infos = false)

Verify results

rtol = 5e-2
p = params(mix_mle)[1] # (θ₁, (α, θ₂))
isapprox(β, probs(mix_mle)[1]; rtol = rtol)
isapprox(θ₁, p[1]...; rtol = rtol)
isapprox(α, p[2][1]; rtol = rtol)
isapprox(θ₂, p[2][2]; rtol = rtol)

EM_mixture_example.svg

expectationmaximization.jl's People

Contributors

dmetivie avatar timholy avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

Forkers

timholy

expectationmaximization.jl's Issues

TagBot trigger issue

This issue is used to trigger TagBot; feel free to unsubscribe.

If you haven't already, you should update your TagBot.yml to include issue comment triggers.
Please see this post on Discourse for instructions and more details.

If you'd like for me to do this for you, comment TagBot fix on this issue.
I'll open a PR within a few hours, please be patient!

Interface compatibility with Distributions.jl ?

Hey,

Thanks for this great addition to the ecosystem !

From Distribution.jl, it seems like the first argument to the fit_mle function should be the distributions type and not an instance of the type :

julia> fit_mle(Gamma,rand(1000))
Gamma{Float64}(α=1.604973623956157, θ=0.3097630701718876)

julia> fit_mle(Gamma,rand(100))
Gamma{Float64}(α=1.6985042802071995, θ=0.31065614192888746)

julia> fit_mle(Gamma(1,1),rand(100))
ERROR: MethodError: no method matching fit_mle(::Gamma{Float64}, ::Vector{Float64})
Closest candidates are:
  fit_mle(::Type{<:LogNormal}, ::AbstractArray{T}) where T<:Real at C:\Users\lrnv\.julia\packages\Distributions\bQ6Gj\src\univariate\continuous\lognormal.jl:163
  fit_mle(::Type{<:Weibull}, ::AbstractArray{<:Real}; alpha0, maxiter, tol) at C:\Users\lrnv\.julia\packages\Distributions\bQ6Gj\src\univariate\continuous\weibull.jl:145
  fit_mle(::Type{<:Beta}, ::AbstractArray{T}; maxiter, tol) where T<:Real at C:\Users\lrnv\.julia\packages\Distributions\bQ6Gj\src\univariate\continuous\beta.jl:217
  ...
Stacktrace:
 [1] top-level scope
   @ REPL[14]:1

julia> 

This is not really a problem for yo as you are free to overload this function as you want, and your interface actually makes a lot of sense since you exploit the guesses in your algorithm. But would it be possible to add methods following this convention, maybe with automatic guesses ? I have fit_mle bindings in Copulas.jl that assume this convention, and thus do not work directly with your package :(

Edit: I was trying to make a code example of what i would like, but I saw that mixures types do not include components types... More specifically, I would like to be able to type :

fit_mle(MixtureModel{Gamma,Gamma,Normal},data)

instead of

fit_mle(MixtureModel([Gamma(),Gamma(),Normal()],[1/3 1/3 1/3]),data)

Would that be possible ?

It would allow composability, as I am currently using :

using Copulas, Distributions, ExpectationMaximization, Random
X₁ = MixtureModel([Gamma(2,3), LogNormal(1,1)],[1/2,1/2])
X₂ = Pareto()
X₃ = LogNormal(0,1)
C = ClaytonCopula(3,0.7) # A 3-variate Frank Copula with θ = 0.7
D = SklarDist(C,(X₁,X₂,X₃)) # The final distribution

# This generates a (3,1000)-sized dataset from the multivariate distribution D
simu = rand(D,1000)

D̂ = fit(SklarDist{FrankCopula,Tuple{Gamma,Normal,LogNormal}}, simu) # works
# But how can i specify that i want a mixture for one of the variables ? 

which, under the hood, calls fit_mle(Marginal_Type,marginal_data) on each marignals.

update doc/readme with common pitfalls & errors

  • Bad initilization
  • Convergence failure #11 and #12
  • Undefined weighted method like #9
  [2] suffstats(::Type{Beta{Float64}}, ::Vector{Float64}, ::Vector{Float64})
    @ Distributions C:\Users\metivier\.julia\packages\Distributions\SUTV1\src\genericfit.jl:5
  [3] fit_mle(dt::Type{Beta{Float64}}, x::Vector{Float64}, w::Vector{Float64})

Fitting a beta mixture fails with no method found

Hi @dmetivie ,

first of all thanks for this Julia package, and apologies if I make some very naive mistakes here, since this is one of my first meddling with Julia ever...

I am trying to fit a mixture of two beta distributions, but it does not find suffstats in this case:

using ExpectationMaximization

# Try to fit a beta mixture.
N = 50_000
α₁ = 10
β₁ = 5
α₂ = 5
β₂ = 10
π = 0.3

# Mixture Model of two betas.
mix_true = MixtureModel([Beta(α₁, β₁), Beta(α₂, β₂)], [π, 1 - π]) 

# Generate N samples from the mixture.
y = rand(mix_true, N)
histogram(y)

# Initial guess.
mix_guess = MixtureModel([Beta(1, 1), Beta(1, 1)], [0.5, 1 - 0.5])
test = rand(mix_guess, N)

# Fit the MLE with the EM algorithm:
mix_mle = fit_mle(mix_guess, y)
# ERROR: suffstats is not implemented for (Beta{Float64}, Vector{Float64}, Vector{Float64}).

My status:

(@v1.9) pkg> status
Status `~/.julia/environments/v1.9/Project.toml`
  [336ed68f] CSV v0.10.11
  [a93c6f00] DataFrames v1.6.1
  [31c24e10] Distributions v0.25.103
  [e1fe09cc] ExpectationMaximization v0.2.2
  [f3b207a7] StatsPlots v0.15.6
  [fce5fe82] Turing v0.29.3

Handling dropouts

In cases of poor initialization, some components of the mixture may drop out. For example, let's create a 2-component mixture that is very poorly initialized:

julia> X = randn(10);

julia> mix = MixtureModel([Normal(100, 0.001), Normal(200, 0.001)], [0.5, 0.5]);

julia> logpdf.(components(mix), X')
2×10 Matrix{Float64}:
 -4.92479e9   -4.97741e9   -5.02964e9   -5.15501e9   -5.05792e9     -5.16391e9   -4.88617e9   -4.93348e9   -5.09162e9
 -1.98493e10  -1.99548e10  -2.00592e10  -2.03088e10  -2.01157e10     -2.03265e10  -1.97717e10  -1.98667e10  -2.01828e10

You can see that both have poor likelihood, but one of the two always loses by a very large margin. Then when we go to optimize,

julia> fit_mle(mix, X)
ERROR: DomainError with NaN:
Normal: the condition σ >= zero(σ) is not satisfied.
Stacktrace:
  [1] #371
    @ ~/.julia/dev/Distributions/src/univariate/continuous/normal.jl:37 [inlined]
  [2] check_args
    @ ~/.julia/dev/Distributions/src/utils.jl:89 [inlined]
  [3] #Normal#370
    @ ~/.julia/dev/Distributions/src/univariate/continuous/normal.jl:37 [inlined]
  [4] Normal
    @ ~/.julia/dev/Distributions/src/univariate/continuous/normal.jl:36 [inlined]
  [5] fit_mle
    @ ~/.julia/dev/Distributions/src/univariate/continuous/normal.jl:229 [inlined]
  [6] fit_mle(::Type{Normal{Float64}}, x::Vector{Float64}, w::Vector{Float64}; mu::Float64, sigma::Float64)
    @ Distributions ~/.julia/dev/Distributions/src/univariate/continuous/normal.jl:256
  [7] fit_mle
    @ ~/.julia/dev/Distributions/src/univariate/continuous/normal.jl:253 [inlined]
  [8] fit_mle
    @ ~/.julia/dev/ExpectationMaximization/src/that_should_be_in_Distributions.jl:17 [inlined]
  [9] (::ExpectationMaximization.var"#2#3"{Vector{Normal{Float64}}, Vector{Float64}, Matrix{Float64}})(k::Int64)
    @ ExpectationMaximization ./none:0
 [10] iterate(::Base.Generator{Vector{Any}, DualNumbers.var"#1#3"})
    @ Base ./generator.jl:47 [inlined]
 [11] collect_to!(dest::AbstractArray{T}, itr::Any, offs::Any, st::Any) where T
    @ Base ./array.jl:890 [inlined]
 [12] collect_to_with_first!(dest::AbstractArray, v1::Any, itr::Any, st::Any)
    @ Base ./array.jl:868 [inlined]
 [13] collect(itr::Base.Generator{UnitRange{Int64}, ExpectationMaximization.var"#2#3"{Vector{…}, Vector{…}, Matrix{…}}})
    @ Base ./array.jl:842
 [14] fit_mle!::Vector{…}, dists::Vector{…}, y::Vector{…}, method::ClassicEM; display::Symbol, maxiter::Int64, atol::Float64, robust::Bool)
    @ ExpectationMaximization ~/.julia/dev/ExpectationMaximization/src/classic_em.jl:48
 [15] fit_mle!
    @ ~/.julia/dev/ExpectationMaximization/src/classic_em.jl:14 [inlined]
 [16] fit_mle(::MixtureModel{…}, ::Vector{…}; method::ClassicEM, display::Symbol, maxiter::Int64, atol::Float64, robust::Bool,
 infos::Bool)
    @ ExpectationMaximization ~/.julia/dev/ExpectationMaximization/src/fit_em.jl:30
 [17] fit_mle(::MixtureModel{Univariate, Continuous, Normal{Float64}, Categorical{Float64, Vector{Float64}}}, ::Vector{Float64})
    @ ExpectationMaximization ~/.julia/dev/ExpectationMaximization/src/fit_em.jl:12
 [18] top-level scope
    @ REPL[8]:1
Some type information was truncated. Use `show(err)` to see complete types.

This arises because α[:] = mean(γ, dims = 1) returns α = [1.0, 0.0]. In other words, component 2 of the mixture "drops out."

I've found errors like these, as well as positive-definiteness errors in a multivariate context, to be pretty ubiquitous when fitting complicated distributions and point-clouds. To me it seems we'd need to come up with some kind of guard against this behavior? But I'm not sure what the state-of-the-art approach is, or I'd implement it.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.