Comments (1)
Hi, sorry for the delay.
A few comments:
-
In the fit_mle call, you pass
rand(model, 10, 241)
as the data. It is a 2D-array (10x241) buthmm
is a univariate HMM. You probably want to passdata = rand(model, 10000)
? (On a side-note, I should raise a warning/error in fit_mle if the data size is inconsistent). -
In the mixture model, you have a gaussian component with a 0 variance (
Normal(0.0, 0.0)
). This can cause a problem in the ML estimation since the density of the normal distribution is undefined when σ = 0 (in particular in [1]). One solution is to setrobust = true
to truncate non-finite values to small or large, but representable, floating-point numbers (see [2]). Another, possibly cleaner, solution is to use a custom estimation function with a prior on the variance that prevents it to go to 0 (see [3]).
Aside from the two comments above, the main issue seems to be that log.(x)
in Distributions.fit_mle(D::Type{LogNormal{T}}, x, w)
is called with potentially negative values in x
.
The w
vector contains the probabilities of belonging (to the current component) for each observations. If all goes well, the values outside of the support of the log-normal distribution (R+
) should be assigned a probability of belonging of zero.
A quick-fix is to replace the observations with a zero-probability with some arbitrary positive number (they will not be used in the computation of the mean/std. deviation anyway):
function Distributions.fit_mle(D::Type{LogNormal{T}}, x::AbstractMatrix, w::AbstractVector) where {T}
# Assign some dummy value to observations with a zero-probability of belonging to the current component.
# This will prevent log.(x) from raising an error.
x[w .== 0] .= 1.0
# There is no constructor of LogNormal from Normal.
# Let's do it by hand.
d = fit_mle(Normal{T}, log.(x), w)
LogNormal(d.μ, d.σ)
end
All-in-all:
using HMMBase #v1.0.6
using Distributions
using Random
# adapted from https://github.com/maxmouchet/HMMBase.jl/issues/25
# to fix suffstats error
function Distributions.fit_mle(D::Type{LogNormal{T}}, x::AbstractMatrix, w::AbstractVector) where {T}
# Assign some dummy value to observations with a zero-probability of belonging to the current component.
# This will prevent log.(x) from raising an error.
x[w .== 0] .= 1.0
# There is no constructor of LogNormal from Normal.
# Let's do it by hand.
d = fit_mle(Normal{T}, log.(x), w)
LogNormal(d.μ, d.σ)
end
# Use a non-zero standard deviation for the first Normal component.
model = MixtureModel([Normal(0.0, 1.0), LogNormal(7.9, 0.49)], [0.65, 0.35])
Random.seed!(123)
data=rand(model, 10000)
# Sanity check, we use the same parameters as for the mixture model.
# We should find back similar parameters.
hmm = HMM([0.65 0.35; 0.35 0.65], [
Normal(0.0, 1.0),
LogNormal(7.9, 0.49)
])
fit_mle(hmm, data, display=:iter)
# [...]
# Normal{Float64}(μ=0.9023259685700078, σ=0.4416361098258359)
# LogNormal{Float64}(μ=7.89683177821269, σ=0.4916590830787966)
# => Looks OK.
# Try some other parameters
hmm = HMM([0.65 0.35; 0.35 0.65], [
Normal(2.0, 0.5),
LogNormal(1.0, 1.0)
])
fit_mle(hmm, data, display=:iter)
# [...]
# Normal{Float64}(μ=0.9023259685700078, σ=0.4416361098258359)
# LogNormal{Float64}(μ=7.89683177821269, σ=0.4916590830787966)
# => Looks OK.
# Some other parameters
hmm = HMM([0.65 0.35; 0.35 0.65], [
Normal(10.0, 0.5),
LogNormal(1.0, 5.0)
])
fit_mle(hmm, data, display=:iter)
# CheckError: isprobvec(hmm.a) must hold. Got
# hmm.a => [NaN, NaN]
# Let's use a custom estimator with a prior on the variance for the Normal distribution.
# (See https://maxmouchet.github.io/HMMBase.jl/stable/examples/fit_map/)
import ConjugatePriors: InverseGamma, NormalKnownMu, posterior_canon
import StatsBase: Weights
function fit_map(::Type{<:Normal}, x, w)
# Empirical mean
μ = mean(x, Weights(w))
# Prior, posterior, and mode of the variance
ss = suffstats(NormalKnownMu(μ), x, w)
prior = InverseGamma(2, 1)
posterior = posterior_canon(prior, ss)
σ2 = mode(posterior)
Normal(μ, sqrt(σ2))
end
function fit_map(::Type{<:LogNormal}, x, w)
# Assign some dummy value to observations with a zero-probability of belonging to the current component.
# This will prevent log.(x) from raising an error.
x[w .== 0] .= 1.0
x = log.(x)
# Empirical mean
μ = mean(x, Weights(w))
# Prior, posterior, and mode of the variance
ss = suffstats(NormalKnownMu(μ), x, w)
prior = InverseGamma(2, 1)
posterior = posterior_canon(prior, ss)
σ2 = mode(posterior)
LogNormal(μ, sqrt(σ2))
end
fit_mle(hmm, data, display=:iter, estimator=fit_map)
# [...]
# Normal{Float64}(μ=0.9998807992982841, σ=0.027344693249784436)
# LogNormal{Float64}(μ=3.9545626227049397, σ=4.3562804894699365)
# => Converge, but we do not find back the original parameters.
I hope this helps! Let me know if something is not clear, or doesn't work.
In the end, the ML estimation algorithm for the HMM (the Baum-Welch algorithm) is very sensitive to the initial parameters (since it only finds a local maxima of the likelihood function), so your best chance is to have reasonable estimate for these parameters. Otherwise you can resort to Bayesian inference (see https://github.com/TuringLang/Turing.jl) which is not affected by local maxima issues (but is much slower).
[1] https://github.com/maxmouchet/HMMBase.jl/blob/master/src/mle.jl#L128
[2] https://maxmouchet.github.io/HMMBase.jl/stable/examples/numerical_stability/
[3] https://maxmouchet.github.io/HMMBase.jl/stable/examples/fit_map/
from hmmbase.jl.
Related Issues (20)
- Error in mle.jl? HOT 1
- Benchmarks vs MS_HMMBase & question about messages_backwards_log
- 2 state and 3 observation HMM HOT 3
- Multiple sequence with different length HOT 17
- Stable documentation is not up-to-date HOT 1
- HMM with observations as probabilities HOT 13
- viterbi(hmm, y) got "ERROR: BoundsError: attempt to access T×5 Array{Int64,2} at index [T-1, 0]" HOT 2
- Unable to fit using Multivariate LogNormal distribution HOT 2
- I can not recover the original parameters HOT 1
- Product of Discrete, Bernoulli HOT 1
- kmeans initialization doesn't support a user defined estimator
- Possible error in `fit_mle!`?
- Compat for Distributions.jl v0.25? HOT 1
- TagBot trigger issue HOT 4
- Improve documentation for novices HOT 1
- Example for multivariate features GMMHMM, include in docs HOT 1
- Application for new maintainer HOT 6
- Implement serialization/deserialization HOT 1
- Em Algorithm Speedup HOT 3
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from hmmbase.jl.