Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into pr/THargreaves/105
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Oct 1, 2024
2 parents 82fd465 + 882f2f7 commit 26dbff2
Show file tree
Hide file tree
Showing 31 changed files with 287 additions and 177 deletions.
17 changes: 10 additions & 7 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ on:
push:
branches:
- main
tags: ['*']
tags: ["*"]
pull_request:
concurrency:
# Skip intermediate builds: always.
Expand All @@ -12,16 +12,19 @@ concurrency:
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}
jobs:
test:
name: Julia ${{ matrix.version }} - ${{ github.event_name }}
name: Julia ${{ matrix.version }} - ${{ matrix.test_suite }}
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
version:
- '1.9'
- '1'
os:
- ubuntu-latest
- "1.9"
- "1"
test_suite:
- "Standard"
- "HMMBase"
env:
JULIA_HMM_TEST_SUITE: ${{ matrix.test_suite }}
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
Expand All @@ -36,4 +39,4 @@ jobs:
with:
files: lcov.info
token: ${{ secrets.CODECOV_TOKEN }}
fail_ci_if_error: true
fail_ci_if_error: true
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "HiddenMarkovModels"
uuid = "84ca31d5-effc-45e0-bfda-5a68cd981f47"
authors = ["Guillaume Dalle"]
version = "0.5.3"
version = "0.5.4"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down
1 change: 1 addition & 0 deletions examples/autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ using Enzyme: Enzyme
using ForwardDiff: ForwardDiff
using HiddenMarkovModels
import HiddenMarkovModels as HMMs
using HMMTest #src
using LinearAlgebra
using Random: Random, AbstractRNG
using StableRNGs
Expand Down
5 changes: 3 additions & 2 deletions examples/basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ This is important to keep in mind when testing new models.
In many applications, we have access to various observation sequences of different lengths.
=#

nb_seqs = 300
nb_seqs = 1000
long_obs_seqs = [last(rand(rng, hmm, rand(rng, 100:200))) for k in 1:nb_seqs];
typeof(long_obs_seqs)

Expand Down Expand Up @@ -257,7 +257,8 @@ hcat(initialization(hmm_est_concat), initialization(hmm))

# ## Tests #src

@test startswith(string(hmm), "Hidden") #src
@test length.(values(rand(hmm, T))) == (T, T); #src
control_seq = fill(nothing, last(seq_ends)); #src
test_identical_hmmbase(rng, hmm, 100; hmm_guess) #src
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess) #src
test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) #src
6 changes: 3 additions & 3 deletions examples/controlled.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ Simulation requires a vector of controls, each being a vector itself with the ri
Let us build several sequences of variable lengths.
=#

control_seqs = [[randn(rng, d) for t in 1:rand(100:200)] for k in 1:100];
control_seqs = [[randn(rng, d) for t in 1:rand(100:200)] for k in 1:1000];
obs_seqs = [rand(rng, hmm, control_seq).obs_seq for control_seq in control_seqs];

obs_seq = reduce(vcat, obs_seqs)
Expand Down Expand Up @@ -94,7 +94,7 @@ function StatsAPI.fit!(
fb_storage::HMMs.ForwardBackwardStorage,
obs_seq::AbstractVector,
control_seq::AbstractVector;
seq_ends::AbstractVector{Int},
seq_ends,
) where {T}
(; γ, ξ) = fb_storage
N = length(hmm)
Expand Down Expand Up @@ -151,5 +151,5 @@ hcat(hmm_est.dist_coeffs[2], hmm.dist_coeffs[2])

@test hmm_est.dist_coeffs[1] hmm.dist_coeffs[1] atol = 0.05 #src
@test hmm_est.dist_coeffs[2] hmm.dist_coeffs[2] atol = 0.05 #src
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, atol=0.08, init=false) #src
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) #src
test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) #src
2 changes: 1 addition & 1 deletion examples/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ function StatsAPI.fit!(
hmm::PriorHMM,
fb_storage::HiddenMarkovModels.ForwardBackwardStorage,
obs_seq::AbstractVector;
seq_ends::AbstractVector{Int},
seq_ends,
)
## initialize to defaults without observations
hmm.init .= 0
Expand Down
6 changes: 3 additions & 3 deletions examples/temporal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ function StatsAPI.fit!(
fb_storage::HMMs.ForwardBackwardStorage,
obs_seq::AbstractVector,
control_seq::AbstractVector;
seq_ends::AbstractVector{Int},
seq_ends,
) where {T}
(; γ, ξ) = fb_storage
L, N = period(hmm), length(hmm)
Expand Down Expand Up @@ -183,6 +183,6 @@ map(mean, hcat(obs_distributions(hmm_est, 2), obs_distributions(hmm, 2)))

