Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Promote TimeDependentSum to Dual when using ForwardDiff #378

Merged
merged 10 commits into from
Jan 2, 2024
8 changes: 7 additions & 1 deletion src/schroedinger.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,13 @@ end

function schroedinger_dynamic(tspan, psi0::T, H::AbstractTimeDependentOperator;
kwargs...) where {B,Bp,T<:Union{AbstractOperator{B,Bp},StateVector{B}}}
schroedinger_dynamic(tspan, psi0, schroedinger_dynamic_function(H); kwargs...)
promoted_tspan, psi0 = _promote_time_and_state(psi0, H, tspan)
if promoted_tspan !== tspan # promote H
promoted_H = TimeDependentSum(H.coefficients, H.static_op.operators; init_time=first(promoted_tspan))
return schroedinger_dynamic(promoted_tspan, psi0, schroedinger_dynamic_function(promoted_H); kwargs...)
else
return schroedinger_dynamic(promoted_tspan, psi0, schroedinger_dynamic_function(H); kwargs...)
end
end

"""
Expand Down
23 changes: 23 additions & 0 deletions test/test_ForwardDiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,26 @@ for u0 = (psi, psi', psi⊗psi') # test all methods of `rebuild`
end

end # testset

@testset "ForwardDiff with schroedinger using TimeDependentSum" begin

base=SpinBasis(1/2)
ψi = spinup(base)
ψt = spindown(base)
function Ftdop(q)
H = TimeDependentSum([q, abs2∘sinpi], [sigmaz(base), sigmax(base)])
_, ψf = timeevolution.schroedinger_dynamic(range(0,1,2), ψi, H)
abs2(ψt'last(ψf))
end
Ftdop(1.0)
@test ForwardDiff.derivative(Ftdop, 1.0) isa Any

function Ftdop(q)
H = TimeDependentSum([1, abs2∘sinpi], [sigmaz(base), q*sigmax(base)])
_, ψf = timeevolution.schroedinger_dynamic(range(0,1,2), ψi, H)
abs2(ψt'last(ψf))
end
Ftdop(1.0)
@test ForwardDiff.derivative(Ftdop, 1.0) isa Any

end # testset
Loading