diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index e4240b49..975ccb9b 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -3,7 +3,7 @@ on: push: branches: - main - tags: ['*'] + tags: ["*"] pull_request: concurrency: # Skip intermediate builds: always. @@ -12,16 +12,19 @@ concurrency: cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} jobs: test: - name: Julia ${{ matrix.version }} - ${{ github.event_name }} + name: Julia ${{ matrix.version }} - ${{ matrix.test_suite }} runs-on: ubuntu-latest strategy: fail-fast: false matrix: version: - - '1.9' - - '1' - os: - - ubuntu-latest + - "1.9" + - "1" + test_suite: + - "Standard" + - "HMMBase" + env: + JULIA_HMM_TEST_SUITE: ${{ matrix.test_suite }} steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 @@ -36,4 +39,4 @@ jobs: with: files: lcov.info token: ${{ secrets.CODECOV_TOKEN }} - fail_ci_if_error: true \ No newline at end of file + fail_ci_if_error: true diff --git a/Project.toml b/Project.toml index 525b5d01..3d5d5c6c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "HiddenMarkovModels" uuid = "84ca31d5-effc-45e0-bfda-5a68cd981f47" authors = ["Guillaume Dalle"] -version = "0.5.3" +version = "0.5.4" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" diff --git a/examples/autodiff.jl b/examples/autodiff.jl index a1d96d10..fe0e6732 100644 --- a/examples/autodiff.jl +++ b/examples/autodiff.jl @@ -11,6 +11,7 @@ using Enzyme: Enzyme using ForwardDiff: ForwardDiff using HiddenMarkovModels import HiddenMarkovModels as HMMs +using HMMTest #src using LinearAlgebra using Random: Random, AbstractRNG using StableRNGs diff --git a/examples/basics.jl b/examples/basics.jl index 3594bdab..d9d4357b 100644 --- a/examples/basics.jl +++ b/examples/basics.jl @@ -189,7 +189,7 @@ This is important to keep in mind when testing new models. In many applications, we have access to various observation sequences of different lengths. =# -nb_seqs = 300 +nb_seqs = 1000 long_obs_seqs = [last(rand(rng, hmm, rand(rng, 100:200))) for k in 1:nb_seqs]; typeof(long_obs_seqs) @@ -257,7 +257,8 @@ hcat(initialization(hmm_est_concat), initialization(hmm)) # ## Tests #src +@test startswith(string(hmm), "Hidden") #src +@test length.(values(rand(hmm, T))) == (T, T); #src control_seq = fill(nothing, last(seq_ends)); #src -test_identical_hmmbase(rng, hmm, 100; hmm_guess) #src test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess) #src test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) #src diff --git a/examples/controlled.jl b/examples/controlled.jl index 1f2451b9..ffd0ab98 100644 --- a/examples/controlled.jl +++ b/examples/controlled.jl @@ -66,7 +66,7 @@ Simulation requires a vector of controls, each being a vector itself with the ri Let us build several sequences of variable lengths. =# -control_seqs = [[randn(rng, d) for t in 1:rand(100:200)] for k in 1:100]; +control_seqs = [[randn(rng, d) for t in 1:rand(100:200)] for k in 1:1000]; obs_seqs = [rand(rng, hmm, control_seq).obs_seq for control_seq in control_seqs]; obs_seq = reduce(vcat, obs_seqs) @@ -94,7 +94,7 @@ function StatsAPI.fit!( fb_storage::HMMs.ForwardBackwardStorage, obs_seq::AbstractVector, control_seq::AbstractVector; - seq_ends::AbstractVector{Int}, + seq_ends, ) where {T} (; γ, ξ) = fb_storage N = length(hmm) @@ -151,5 +151,5 @@ hcat(hmm_est.dist_coeffs[2], hmm.dist_coeffs[2]) @test hmm_est.dist_coeffs[1] ≈ hmm.dist_coeffs[1] atol = 0.05 #src @test hmm_est.dist_coeffs[2] ≈ hmm.dist_coeffs[2] atol = 0.05 #src -test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, atol=0.08, init=false) #src +test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) #src test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) #src diff --git a/examples/interfaces.jl b/examples/interfaces.jl index ba306051..4498478f 100644 --- a/examples/interfaces.jl +++ b/examples/interfaces.jl @@ -186,7 +186,7 @@ function StatsAPI.fit!( hmm::PriorHMM, fb_storage::HiddenMarkovModels.ForwardBackwardStorage, obs_seq::AbstractVector; - seq_ends::AbstractVector{Int}, + seq_ends, ) ## initialize to defaults without observations hmm.init .= 0 diff --git a/examples/temporal.jl b/examples/temporal.jl index 9c9549f4..1cad38f6 100644 --- a/examples/temporal.jl +++ b/examples/temporal.jl @@ -109,7 +109,7 @@ function StatsAPI.fit!( fb_storage::HMMs.ForwardBackwardStorage, obs_seq::AbstractVector, control_seq::AbstractVector; - seq_ends::AbstractVector{Int}, + seq_ends, ) where {T} (; γ, ξ) = fb_storage L, N = period(hmm), length(hmm) @@ -183,6 +183,6 @@ map(mean, hcat(obs_distributions(hmm_est, 2), obs_distributions(hmm, 2))) # ## Tests #src -@test mean(obs_seq[1:2:end]) < 0 < mean(obs_seq[2:2:end]) #src -test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, atol=0.09, init=false) #src +@test mean(obs_seqs[1][1:2:end]) < 0 < mean(obs_seqs[1][2:2:end]) #src +test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) #src test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) #src diff --git a/examples/types.jl b/examples/types.jl index 0945be63..32901c46 100644 --- a/examples/types.jl +++ b/examples/types.jl @@ -156,10 +156,9 @@ Another useful array type is [StaticArrays.jl](https://github.com/JuliaArrays/St @test nnz(log_transition_matrix(hmm)) == nnz(transition_matrix(hmm)) #src -seq_ends = cumsum(rand(rng, 100:200, 100)); #src +seq_ends = cumsum(rand(rng, 100:200, 1000)); #src control_seq = fill(nothing, last(seq_ends)); #src -test_identical_hmmbase(rng, hmm, 100; hmm_guess) #src -test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false, atol=0.08) #src +test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) #src test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) #src # https://github.com/JuliaSparse/SparseArrays.jl/issues/469 #src @test_skip test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess) #src diff --git a/ext/HiddenMarkovModelsDistributionsExt.jl b/ext/HiddenMarkovModelsDistributionsExt.jl index 7bb16d01..8581ce84 100644 --- a/ext/HiddenMarkovModelsDistributionsExt.jl +++ b/ext/HiddenMarkovModelsDistributionsExt.jl @@ -27,6 +27,10 @@ function HiddenMarkovModels.fit_in_sequence!( return dists[i] = fit(typeof(dists[i]), reduce(hcat, x_vecs), w) end +#= + +# Matrix distribution fitting not supported by Distributions.jl at the moment + function HiddenMarkovModels.fit_in_sequence!( dists::AbstractVector{<:MatrixDistribution}, i::Integer, @@ -37,5 +41,6 @@ function HiddenMarkovModels.fit_in_sequence!( end dcat(M1, M2) = cat(M1, M2; dims=3) +=# end diff --git a/libs/HMMTest/Project.toml b/libs/HMMTest/Project.toml index 5e576d83..90312674 100644 --- a/libs/HMMTest/Project.toml +++ b/libs/HMMTest/Project.toml @@ -5,9 +5,14 @@ version = "0.1.0" [deps] BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" -HMMBase = "b2b3ca75-8444-5ffa-85e6-af70e2b64fe7" HiddenMarkovModels = "84ca31d5-effc-45e0-bfda-5a68cd981f47" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[weakdeps] +HMMBase = "b2b3ca75-8444-5ffa-85e6-af70e2b64fe7" + +[extensions] +HMMTestHMMBaseExt = "HMMBase" \ No newline at end of file diff --git a/libs/HMMTest/src/hmmbase.jl b/libs/HMMTest/ext/HMMTestHMMBaseExt.jl similarity index 89% rename from libs/HMMTest/src/hmmbase.jl rename to libs/HMMTest/ext/HMMTestHMMBaseExt.jl index 808e2e0c..b13e7f64 100644 --- a/libs/HMMTest/src/hmmbase.jl +++ b/libs/HMMTest/ext/HMMTestHMMBaseExt.jl @@ -1,5 +1,14 @@ +module HMMTestHMMBaseExt -function test_identical_hmmbase( +using HiddenMarkovModels +import HiddenMarkovModels as HMMs +using HMMBase: HMMBase +using HMMTest +using Random: AbstractRNG +using Statistics: mean +using Test: @test, @testset, @test_broken + +function HMMTest.test_identical_hmmbase( rng::AbstractRNG, hmm::AbstractHMM, T::Integer; @@ -54,3 +63,5 @@ function test_identical_hmmbase( end end end + +end diff --git a/libs/HMMTest/src/HMMTest.jl b/libs/HMMTest/src/HMMTest.jl index 26951ceb..f342c238 100644 --- a/libs/HMMTest/src/HMMTest.jl +++ b/libs/HMMTest/src/HMMTest.jl @@ -2,13 +2,15 @@ module HMMTest using BenchmarkTools: @ballocated using HiddenMarkovModels +using HiddenMarkovModels: AbstractVectorOrNTuple import HiddenMarkovModels as HMMs -using HMMBase: HMMBase using JET: @test_opt, @test_call using Random: AbstractRNG using Statistics: mean using Test: @test, @testset, @test_broken +function test_identical_hmmbase end # in extension + export transpose_hmm export test_equal_hmms, test_coherent_algorithms export test_identical_hmmbase @@ -18,7 +20,6 @@ export test_type_stability include("utils.jl") include("coherence.jl") include("allocations.jl") -include("hmmbase.jl") include("jet.jl") end diff --git a/libs/HMMTest/src/allocations.jl b/libs/HMMTest/src/allocations.jl index ea3aeef9..38361975 100644 --- a/libs/HMMTest/src/allocations.jl +++ b/libs/HMMTest/src/allocations.jl @@ -3,7 +3,7 @@ function test_allocations( rng::AbstractRNG, hmm::AbstractHMM, control_seq::AbstractVector; - seq_ends::AbstractVector{Int}, + seq_ends::AbstractVectorOrNTuple{Int}, hmm_guess::Union{Nothing,AbstractHMM}=nothing, ) @testset "Allocations" begin diff --git a/libs/HMMTest/src/coherence.jl b/libs/HMMTest/src/coherence.jl index 2a39e336..a3384c0f 100644 --- a/libs/HMMTest/src/coherence.jl +++ b/libs/HMMTest/src/coherence.jl @@ -55,7 +55,7 @@ function test_coherent_algorithms( rng::AbstractRNG, hmm::AbstractHMM, control_seq::AbstractVector; - seq_ends::AbstractVector{Int}, + seq_ends::AbstractVectorOrNTuple{Int}, hmm_guess::Union{Nothing,AbstractHMM}=nothing, atol::Real=0.05, init::Bool=true, diff --git a/libs/HMMTest/src/jet.jl b/libs/HMMTest/src/jet.jl index 75820193..d6d29f4f 100644 --- a/libs/HMMTest/src/jet.jl +++ b/libs/HMMTest/src/jet.jl @@ -3,7 +3,7 @@ function test_type_stability( rng::AbstractRNG, hmm::AbstractHMM, control_seq::AbstractVector; - seq_ends::AbstractVector{Int}, + seq_ends::AbstractVectorOrNTuple{Int}, hmm_guess::Union{Nothing,AbstractHMM}=nothing, ) @testset "Type stability" begin diff --git a/src/HiddenMarkovModels.jl b/src/HiddenMarkovModels.jl index 2cfa2029..ecfd99b2 100644 --- a/src/HiddenMarkovModels.jl +++ b/src/HiddenMarkovModels.jl @@ -16,7 +16,7 @@ using ChainRulesCore: ChainRulesCore, NoTangent, RuleConfig, rrule_via_ad using DensityInterface: DensityInterface, DensityKind, HasDensity, NoDensity, logdensityof using DocStringExtensions using FillArrays: Fill -using LinearAlgebra: Transpose, dot, ldiv!, lmul!, mul!, parent +using LinearAlgebra: Transpose, axpy!, dot, ldiv!, lmul!, mul!, parent using Random: Random, AbstractRNG, default_rng using SparseArrays: AbstractSparseArray, SparseMatrixCSC, nonzeros, nnz, nzrange, rowvals using StatsAPI: StatsAPI, fit, fit! diff --git a/src/inference/baum_welch.jl b/src/inference/baum_welch.jl index 1bc25665..279e6fa3 100644 --- a/src/inference/baum_welch.jl +++ b/src/inference/baum_welch.jl @@ -5,7 +5,7 @@ function baum_welch_has_converged( logL, logL_prev = logL_evolution[end], logL_evolution[end - 1] progress = logL - logL_prev if loglikelihood_increasing && progress < min(0, -atol) - error("Loglikelihood decreased in Baum-Welch") + error("Loglikelihood decreased from $logL_prev to $logL in Baum-Welch") elseif progress < atol return true end @@ -22,7 +22,7 @@ function baum_welch!( hmm::AbstractHMM, obs_seq::AbstractVector, control_seq::AbstractVector; - seq_ends::AbstractVector{Int}, + seq_ends::AbstractVectorOrNTuple{Int}, atol::Real, max_iterations::Integer, loglikelihood_increasing::Bool, @@ -55,7 +55,7 @@ function baum_welch( hmm_guess::AbstractHMM, obs_seq::AbstractVector, control_seq::AbstractVector=Fill(nothing, length(obs_seq)); - seq_ends::AbstractVector{Int}=Fill(length(obs_seq), 1), + seq_ends::AbstractVectorOrNTuple{Int}=(length(obs_seq),), atol=1e-5, max_iterations=100, loglikelihood_increasing=true, @@ -73,7 +73,7 @@ function baum_welch( seq_ends, atol, max_iterations, - loglikelihood_increasing=false, + loglikelihood_increasing, ) return hmm, logL_evolution end @@ -85,7 +85,7 @@ function StatsAPI.fit!( fb_storage::ForwardBackwardStorage, obs_seq::AbstractVector, control_seq::AbstractVector; - seq_ends::AbstractVector{Int}, + seq_ends::AbstractVectorOrNTuple{Int}, ) return fit!(hmm, fb_storage, obs_seq; seq_ends) end diff --git a/src/inference/chainrules.jl b/src/inference/chainrules.jl index 424236e1..8816120f 100644 --- a/src/inference/chainrules.jl +++ b/src/inference/chainrules.jl @@ -4,7 +4,7 @@ function _params_and_loglikelihoods( hmm::AbstractHMM, obs_seq::Vector, control_seq::AbstractVector=Fill(nothing, length(obs_seq)); - seq_ends::AbstractVector{Int}=Fill(length(obs_seq), 1), + seq_ends::AbstractVectorOrNTuple{Int}=(length(obs_seq),), ) init = initialization(hmm) trans_by_time = mapreduce(_dcat, eachindex(control_seq)) do t @@ -22,7 +22,7 @@ function ChainRulesCore.rrule( hmm::AbstractHMM, obs_seq::AbstractVector, control_seq::AbstractVector=Fill(nothing, length(obs_seq)); - seq_ends::AbstractVector{Int}=Fill(length(obs_seq), 1), + seq_ends::AbstractVectorOrNTuple{Int}=(length(obs_seq),), ) _, pullback = rrule_via_ad( rc, _params_and_loglikelihoods, hmm, obs_seq, control_seq; seq_ends diff --git a/src/inference/forward.jl b/src/inference/forward.jl index 3be62097..11327e35 100644 --- a/src/inference/forward.jl +++ b/src/inference/forward.jl @@ -16,7 +16,34 @@ struct ForwardStorage{R} c::Vector{R} end -Base.eltype(::ForwardStorage{R}) where {R} = R +""" +$(TYPEDEF) + +# Fields + +Only the fields with a description are part of the public API. + +$(TYPEDFIELDS) +""" +struct ForwardBackwardStorage{R,M<:AbstractMatrix{R}} + "posterior state marginals `γ[i,t] = ℙ(X[t]=i | Y[1:T])`" + γ::Matrix{R} + "posterior transition marginals `ξ[t][i,j] = ℙ(X[t]=i, X[t+1]=j | Y[1:T])`" + ξ::Vector{M} + "one loglikelihood per observation sequence" + logL::Vector{R} + B::Matrix{R} + α::Matrix{R} + c::Vector{R} + β::Matrix{R} + Bβ::Matrix{R} +end + +Base.eltype(::ForwardBackwardStorage{R}) where {R} = R + +const ForwardOrForwardBackwardStorage{R} = Union{ + ForwardStorage{R},ForwardBackwardStorage{R} +} """ $(SIGNATURES) @@ -25,7 +52,7 @@ function initialize_forward( hmm::AbstractHMM, obs_seq::AbstractVector, control_seq::AbstractVector; - seq_ends::AbstractVector{Int}, + seq_ends::AbstractVectorOrNTuple{Int}, ) N, T, K = length(hmm), length(obs_seq), length(seq_ends) R = eltype(hmm, obs_seq[1], control_seq[1]) @@ -40,7 +67,7 @@ end $(SIGNATURES) """ function forward!( - storage, + storage::ForwardOrForwardBackwardStorage, hmm::AbstractHMM, obs_seq::AbstractVector, control_seq::AbstractVector, @@ -100,16 +127,23 @@ end $(SIGNATURES) """ function forward!( - storage, + storage::ForwardOrForwardBackwardStorage, hmm::AbstractHMM, obs_seq::AbstractVector, control_seq::AbstractVector; - seq_ends::AbstractVector{Int}, + seq_ends::AbstractVectorOrNTuple{Int}, ) (; logL) = storage - @threads for k in eachindex(seq_ends) - t1, t2 = seq_limits(seq_ends, k) - logL[k] = forward!(storage, hmm, obs_seq, control_seq, t1, t2;) + if seq_ends isa NTuple + for k in eachindex(seq_ends) + t1, t2 = seq_limits(seq_ends, k) + logL[k] = forward!(storage, hmm, obs_seq, control_seq, t1, t2;) + end + else + @threads for k in eachindex(seq_ends) + t1, t2 = seq_limits(seq_ends, k) + logL[k] = forward!(storage, hmm, obs_seq, control_seq, t1, t2;) + end end return nothing end @@ -125,7 +159,7 @@ function forward( hmm::AbstractHMM, obs_seq::AbstractVector, control_seq::AbstractVector=Fill(nothing, length(obs_seq)); - seq_ends::AbstractVector{Int}=Fill(length(obs_seq), 1), + seq_ends::AbstractVectorOrNTuple{Int}=(length(obs_seq),), ) storage = initialize_forward(hmm, obs_seq, control_seq; seq_ends) forward!(storage, hmm, obs_seq, control_seq; seq_ends) diff --git a/src/inference/forward_backward.jl b/src/inference/forward_backward.jl index 63ad978b..2f64ab0d 100644 --- a/src/inference/forward_backward.jl +++ b/src/inference/forward_backward.jl @@ -1,28 +1,3 @@ -""" -$(TYPEDEF) - -# Fields - -Only the fields with a description are part of the public API. - -$(TYPEDFIELDS) -""" -struct ForwardBackwardStorage{R,M<:AbstractMatrix{R}} - "posterior state marginals `γ[i,t] = ℙ(X[t]=i | Y[1:T])`" - γ::Matrix{R} - "posterior transition marginals `ξ[t][i,j] = ℙ(X[t]=i, X[t+1]=j | Y[1:T])`" - ξ::Vector{M} - "one loglikelihood per observation sequence" - logL::Vector{R} - B::Matrix{R} - α::Matrix{R} - c::Vector{R} - β::Matrix{R} - Bβ::Matrix{R} -end - -Base.eltype(::ForwardBackwardStorage{R}) where {R} = R - """ $(SIGNATURES) """ @@ -30,7 +5,7 @@ function initialize_forward_backward( hmm::AbstractHMM, obs_seq::AbstractVector, control_seq::AbstractVector; - seq_ends::AbstractVector{Int}, + seq_ends::AbstractVectorOrNTuple{Int}, transition_marginals=true, ) N, T, K = length(hmm), length(obs_seq), length(seq_ends) @@ -100,19 +75,28 @@ end $(SIGNATURES) """ function forward_backward!( - storage::ForwardBackwardStorage{R}, + storage::ForwardBackwardStorage, hmm::AbstractHMM, obs_seq::AbstractVector, control_seq::AbstractVector; - seq_ends::AbstractVector{Int}, + seq_ends::AbstractVectorOrNTuple{Int}, transition_marginals::Bool=true, -) where {R} +) (; logL) = storage - @threads for k in eachindex(seq_ends) - t1, t2 = seq_limits(seq_ends, k) - logL[k] = forward_backward!( - storage, hmm, obs_seq, control_seq, t1, t2; transition_marginals - ) + if seq_ends isa NTuple + for k in eachindex(seq_ends) + t1, t2 = seq_limits(seq_ends, k) + logL[k] = forward_backward!( + storage, hmm, obs_seq, control_seq, t1, t2; transition_marginals + ) + end + else + @threads for k in eachindex(seq_ends) + t1, t2 = seq_limits(seq_ends, k) + logL[k] = forward_backward!( + storage, hmm, obs_seq, control_seq, t1, t2; transition_marginals + ) + end end return nothing end @@ -128,7 +112,7 @@ function forward_backward( hmm::AbstractHMM, obs_seq::AbstractVector, control_seq::AbstractVector=Fill(nothing, length(obs_seq)); - seq_ends::AbstractVector{Int}=Fill(length(obs_seq), 1), + seq_ends::AbstractVectorOrNTuple{Int}=(length(obs_seq),), ) transition_marginals = false storage = initialize_forward_backward( diff --git a/src/inference/logdensity.jl b/src/inference/logdensity.jl index f43fb25c..ce153ff2 100644 --- a/src/inference/logdensity.jl +++ b/src/inference/logdensity.jl @@ -7,7 +7,7 @@ function DensityInterface.logdensityof( hmm::AbstractHMM, obs_seq::AbstractVector, control_seq::AbstractVector=Fill(nothing, length(obs_seq)); - seq_ends::AbstractVector{Int}=Fill(length(obs_seq), 1), + seq_ends::AbstractVectorOrNTuple{Int}=(length(obs_seq),), ) _, logL = forward(hmm, obs_seq, control_seq; seq_ends) return sum(logL) @@ -23,7 +23,7 @@ function joint_logdensityof( obs_seq::AbstractVector, state_seq::AbstractVector, control_seq::AbstractVector=Fill(nothing, length(obs_seq)); - seq_ends::AbstractVector{Int}=Fill(length(obs_seq), 1), + seq_ends::AbstractVectorOrNTuple{Int}=(length(obs_seq),), ) R = eltype(hmm, obs_seq[1], control_seq[1]) logL = zero(R) diff --git a/src/inference/viterbi.jl b/src/inference/viterbi.jl index 09e18a26..f1758097 100644 --- a/src/inference/viterbi.jl +++ b/src/inference/viterbi.jl @@ -17,8 +17,6 @@ struct ViterbiStorage{R} ψ::Matrix{Int} end -Base.eltype(::ViterbiStorage{R}) where {R} = R - """ $(SIGNATURES) """ @@ -26,7 +24,7 @@ function initialize_viterbi( hmm::AbstractHMM, obs_seq::AbstractVector, control_seq::AbstractVector; - seq_ends::AbstractVector{Int}, + seq_ends::AbstractVectorOrNTuple{Int}, ) N, T, K = length(hmm), length(obs_seq), length(seq_ends) R = eltype(hmm, obs_seq[1], control_seq[1]) @@ -85,12 +83,19 @@ function viterbi!( hmm::AbstractHMM, obs_seq::AbstractVector, control_seq::AbstractVector; - seq_ends::AbstractVector{Int}, + seq_ends::AbstractVectorOrNTuple{Int}, ) where {R} (; logL) = storage - @threads for k in eachindex(seq_ends) - t1, t2 = seq_limits(seq_ends, k) - logL[k] = viterbi!(storage, hmm, obs_seq, control_seq, t1, t2;) + if seq_ends isa NTuple + for k in eachindex(seq_ends) + t1, t2 = seq_limits(seq_ends, k) + logL[k] = viterbi!(storage, hmm, obs_seq, control_seq, t1, t2;) + end + else + @threads for k in eachindex(seq_ends) + t1, t2 = seq_limits(seq_ends, k) + logL[k] = viterbi!(storage, hmm, obs_seq, control_seq, t1, t2;) + end end return nothing end @@ -106,7 +111,7 @@ function viterbi( hmm::AbstractHMM, obs_seq::AbstractVector, control_seq::AbstractVector=Fill(nothing, length(obs_seq)); - seq_ends::AbstractVector{Int}=Fill(length(obs_seq), 1), + seq_ends::AbstractVectorOrNTuple{Int}=(length(obs_seq),), ) storage = initialize_viterbi(hmm, obs_seq, control_seq; seq_ends) viterbi!(storage, hmm, obs_seq, control_seq; seq_ends) diff --git a/src/types/abstract_hmm.jl b/src/types/abstract_hmm.jl index 6f3f8e53..c9f0c675 100644 --- a/src/types/abstract_hmm.jl +++ b/src/types/abstract_hmm.jl @@ -125,7 +125,7 @@ function obs_logdensities!( logb::AbstractVector{T}, hmm::AbstractHMM, obs, control ) where {T} dists = obs_distributions(hmm, control) - @inbounds @simd for i in eachindex(logb, dists) + @simd for i in eachindex(logb, dists) logb[i] = logdensityof(dists[i], obs) end @argcheck maximum(logb) < typemax(T) diff --git a/src/types/hmm.jl b/src/types/hmm.jl index 9de3cf2e..60d7bf12 100644 --- a/src/types/hmm.jl +++ b/src/types/hmm.jl @@ -38,10 +38,6 @@ struct HMM{ end end -function Base.copy(hmm::HMM) - return HMM(copy(hmm.init), copy(hmm.trans), copy(hmm.dists)) -end - function Base.show(io::IO, hmm::HMM) return print( io, @@ -61,7 +57,7 @@ function StatsAPI.fit!( hmm::HMM, fb_storage::ForwardBackwardStorage, obs_seq::AbstractVector; - seq_ends::AbstractVector{Int}, + seq_ends::AbstractVectorOrNTuple{Int}, ) (; γ, ξ) = fb_storage # Fit states @@ -69,13 +65,13 @@ function StatsAPI.fit!( t1, t2 = seq_limits(seq_ends, k) # use ξ[t2] as scratch space since it is zero anyway scratch = ξ[t2] - scratch .= zero(eltype(scratch)) + fill!(scratch, zero(eltype(scratch))) for t in t1:(t2 - 1) scratch .+= ξ[t] end end - hmm.init .= zero(eltype(hmm.init)) - hmm.trans .= zero(eltype(hmm.trans)) + fill!(hmm.init, zero(eltype(hmm.init))) + fill!(hmm.trans, zero(eltype(hmm.trans))) for k in eachindex(seq_ends) t1, t2 = seq_limits(seq_ends, k) hmm.init .+= view(γ, :, t1) diff --git a/src/utils/lightcategorical.jl b/src/utils/lightcategorical.jl index f16f46dc..605f3c28 100644 --- a/src/utils/lightcategorical.jl +++ b/src/utils/lightcategorical.jl @@ -52,8 +52,8 @@ function StatsAPI.fit!( ) where {T1} @argcheck 1 <= minimum(x) <= maximum(x) <= length(dist.p) w_tot = sum(w) - dist.p .= zero(T1) - @inbounds @simd for i in eachindex(x, w) + fill!(dist.p, zero(T1)) + @simd for i in eachindex(x, w) dist.p[x[i]] += w[i] end dist.p ./= w_tot diff --git a/src/utils/lightdiagnormal.jl b/src/utils/lightdiagnormal.jl index 8b0748d6..05851672 100644 --- a/src/utils/lightdiagnormal.jl +++ b/src/utils/lightdiagnormal.jl @@ -46,7 +46,7 @@ function DensityInterface.logdensityof( ) where {T1,T2,T3} l = zero(promote_type(T1, T2, T3, eltype(x))) l -= sum(dist.logσ) + log2π * length(x) / 2 - @inbounds @simd for i in eachindex(x, dist.μ, dist.σ) + @simd for i in eachindex(x, dist.μ, dist.σ) l -= abs2(x[i] - dist.μ[i]) / (2 * abs2(dist.σ[i])) end return l @@ -56,13 +56,13 @@ function StatsAPI.fit!( dist::LightDiagNormal{T1,T2}, x::AbstractVector{<:AbstractVector}, w::AbstractVector ) where {T1,T2} w_tot = sum(w) - dist.μ .= zero(T1) - dist.σ .= zero(T2) - @inbounds @simd for i in eachindex(x, w) - dist.μ .+= x[i] .* w[i] + fill!(dist.μ, zero(T1)) + fill!(dist.σ, zero(T2)) + @simd for i in eachindex(x, w) + axpy!(w[i], x[i], dist.μ) end dist.μ ./= w_tot - @inbounds @simd for i in eachindex(x, w) + @simd for i in eachindex(x, w) dist.σ .+= abs2.(x[i] .- dist.μ) .* w[i] end dist.σ .= sqrt.(dist.σ ./ w_tot) diff --git a/src/utils/limits.jl b/src/utils/limits.jl index cbd40e50..f06c7b0f 100644 --- a/src/utils/limits.jl +++ b/src/utils/limits.jl @@ -3,7 +3,7 @@ $(SIGNATURES) Return a tuple `(t1, t2)` giving the begin and end indices of subsequence `k` within a set of sequences ending at `seq_ends`. """ -function seq_limits(seq_ends::AbstractVector{Int}, k::Integer) +function seq_limits(seq_ends::AbstractVectorOrNTuple{Int}, k::Integer) if k == 1 return 1, seq_ends[k] else diff --git a/src/utils/linalg.jl b/src/utils/linalg.jl index 506b29e4..9ed23157 100644 --- a/src/utils/linalg.jl +++ b/src/utils/linalg.jl @@ -1,3 +1,5 @@ +const AbstractVectorOrNTuple{T} = Union{AbstractVector{T},NTuple{N,T}} where {N} + sum_to_one!(x) = ldiv!(sum(x), x) mynonzeros(x::AbstractArray) = x @@ -33,9 +35,9 @@ function mul_rows_cols!( Brv = rowvals(B) Bnz = nonzeros(B) Anz = nonzeros(A) - for j in axes(B, 2) + @simd for j in axes(B, 2) @argcheck nzrange(B, j) == nzrange(A, j) - for k in nzrange(B, j) + @simd for k in nzrange(B, j) i = Brv[k] Bnz[k] = l[i] * Anz[k] * r[j] end @@ -56,10 +58,10 @@ function argmaxplus_transmul!( ) where {R} @argcheck axes(A, 1) == eachindex(x) @argcheck axes(A, 2) == eachindex(y) - y .= typemin(R) - ind .= 0 - for j in axes(A, 2) - for i in axes(A, 1) + fill!(y, typemin(R)) + fill!(ind, 0) + @simd for j in axes(A, 2) + @simd for i in axes(A, 1) z = A[i, j] + x[i] if z > y[j] y[j] = z @@ -80,10 +82,10 @@ function argmaxplus_transmul!( @argcheck axes(A, 2) == eachindex(y) Anz = nonzeros(A) Arv = rowvals(A) - y .= typemin(R) - ind .= 0 - for j in axes(A, 2) - for k in nzrange(A, j) + fill!(y, typemin(R)) + fill!(ind, 0) + @simd for j in axes(A, 2) + @simd for k in nzrange(A, j) i = Arv[k] z = Anz[k] + x[i] if z > y[j] diff --git a/test/correctness.jl b/test/correctness.jl index cb139823..716e8552 100644 --- a/test/correctness.jl +++ b/test/correctness.jl @@ -9,11 +9,11 @@ using SparseArrays using StableRNGs using Test -rng = StableRNG(63) +TEST_SUITE = get(ENV, "JULIA_HMM_TEST_SUITE", "Standard") ## Settings -T, K = 50, 200 +T, K = 100, 200 init = [0.4, 0.6] init_guess = [0.5, 0.5] @@ -29,26 +29,31 @@ p_guess = [[0.7, 0.3], [0.3, 0.7]] σ = ones(2) +rng = StableRNG(63) control_seqs = [fill(nothing, rand(rng, T:(2T))) for k in 1:K]; control_seq = reduce(vcat, control_seqs); seq_ends = cumsum(length.(control_seqs)); ## Uncontrolled -@testset "Normal" begin +@testset verbose = true "Normal" begin dists = [Normal(μ[1][1]), Normal(μ[2][1])] dists_guess = [Normal(μ_guess[1][1]), Normal(μ_guess[2][1])] hmm = HMM(init, trans, dists) hmm_guess = HMM(init_guess, trans_guess, dists_guess) - test_identical_hmmbase(rng, hmm, T; hmm_guess) - test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) - test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) - test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess) + rng = StableRNG(63) + if TEST_SUITE == "HMMBase" + test_identical_hmmbase(rng, hmm, T; hmm_guess) + else + test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) + test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) + test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess) + end end -@testset "DiagNormal" begin +@testset verbose = true "DiagNormal" begin dists = [MvNormal(μ[1], Diagonal(abs2.(σ))), MvNormal(μ[2], Diagonal(abs2.(σ)))] dists_guess = [ MvNormal(μ_guess[1], Diagonal(abs2.(σ))), MvNormal(μ_guess[2], Diagonal(abs2.(σ))) @@ -57,68 +62,90 @@ end hmm = HMM(init, trans, dists) hmm_guess = HMM(init_guess, trans_guess, dists_guess) - test_identical_hmmbase(rng, hmm, T; hmm_guess) - test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) - test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) + rng = StableRNG(63) + if TEST_SUITE == "HMMBase" + test_identical_hmmbase(rng, hmm, T; hmm_guess) + else + test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) + test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) + end end -@testset "LightCategorical" begin +@testset verbose = true "LightCategorical" begin dists = [LightCategorical(p[1]), LightCategorical(p[2])] dists_guess = [LightCategorical(p_guess[1]), LightCategorical(p_guess[2])] hmm = HMM(init, trans, dists) hmm_guess = HMM(init_guess, trans_guess, dists_guess) - test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) - test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) - test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess) + rng = StableRNG(63) + if TEST_SUITE != "HMMBase" + test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) + test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) + test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess) + end end -@testset "LightDiagNormal" begin +@testset verbose = true "LightDiagNormal" begin dists = [LightDiagNormal(μ[1], σ), LightDiagNormal(μ[2], σ)] dists_guess = [LightDiagNormal(μ_guess[1], σ), LightDiagNormal(μ_guess[2], σ)] hmm = HMM(init, trans, dists) hmm_guess = HMM(init_guess, trans_guess, dists_guess) - test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) - test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) - test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess) + rng = StableRNG(63) + if TEST_SUITE != "HMMBase" + test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) + test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) + test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess) + end end -@testset "Normal (sparse)" begin +@testset verbose = true "Normal (sparse)" begin dists = [Normal(μ[1][1]), Normal(μ[2][1])] dists_guess = [Normal(μ_guess[1][1]), Normal(μ_guess[2][1])] hmm = HMM(init, sparse(trans), dists) hmm_guess = HMM(init_guess, trans_guess, dists_guess) - test_identical_hmmbase(rng, hmm, T; hmm_guess) - test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) - test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) - @test_skip test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess) + rng = StableRNG(63) + if TEST_SUITE == "HMMBase" + test_identical_hmmbase(rng, hmm, T; hmm_guess) + else + test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) + test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) + @test_skip test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess) + end end -@testset "Normal transposed" begin # issue 99 +@testset verbose = true "Normal transposed" begin # issue 99 dists = [Normal(μ[1][1]), Normal(μ[2][1])] dists_guess = [Normal(μ_guess[1][1]), Normal(μ_guess[2][1])] hmm = transpose_hmm(HMM(init, trans, dists)) hmm_guess = transpose_hmm(HMM(init_guess, trans_guess, dists_guess)) - test_identical_hmmbase(rng, hmm, T; hmm_guess) - test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) - test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) - test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess) + rng = StableRNG(63) + if TEST_SUITE == "HMMBase" + test_identical_hmmbase(rng, hmm, T; hmm_guess) + else + test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) + test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) + test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess) + end end -@testset "Normal and Exponential" begin # issue 101 +@testset verbose = true "Normal and Exponential" begin # issue 101 dists = [Normal(μ[1][1]), Exponential(1.0)] dists_guess = [Normal(μ_guess[1][1]), Exponential(0.8)] hmm = HMM(init, trans, dists) hmm_guess = HMM(init_guess, trans_guess, dists_guess) - test_identical_hmmbase(rng, hmm, T; hmm_guess) - test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) + rng = StableRNG(63) + if TEST_SUITE == "HMMBase" + test_identical_hmmbase(rng, hmm, T; hmm_guess) + else + test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) + end end diff --git a/test/distributions.jl b/test/distributions.jl index 544ba063..3c660c0d 100644 --- a/test/distributions.jl +++ b/test/distributions.jl @@ -1,5 +1,6 @@ using Distributions -using HiddenMarkovModels: LightCategorical, LightDiagNormal, logdensityof, rand_prob_vec +using HiddenMarkovModels: + LightCategorical, LightDiagNormal, logdensityof, rand_prob_vec, rand_trans_mat using LinearAlgebra using Statistics using StatsAPI: fit! @@ -8,6 +9,29 @@ using Test rng = StableRNG(63) +function test_randprobvec(p) + @test all(>=(0), p) + @test sum(p) ≈ 1 +end + +function test_randtransmat(A) + foreach(eachrow(A)) do p + test_randprobvec(p) + end +end + +@testset "Rand prob" begin + n = 10 + test_randprobvec(rand_prob_vec(n)) + test_randprobvec(rand_prob_vec(rng, n)) + test_randprobvec(rand_prob_vec(Float32, n)) + test_randprobvec(rand_prob_vec(rng, Float32, n)) + test_randtransmat(rand_trans_mat(n)) + test_randtransmat(rand_trans_mat(rng, n)) + test_randtransmat(rand_trans_mat(Float32, n)) + test_randtransmat(rand_trans_mat(rng, Float32, n)) +end + function test_fit_allocs(dist, x, w) dist_copy = deepcopy(dist) allocs = @allocated fit!(dist_copy, x, w) @@ -17,6 +41,7 @@ end @testset "LightCategorical" begin p = rand_prob_vec(rng, 10) dist = LightCategorical(p) + @test startswith(string(dist), "LightCategorical") x = [(@inferred rand(rng, dist)) for _ in 1:100_000] # Simulation val_count = zeros(Int, length(p)) @@ -25,7 +50,7 @@ end end @test val_count ./ length(x) ≈ p atol = 2e-2 # Fitting - dist_est = deepcopy(dist) + dist_est = LightCategorical(rand_prob_vec(rng, 10)) w = ones(length(x)) fit!(dist_est, x, w) @test dist_est.p ≈ p atol = 2e-2 @@ -38,12 +63,13 @@ end μ = randn(rng, 10) σ = rand(rng, 10) dist = LightDiagNormal(μ, σ) + @test startswith(string(dist), "LightDiagNormal") x = [(@inferred rand(rng, dist)) for _ in 1:100_000] # Simulation @test mean(x) ≈ μ atol = 2e-2 @test std(x) ≈ σ atol = 2e-2 # Fitting - dist_est = deepcopy(dist) + dist_est = LightDiagNormal(randn(rng, 10), rand(rng, 10)) w = ones(length(x)) fit!(dist_est, x, w) @test dist_est.μ ≈ μ atol = 2e-2 diff --git a/test/runtests.jl b/test/runtests.jl index 96b2f0fa..39a06cbd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,41 +6,51 @@ using JuliaFormatter: JuliaFormatter using Pkg using Test +TEST_SUITE = get(ENV, "JULIA_HMM_TEST_SUITE", "Standard") +if TEST_SUITE == "HMMBase" + Pkg.add("HMMBase") + using HMMBase: HMMBase +end + Pkg.develop(; path=joinpath(dirname(@__DIR__), "libs", "HMMTest")) @testset verbose = true "HiddenMarkovModels.jl" begin - @testset "Code formatting" begin - @test JuliaFormatter.format(HiddenMarkovModels; verbose=false, overwrite=false) - end + if TEST_SUITE == "Standard" + @testset "Code formatting" begin + @test JuliaFormatter.format(HiddenMarkovModels; verbose=false, overwrite=false) + end - @testset "Code quality" begin - Aqua.test_all( - HiddenMarkovModels; ambiguities=false, deps_compat=(check_extras=false,) - ) - end + @testset "Code quality" begin + Aqua.test_all( + HiddenMarkovModels; ambiguities=false, deps_compat=(check_extras=false,) + ) + end - @testset "Code linting" begin - using Distributions - using Zygote - JET.test_package(HiddenMarkovModels; target_defined_modules=true) - end + @testset "Code linting" begin + using Distributions + using Zygote + if VERSION >= v"1.10" + JET.test_package(HiddenMarkovModels; target_defined_modules=true) + end + end - @testset "Distributions" begin - include("distributions.jl") - end + @testset "Distributions" begin + include("distributions.jl") + end - @testset "Correctness" begin - include("correctness.jl") - end + examples_path = joinpath(dirname(@__DIR__), "examples") + for file in readdir(examples_path) + @testset "Example - $file" begin + include(joinpath(examples_path, file)) + end + end - examples_path = joinpath(dirname(@__DIR__), "examples") - for file in readdir(examples_path) - @testset "Example - $file" begin - include(joinpath(examples_path, file)) + @testset "Doctests" begin + Documenter.doctest(HiddenMarkovModels) end end - @testset "Doctests" begin - Documenter.doctest(HiddenMarkovModels) + @testset verbose = true "Correctness - $TEST_SUITE" begin + include("correctness.jl") end end