GithubHelp home page GithubHelp logo

Comments (1)

maxmouchet avatar maxmouchet commented on May 24, 2024 1

Hi, sorry for the delay.

A few comments:

  1. In the fit_mle call, you pass rand(model, 10, 241) as the data. It is a 2D-array (10x241) but hmm is a univariate HMM. You probably want to pass data = rand(model, 10000) ? (On a side-note, I should raise a warning/error in fit_mle if the data size is inconsistent).

  2. In the mixture model, you have a gaussian component with a 0 variance (Normal(0.0, 0.0)). This can cause a problem in the ML estimation since the density of the normal distribution is undefined when σ = 0 (in particular in [1]). One solution is to set robust = true to truncate non-finite values to small or large, but representable, floating-point numbers (see [2]). Another, possibly cleaner, solution is to use a custom estimation function with a prior on the variance that prevents it to go to 0 (see [3]).

Aside from the two comments above, the main issue seems to be that log.(x) in Distributions.fit_mle(D::Type{LogNormal{T}}, x, w) is called with potentially negative values in x.

The w vector contains the probabilities of belonging (to the current component) for each observations. If all goes well, the values outside of the support of the log-normal distribution (R+) should be assigned a probability of belonging of zero.
A quick-fix is to replace the observations with a zero-probability with some arbitrary positive number (they will not be used in the computation of the mean/std. deviation anyway):

function Distributions.fit_mle(D::Type{LogNormal{T}}, x::AbstractMatrix, w::AbstractVector) where {T}
    # Assign some dummy value to observations with a zero-probability of belonging to the current component.
    # This will prevent log.(x) from raising an error.
    x[w .== 0] .= 1.0
    # There is no constructor of LogNormal from Normal.
    # Let's do it by hand.
    d = fit_mle(Normal{T}, log.(x), w)
    LogNormal(d.μ, d.σ)
end

All-in-all:

using HMMBase #v1.0.6
using Distributions
using Random

# adapted from https://github.com/maxmouchet/HMMBase.jl/issues/25
# to fix suffstats error
function Distributions.fit_mle(D::Type{LogNormal{T}}, x::AbstractMatrix, w::AbstractVector) where {T}
    # Assign some dummy value to observations with a zero-probability of belonging to the current component.
    # This will prevent log.(x) from raising an error.
    x[w .== 0] .= 1.0
    # There is no constructor of LogNormal from Normal.
    # Let's do it by hand.
    d = fit_mle(Normal{T}, log.(x), w)
    LogNormal(d.μ, d.σ)
end

# Use a non-zero standard deviation for the first Normal component.
model = MixtureModel([Normal(0.0, 1.0), LogNormal(7.9, 0.49)], [0.65, 0.35])

Random.seed!(123)
data=rand(model, 10000)
# Sanity check, we use the same parameters as for the mixture model.
# We should find back similar parameters.
hmm = HMM([0.65 0.35; 0.35 0.65], [
        Normal(0.0, 1.0), 
        LogNormal(7.9, 0.49)
    ])

fit_mle(hmm, data, display=:iter)
# [...]
# Normal{Float64}(μ=0.9023259685700078, σ=0.4416361098258359)
# LogNormal{Float64}(μ=7.89683177821269, σ=0.4916590830787966)
# => Looks OK.
# Try some other parameters
hmm = HMM([0.65 0.35; 0.35 0.65], [
        Normal(2.0, 0.5), 
        LogNormal(1.0, 1.0)
    ])

fit_mle(hmm, data, display=:iter)
# [...]
# Normal{Float64}(μ=0.9023259685700078, σ=0.4416361098258359)
# LogNormal{Float64}(μ=7.89683177821269, σ=0.4916590830787966)
# => Looks OK.
# Some other parameters
hmm = HMM([0.65 0.35; 0.35 0.65], [
        Normal(10.0, 0.5), 
        LogNormal(1.0, 5.0)
    ])

fit_mle(hmm, data, display=:iter)
# CheckError: isprobvec(hmm.a) must hold. Got
# hmm.a => [NaN, NaN]

# Let's use a custom estimator with a prior on the variance for the Normal distribution.
# (See https://maxmouchet.github.io/HMMBase.jl/stable/examples/fit_map/)
import ConjugatePriors: InverseGamma, NormalKnownMu, posterior_canon
import StatsBase: Weights

function fit_map(::Type{<:Normal}, x, w)
    # Empirical mean
    μ = mean(x, Weights(w))

    # Prior, posterior, and mode of the variance
    ss = suffstats(NormalKnownMu(μ), x, w)
    prior = InverseGamma(2, 1)
    posterior = posterior_canon(prior, ss)
    σ2 = mode(posterior)

    Normal(μ, sqrt(σ2))
end

function fit_map(::Type{<:LogNormal}, x, w)
    # Assign some dummy value to observations with a zero-probability of belonging to the current component.
    # This will prevent log.(x) from raising an error.
    x[w .== 0] .= 1.0
    x = log.(x)

    # Empirical mean
    μ = mean(x, Weights(w))

    # Prior, posterior, and mode of the variance
    ss = suffstats(NormalKnownMu(μ), x, w)
    prior = InverseGamma(2, 1)
    posterior = posterior_canon(prior, ss)
    σ2 = mode(posterior)

    LogNormal(μ, sqrt(σ2))
end

fit_mle(hmm, data, display=:iter, estimator=fit_map)
# [...]
# Normal{Float64}(μ=0.9998807992982841, σ=0.027344693249784436)
# LogNormal{Float64}(μ=3.9545626227049397, σ=4.3562804894699365)
# => Converge, but we do not find back the original parameters.

I hope this helps! Let me know if something is not clear, or doesn't work.

In the end, the ML estimation algorithm for the HMM (the Baum-Welch algorithm) is very sensitive to the initial parameters (since it only finds a local maxima of the likelihood function), so your best chance is to have reasonable estimate for these parameters. Otherwise you can resort to Bayesian inference (see https://github.com/TuringLang/Turing.jl) which is not affected by local maxima issues (but is much slower).

[1] https://github.com/maxmouchet/HMMBase.jl/blob/master/src/mle.jl#L128
[2] https://maxmouchet.github.io/HMMBase.jl/stable/examples/numerical_stability/
[3] https://maxmouchet.github.io/HMMBase.jl/stable/examples/fit_map/

from hmmbase.jl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.