From d2eacac9454d81bc17ed2c2953bced57c807ec30 Mon Sep 17 00:00:00 2001 From: stecrotti Date: Sun, 22 Dec 2024 17:55:11 +0100 Subject: [PATCH] fix small bugs, add docs and notebook --- examples/ksat.ipynb | 28 ++++++++++++++-------------- src/Models/ksat.jl | 9 +++++---- test/Models/ksat.jl | 2 ++ 3 files changed, 21 insertions(+), 18 deletions(-) diff --git a/examples/ksat.ipynb b/examples/ksat.ipynb index dd304a4..ddc3b27 100644 --- a/examples/ksat.ipynb +++ b/examples/ksat.ipynb @@ -12,7 +12,7 @@ }, { "cell_type": "code", - "execution_count": 95, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -29,7 +29,7 @@ }, { "cell_type": "code", - "execution_count": 101, + "execution_count": 18, "metadata": {}, "outputs": [], "source": [ @@ -39,7 +39,7 @@ "k = 3\n", "g = rand_regular_factor_graph(rng, n, m, k)\n", "ψ = [KSATClause(bitrand(rng, length(neighbors(g, factor(a))))) for a in factors(g)]\n", - "bp = BP(g, ψ, fill(2, nvariables(g)));" + "bp = fast_ksat_bp(g, ψ);" ] }, { @@ -51,13 +51,13 @@ }, { "cell_type": "code", - "execution_count": 103, + "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "40" + "43" ] }, "metadata": {}, @@ -65,7 +65,7 @@ } ], "source": [ - "iters = iterate!(bp; maxiter=1000, tol=1e-12, rein=5e-2)" + "iters = iterate!(bp; maxiter=1000, tol=1e-14, rein=5e-2)" ] }, { @@ -78,14 +78,14 @@ }, { "cell_type": "code", - "execution_count": 108, + "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "1×100 adjoint(::Vector{Int64}) with eltype Int64:\n", - " 0 0 0 1 0 0 1 0 1 0 0 0 1 … 0 0 1 0 0 0 0 1 1 0 0 1" + "1×1000 adjoint(::Vector{Int64}) with eltype Int64:\n", + " 0 0 0 1 0 0 1 0 1 1 0 0 1 … 0 0 0 0 0 0 0 1 0 0 0 1" ] }, "metadata": {}, @@ -106,7 +106,7 @@ }, { "cell_type": "code", - "execution_count": 112, + "execution_count": 21, "metadata": {}, "outputs": [ { @@ -118,7 +118,7 @@ } ], "source": [ - "nunsat = sum(1 - Int(bp.ψ[a](xstar[i] for i in neighbors(bp.g, factor(a)))) \n", + "nunsat = sum(!(Bool(bp.ψ[a](xstar[i]+1 for i in neighbors(bp.g, factor(a))))) \n", " for a in factors(bp.g))\n", "println(\"Number of unsatisfied clauses: $nunsat\")" ] @@ -126,15 +126,15 @@ ], "metadata": { "kernelspec": { - "display_name": "Julia 1.10.5", + "display_name": "Julia 1.11.2", "language": "julia", - "name": "julia-1.10" + "name": "julia-1.11" }, "language_info": { "file_extension": ".jl", "mimetype": "application/julia", "name": "julia", - "version": "1.10.5" + "version": "1.11.2" } }, "nbformat": 4, diff --git a/src/Models/ksat.jl b/src/Models/ksat.jl index c514a6b..255ae48 100644 --- a/src/Models/ksat.jl +++ b/src/Models/ksat.jl @@ -35,6 +35,7 @@ const BPKSAT = BP{<:KSATClause, <:BPFactor, <:NTuple{2,<:Real}, <:NTuple{2,<:Rea @doc raw""" fast_ksat_bp(g::AbstractFactorGraph, ψ::Vector{<:KSATClause}, [ϕ]) +Return a specialized BP instance with `KSATClause` and messages encoded as tuples of two reals instead of vectors. ``` """ function fast_ksat_bp(g::AbstractFactorGraph, ψ::Vector{<:KSATClause}, @@ -55,9 +56,9 @@ Base.eltype(bp::BPKSAT) = eltype(eltype(eltype(bp.b))) function BeliefPropagation.reset!(bp::BPKSAT) (; u, h, b) = bp T = eltype(bp) - u .= (T(0.5), T(0.5)) - h .= (T(0.5), T(0.5)) - b .= (T(0.5), T(0.5)) + u .= ((T(0.5), T(0.5)),) + h .= ((T(0.5), T(0.5)),) + b .= ((T(0.5), T(0.5)),) return nothing end function BeliefPropagation.randomize!(rng::AbstractRNG, bp::BPKSAT) @@ -68,7 +69,7 @@ function BeliefPropagation.randomize!(rng::AbstractRNG, bp::BPKSAT) u[ia] = (ru, 1-ru) h[ia] = (rh, 1-rh) end - b .= (T(0.5), T(0.5)) + b .= ((T(0.5), T(0.5)),) return nothing end diff --git a/test/Models/ksat.jl b/test/Models/ksat.jl index f5d2de0..a77bf22 100644 --- a/test/Models/ksat.jl +++ b/test/Models/ksat.jl @@ -39,6 +39,8 @@ end g = rand_regular_factor_graph(rng, n, m, k) ψ = [KSATClause(bitrand(rng, degree(g, factor(a)))) for a in factors(g)] bp = fast_ksat_bp(g, ψ) + reset!(bp) + randomize!(bp) iterate!(bp; maxiter=50, tol=1e-10) b = beliefs(bp) fb = factor_beliefs(bp)