Skip to content

Commit

Permalink
Indexing fix (#200)
Browse files Browse the repository at this point in the history
* fix test

* fix issue 198

* change_index(SingleSum)

* fix issue 189 - sum simplification

---------

Co-authored-by: christoph <[email protected]>
  • Loading branch information
ChristophHotter and christoph authored Mar 21, 2024
1 parent 05d8536 commit 302d849
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 7 deletions.
20 changes: 19 additions & 1 deletion src/index_average.jl
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,12 @@ function insert_index(term::BasicSymbolic{<:CNumber},ind::Index,value::Int64)
elseif op === +
return sum(insert_index(arg,ind,value) for arg in arguments(term))
elseif op === ^
return insert_index(arguments(term)[1],ind,value)^(arguments(term)[2])
return insert_index(arguments(term)[1],ind,value)^insert_index(arguments(term)[2],ind,value)
# issue 198
elseif op === /
return insert_index(arguments(term)[1],ind,value)/insert_index(arguments(term)[2],ind,value)
elseif length(arguments(term)) == 1 # exp, sin, cos, ln, ... #TODO: write tests
return op(insert_index(arguments(term)[1],ind,value))
end
end
return term
Expand Down Expand Up @@ -745,6 +750,19 @@ function eval_term(term::BasicSymbolic{<:CNumber};kwargs...)
if op === *
return prod(eval_term(arg;kwargs...) for arg in arguments(term))
end
# issue 198 #TODO: tests
if op === ^
args = arguments(term)
return eval_term(args[1];kwargs...)^eval_term(args[2];kwargs...)
end
if op === /
args = arguments(term)
return eval_term(args[1];kwargs...)/eval_term(args[2];kwargs...)
end

if length(arguments(term)) == 1 # exp, sin, cos, ln, ...
return op(eval_term(arguments(term)[1];kwargs...))
end
end
return term
end
Expand Down
23 changes: 21 additions & 2 deletions src/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ struct IndexedOperator <: QSym
end
end

const Summable = Union{<:QNumber,<:CNumber,<:BasicSymbolic{IndexedVariable},<:BasicSymbolic{DoubleIndexedVariable}}
# const Summable = Union{<:QNumber,<:CNumber,<:BasicSymbolic{IndexedVariable},<:BasicSymbolic{DoubleIndexedVariable}}
const Summable = Union{<:QNumber,<:CNumber,<:BasicSymbolic{IndexedVariable},<:BasicSymbolic{DoubleIndexedVariable},<:BasicSymbolic{CNumber}}

"""
SingleSum <: QTerm
Expand Down Expand Up @@ -185,6 +186,8 @@ function SingleSum(term::IndexedAdd, sum_index, non_equal_indices;metadata=NO_ME
args = arguments(term)
if op === +
return sum([SingleSum(arg,sum_index,non_equal_indices;metadata=NO_METADATA) for arg in args])
elseif (op === *) && (sum_index get_indices(term)) #issue 188
return SingleSum(term,sum_index,non_equal_indices,metadata)
else
return (sum_index.range - length(non_equal_indices))*term
end
Expand Down Expand Up @@ -661,11 +664,27 @@ function change_index(term::BasicSymbolic{<:CNumber},from::Index,to::Index)
end
if op === ^
args = arguments(term)
return change_index(args[1],from,to)^args[2]
return change_index(args[1],from,to)^change_index(args[2],from,to)
end
# issue 198
if op === /
args = arguments(term)
return change_index(args[1],from,to)/change_index(args[2],from,to)
end
if length(arguments(term)) == 1 # exp, sin, cos, ln, ...
return op(change_index(arguments(term)[1],from,to))
end
end
return term
end
# issue 196: TODO:test
function change_index(S::SingleSum, i::Index, j::Index)
(j S.non_equal_indices) && error("Index $(j) is in the non-equal index list.")
if S.sum_index == i
return SingleSum(change_index(S.term,i,j), j, replace(S.non_equal_indices, i=>j), S.metadata)
end
return S
end
change_index(x,from::Index,to::Index) = x

ismergeable(a::IndexedOperator,b::IndexedOperator) = isequal(a.ind,b.ind) ? ismergeable(a.op,b.op) : false
Expand Down
9 changes: 8 additions & 1 deletion test/test_index_basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ k_ind = indT(:k)
@test(isequal(change_index(Γij,j_ind,k_ind), DoubleIndexedVariable(,i_ind,k_ind)))
@test(isequal(change_index(σ(1,2,j_ind)*σ(1,2,i_ind),j_ind,i_ind),0))
@test(isequal(change_index(g(k_ind),k_ind,j_ind),g(j_ind)))
@test isequal(change_index((2g(i_ind),i_ind), i_ind, j_ind), (2g(j_ind),j_ind))

@test(isequal(
order_by_index(σ(1,2,k_ind)*σ(1,2,j_ind)*σ(1,2,i_ind),[i_ind]), σ(1,2,i_ind)*σ(1,2,k_ind)*σ(1,2,j_ind)
Expand Down Expand Up @@ -267,7 +268,7 @@ i.hilb
hc = NLevelSpace(:cavity, 3)
ha = NLevelSpace(:atom,2)
h = hc ha
@cnumbers N
@cnumbers N α
i = Index(h,:i,N,ha)
S(x,y) = Transition(h,:S,x,y,1)
σ(x,y,k) = IndexedOperator(Transition(h,,x,y,2),k)
Expand All @@ -292,5 +293,11 @@ arr = qc.create_index_arrays([i],[1:10])
@test isequal(qc._inconj(average(σ(2,1,1)*σ(2,2,2)*σ(1,2,1))),(average(σ(2,2,1)*σ(2,2,2))))
@test qc.ismergeable(σ(2,1,5),σ(1,2,5))

# issue 188
gi = IndexedVariable(:g, i)
@test isa((5gi,i), SingleSum)
@test isa((gi*α,i), SingleSum)
@test isequal((α,i), N*α)
@test isequal((5α,i), 5*N*α)

end
2 changes: 1 addition & 1 deletion test/test_indexed_meanfield.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ sol_ss = solve(prob, Tsit5(), save_everystep=false, save_on=false, save_start=fa

@test length(eqs_4) == length(eqs)

@test get_solution(sol_ss, 2a + σ(1,1,1))[1] == (2*sol_ss[a][1] + 1- sol_ss[σ(2,2,1)][1]) == get_solution(sol_ss, average(2a + σ(1,1,1)))[1]
@test get_solution(sol_ss, 2a + σ(1,1,1))[1] (2*sol_ss[a][1] + 1- sol_ss[σ(2,2,1)][1]) get_solution(sol_ss, average(2a + σ(1,1,1)))[1]
@test isequal(σ(1,1,2), 1-σ(2,2,2))

order = 1
Expand Down
4 changes: 2 additions & 2 deletions test/test_spin.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ sx(1) == σx(1)
@test isequal(σz(2)*σz(1), σz(1)*σz(2))

@cnumbers J
Δ(i) = cnumber(Symbol(:Δ_,i))
H = Δ(1)*σz(1) + Δ(2)*σz(2) + J*σx(1)*σx(2)
Δi(i) = cnumber(Symbol(:Δ_,i))
H = Δi(1)*σz(1) + Δi(2)*σz(2) + J*σx(1)*σx(2)

ops = [σz(1)]
eqs = meanfield(ops, H)
Expand Down

0 comments on commit 302d849

Please sign in to comment.