# ## Tests #src

@test mean(obs_seq[1:2:end]) < 0 < mean(obs_seq[2:2:end]) #src
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, atol=0.09, init=false) #src
@test mean(obs_seqs[1][1:2:end]) < 0 < mean(obs_seqs[1][2:2:end]) #src
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) #src
test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) #src
5 changes: 2 additions & 3 deletions examples/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,9 @@ Another useful array type is [StaticArrays.jl](https://github.com/JuliaArrays/St

@test nnz(log_transition_matrix(hmm)) == nnz(transition_matrix(hmm)) #src

seq_ends = cumsum(rand(rng, 100:200, 100)); #src
seq_ends = cumsum(rand(rng, 100:200, 1000)); #src
control_seq = fill(nothing, last(seq_ends)); #src
test_identical_hmmbase(rng, hmm, 100; hmm_guess) #src
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false, atol=0.08) #src
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false) #src
test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess) #src
# https://github.com/JuliaSparse/SparseArrays.jl/issues/469 #src
@test_skip test_allocations(rng, hmm, control_seq; seq_ends, hmm_guess) #src
5 changes: 5 additions & 0 deletions ext/HiddenMarkovModelsDistributionsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ function HiddenMarkovModels.fit_in_sequence!(
return dists[i] = fit(typeof(dists[i]), reduce(hcat, x_vecs), w)
end

#=
# Matrix distribution fitting not supported by Distributions.jl at the moment
function HiddenMarkovModels.fit_in_sequence!(
dists::AbstractVector{<:MatrixDistribution},
i::Integer,
Expand All @@ -37,5 +41,6 @@ function HiddenMarkovModels.fit_in_sequence!(
end
dcat(M1, M2) = cat(M1, M2; dims=3)
=#

end
7 changes: 6 additions & 1 deletion libs/HMMTest/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,14 @@ version = "0.1.0"

[deps]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
HMMBase = "b2b3ca75-8444-5ffa-85e6-af70e2b64fe7"
HiddenMarkovModels = "84ca31d5-effc-45e0-bfda-5a68cd981f47"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[weakdeps]
HMMBase = "b2b3ca75-8444-5ffa-85e6-af70e2b64fe7"

[extensions]
HMMTestHMMBaseExt = "HMMBase"
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
module HMMTestHMMBaseExt

function test_identical_hmmbase(
using HiddenMarkovModels
import HiddenMarkovModels as HMMs
using HMMBase: HMMBase
using HMMTest
using Random: AbstractRNG
using Statistics: mean
using Test: @test, @testset, @test_broken

function HMMTest.test_identical_hmmbase(
rng::AbstractRNG,
hmm::AbstractHMM,
T::Integer;
Expand Down Expand Up @@ -54,3 +63,5 @@ function test_identical_hmmbase(
end
end
end

end
5 changes: 3 additions & 2 deletions libs/HMMTest/src/HMMTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@ module HMMTest

using BenchmarkTools: @ballocated
using HiddenMarkovModels
using HiddenMarkovModels: AbstractVectorOrNTuple
import HiddenMarkovModels as HMMs
using HMMBase: HMMBase
using JET: @test_opt, @test_call
using Random: AbstractRNG
using Statistics: mean
using Test: @test, @testset, @test_broken

function test_identical_hmmbase end # in extension

export transpose_hmm
export test_equal_hmms, test_coherent_algorithms
export test_identical_hmmbase
Expand All @@ -18,7 +20,6 @@ export test_type_stability
include("utils.jl")
include("coherence.jl")
include("allocations.jl")
include("hmmbase.jl")
include("jet.jl")

end
2 changes: 1 addition & 1 deletion libs/HMMTest/src/allocations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ function test_allocations(
rng::AbstractRNG,
hmm::AbstractHMM,
control_seq::AbstractVector;
seq_ends::AbstractVector{Int},
seq_ends::AbstractVectorOrNTuple{Int},
hmm_guess::Union{Nothing,AbstractHMM}=nothing,
)
@testset "Allocations" begin
Expand Down
2 changes: 1 addition & 1 deletion libs/HMMTest/src/coherence.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ function test_coherent_algorithms(
rng::AbstractRNG,
hmm::AbstractHMM,
control_seq::AbstractVector;
seq_ends::AbstractVector{Int},
seq_ends::AbstractVectorOrNTuple{Int},
hmm_guess::Union{Nothing,AbstractHMM}=nothing,
atol::Real=0.05,
init::Bool=true,
Expand Down
2 changes: 1 addition & 1 deletion libs/HMMTest/src/jet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ function test_type_stability(
rng::AbstractRNG,
hmm::AbstractHMM,
control_seq::AbstractVector;
seq_ends::AbstractVector{Int},
seq_ends::AbstractVectorOrNTuple{Int},
hmm_guess::Union{Nothing,AbstractHMM}=nothing,
)
@testset "Type stability" begin
Expand Down
2 changes: 1 addition & 1 deletion src/HiddenMarkovModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ using ChainRulesCore: ChainRulesCore, NoTangent, RuleConfig, rrule_via_ad
using DensityInterface: DensityInterface, DensityKind, HasDensity, NoDensity, logdensityof
using DocStringExtensions
using FillArrays: Fill
using LinearAlgebra: Transpose, dot, ldiv!, lmul!, mul!, parent
using LinearAlgebra: Transpose, axpy!, dot, ldiv!, lmul!, mul!, parent
using Random: Random, AbstractRNG, default_rng
using SparseArrays: AbstractSparseArray, SparseMatrixCSC, nonzeros, nnz, nzrange, rowvals
using StatsAPI: StatsAPI, fit, fit!
Expand Down
10 changes: 5 additions & 5 deletions src/inference/baum_welch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ function baum_welch_has_converged(
logL, logL_prev = logL_evolution[end], logL_evolution[end - 1]
progress = logL - logL_prev
if loglikelihood_increasing && progress < min(0, -atol)
error("Loglikelihood decreased in Baum-Welch")
error("Loglikelihood decreased from $logL_prev to $logL in Baum-Welch")
elseif progress < atol
return true
end
Expand All @@ -22,7 +22,7 @@ function baum_welch!(
hmm::AbstractHMM,
obs_seq::AbstractVector,
control_seq::AbstractVector;
seq_ends::AbstractVector{Int},
seq_ends::AbstractVectorOrNTuple{Int},
atol::Real,
max_iterations::Integer,
loglikelihood_increasing::Bool,
Expand Down Expand Up @@ -55,7 +55,7 @@ function baum_welch(
hmm_guess::AbstractHMM,
obs_seq::AbstractVector,
control_seq::AbstractVector=Fill(nothing, length(obs_seq));
seq_ends::AbstractVector{Int}=Fill(length(obs_seq), 1),
seq_ends::AbstractVectorOrNTuple{Int}=(length(obs_seq),),
atol=1e-5,
max_iterations=100,
loglikelihood_increasing=true,
Expand All @@ -73,7 +73,7 @@ function baum_welch(
seq_ends,
atol,
max_iterations,
loglikelihood_increasing=false,
loglikelihood_increasing,
)
return hmm, logL_evolution
end
Expand All @@ -85,7 +85,7 @@ function StatsAPI.fit!(
fb_storage::ForwardBackwardStorage,
obs_seq::AbstractVector,
control_seq::AbstractVector;
seq_ends::AbstractVector{Int},
seq_ends::AbstractVectorOrNTuple{Int},
)
return fit!(hmm, fb_storage, obs_seq; seq_ends)
end
4 changes: 2 additions & 2 deletions src/inference/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ function _params_and_loglikelihoods(
hmm::AbstractHMM,
obs_seq::Vector,
control_seq::AbstractVector=Fill(nothing, length(obs_seq));
seq_ends::AbstractVector{Int}=Fill(length(obs_seq), 1),
seq_ends::AbstractVectorOrNTuple{Int}=(length(obs_seq),),
)
init = initialization(hmm)
trans_by_time = mapreduce(_dcat, eachindex(control_seq)) do t
Expand All @@ -22,7 +22,7 @@ function ChainRulesCore.rrule(
hmm::AbstractHMM,
obs_seq::AbstractVector,
control_seq::AbstractVector=Fill(nothing, length(obs_seq));
seq_ends::AbstractVector{Int}=Fill(length(obs_seq), 1),
seq_ends::AbstractVectorOrNTuple{Int}=(length(obs_seq),),
)
_, pullback = rrule_via_ad(
rc, _params_and_loglikelihoods, hmm, obs_seq, control_seq; seq_ends
Expand Down
Loading

0 comments on commit 26dbff2

Please sign in to comment.