From 1656f7d5624717cd90c3047d5da041afc64c157b Mon Sep 17 00:00:00 2001 From: Tim Hargreaves <38204689+THargreaves@users.noreply.github.com> Date: Wed, 2 Oct 2024 15:36:50 +0100 Subject: [PATCH] Decompose forward function into initialize, predict, update (#105) * Decompose forward function into initialize, predict, update * Fixes * Remove unused predict --------- Co-authored-by: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> --- src/HiddenMarkovModels.jl | 1 + src/inference/forward.jl | 64 +++++++++++++++++++++------------------ src/inference/predict.jl | 10 ++++++ src/types/abstract_hmm.jl | 6 ++++ 4 files changed, 51 insertions(+), 30 deletions(-) create mode 100644 src/inference/predict.jl diff --git a/src/HiddenMarkovModels.jl b/src/HiddenMarkovModels.jl index ecfd99b..aa319be 100644 --- a/src/HiddenMarkovModels.jl +++ b/src/HiddenMarkovModels.jl @@ -38,6 +38,7 @@ include("utils/lightdiagnormal.jl") include("utils/lightcategorical.jl") include("utils/limits.jl") +include("inference/predict.jl") include("inference/forward.jl") include("inference/viterbi.jl") include("inference/forward_backward.jl") diff --git a/src/inference/forward.jl b/src/inference/forward.jl index 27a2200..615c70a 100644 --- a/src/inference/forward.jl +++ b/src/inference/forward.jl @@ -63,6 +63,27 @@ function initialize_forward( return ForwardStorage(α, logL, B, c) end +function _forward_digest_observation!( + current_state_marginals::AbstractVector{<:Real}, + current_obs_likelihoods::AbstractVector{<:Real}, + hmm::AbstractHMM, + obs, + control, +) + a, b = current_state_marginals, current_obs_likelihoods + + obs_logdensities!(b, hmm, obs, control) + logm = maximum(b) + b .= exp.(b .- logm) + + a .*= b + c = inv(sum(a)) + lmul!(c, a) + + logL = -log(c) + logm + return c, logL +end + function _forward!( storage::ForwardOrForwardBackwardStorage, hmm::AbstractHMM, @@ -73,36 +94,19 @@ function _forward!( ) (; α, B, c, logL) = storage t1, t2 = seq_limits(seq_ends, k) - - # Initialization - Bₜ₁ = view(B, :, t1) - obs_logdensities!(Bₜ₁, hmm, obs_seq[t1], control_seq[t1]) - logm = maximum(Bₜ₁) - Bₜ₁ .= exp.(Bₜ₁ .- logm) - - init = initialization(hmm) - αₜ₁ = view(α, :, t1) - αₜ₁ .= init .* Bₜ₁ - c[t1] = inv(sum(αₜ₁)) - lmul!(c[t1], αₜ₁) - - logL[k] = -log(c[t1]) + logm - - # Loop - for t in t1:(t2 - 1) - Bₜ₊₁ = view(B, :, t + 1) - obs_logdensities!(Bₜ₊₁, hmm, obs_seq[t + 1], control_seq[t + 1]) - logm = maximum(Bₜ₊₁) - Bₜ₊₁ .= exp.(Bₜ₊₁ .- logm) - - trans = transition_matrix(hmm, control_seq[t]) - αₜ, αₜ₊₁ = view(α, :, t), view(α, :, t + 1) - mul!(αₜ₊₁, transpose(trans), αₜ) - αₜ₊₁ .*= Bₜ₊₁ - c[t + 1] = inv(sum(αₜ₊₁)) - lmul!(c[t + 1], αₜ₊₁) - - logL[k] += -log(c[t + 1]) + logm + logL[k] = zero(eltype(logL)) + for t in t1:t2 + αₜ = view(α, :, t) + Bₜ = view(B, :, t) + if t == t1 + copyto!(αₜ, initialization(hmm)) + else + αₜ₋₁ = view(α, :, t - 1) + predict_next_state!(αₜ, hmm, αₜ₋₁, control_seq[t - 1]) + end + cₜ, logLₜ = _forward_digest_observation!(αₜ, Bₜ, hmm, obs_seq[t], control_seq[t]) + c[t] = cₜ + logL[k] += logLₜ end @argcheck isfinite(logL[k]) diff --git a/src/inference/predict.jl b/src/inference/predict.jl new file mode 100644 index 0000000..7240af6 --- /dev/null +++ b/src/inference/predict.jl @@ -0,0 +1,10 @@ +function predict_next_state!( + next_state_marginals::AbstractVector{<:Real}, + hmm::AbstractHMM, + current_state_marginals::AbstractVector{<:Real}, + control=nothing, +) + trans = transition_matrix(hmm, control) + mul!(next_state_marginals, transpose(trans), current_state_marginals) + return next_state_marginals +end diff --git a/src/types/abstract_hmm.jl b/src/types/abstract_hmm.jl index c9f0c67..4fe1387 100644 --- a/src/types/abstract_hmm.jl +++ b/src/types/abstract_hmm.jl @@ -72,6 +72,9 @@ log_initialization(hmm::AbstractHMM) = elementwise_log(initialization(hmm)) transition_matrix(hmm, control) Return the matrix of state transition probabilities for `hmm` (possibly when `control` is applied). + +!!! note + When processing sequences, the control at time `t` influences the transition from time `t` to `t+1` (and not from time `t-1` to `t`). """ function transition_matrix end @@ -82,6 +85,9 @@ function transition_matrix end Return the matrix of state transition log-probabilities for `hmm` (possibly when `control` is applied). Falls back on `transition_matrix`. + +!!! note + When processing sequences, the control at time `t` influences the transition from time `t` to `t+1` (and not from time `t-1` to `t`). """ function log_transition_matrix(hmm::AbstractHMM, control) return elementwise_log(transition_matrix(hmm, control))