From 04e8d12b33f72122e4f924633ebfc6fad8d2f6da Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 21 May 2024 11:26:21 +0530 Subject: [PATCH 01/13] feat: allow build_explicit_function to generate param-only observed --- src/systems/diffeqs/odesystem.jl | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index 5d1bae95ec..e28f1ece3b 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -387,6 +387,7 @@ function build_explicit_observed_function(sys, ts; drop_expr = drop_expr, ps = full_parameters(sys), return_inplace = false, + param_only = false, op = Operator, throw = true) if (isscalar = symbolic_type(ts) !== NotSymbolic()) @@ -399,7 +400,16 @@ function build_explicit_observed_function(sys, ts; ivs = independent_variables(sys) dep_vars = scalarize(setdiff(vars, ivs)) - obs = observed(sys) + obs = param_only ? Equation[] : observed(sys) + if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing + # each subsystem is topologically sorted independently. We can append the + # equations to override the `lhs ~ 0` equations in `observed(sys)` + syss, _, continuous_id, _... = dss + for (i, subsys) in enumerate(syss) + i == continuous_id && continue + append!(obs, observed(subsys)) + end + end cs = collect_constants(obs) if !isempty(cs) > 0 @@ -407,8 +417,9 @@ function build_explicit_observed_function(sys, ts; obs = map(x -> x.lhs ~ substitute(x.rhs, cmap), obs) end - sts = Set(unknowns(sys)) - sts = union(sts, + sts = param_only ? Set() : Set(unknowns(sys)) + sts = param_only ? Set() : + union(sts, Set(arguments(st)[1] for st in sts if iscall(st) && operation(st) === getindex)) observed_idx = Dict(x.lhs => i for (i, x) in enumerate(obs)) @@ -420,7 +431,8 @@ function build_explicit_observed_function(sys, ts; Set(arguments(p)[1] for p in param_set_ns if iscall(p) && operation(p) === getindex)) namespaced_to_obs = Dict(unknowns(sys, x.lhs) => x.lhs for x in obs) - namespaced_to_sts = Dict(unknowns(sys, x) => x for x in unknowns(sys)) + namespaced_to_sts = param_only ? Dict() : + Dict(unknowns(sys, x) => x for x in unknowns(sys)) # FIXME: This is a rather rough estimate of dependencies. We assume # the expression depends on everything before the `maxidx`. @@ -485,11 +497,11 @@ function build_explicit_observed_function(sys, ts; end dvs = DestructuredArgs(unknowns(sys), inbounds = !checkbounds) if inputs === nothing - args = [dvs, ps..., ivs...] + args = param_only ? [ps..., ivs...] : [dvs, ps..., ivs...] else inputs = unwrap.(inputs) ipts = DestructuredArgs(inputs, inbounds = !checkbounds) - args = [dvs, ipts, ps..., ivs...] + args = param_only ? [ipts, ps..., ivs...] : [dvs, ipts, ps..., ivs...] end pre = get_postprocess_fbody(sys) From fc1a4bddf5a3cca8438b7ecf3c21568bb65f731c Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 21 May 2024 11:27:00 +0530 Subject: [PATCH 02/13] fix: check hasname before using getname --- src/systems/abstractsystem.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 97d5dfc970..e02f6ead0a 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -461,11 +461,12 @@ function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym::Symbol if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing return is_parameter(ic, sym) end - return any(isequal(sym), getname.(parameter_symbols(sys))) || + + named_parameters = [getname(sym) for sym in parameter_symbols(sys) if hasname(sym)] + return any(isequal(sym), named_parameters) || count(NAMESPACE_SEPARATOR, string(sym)) == 1 && count(isequal(sym), - Symbol.(nameof(sys), NAMESPACE_SEPARATOR_SYMBOL, getname.(parameter_symbols(sys)))) == - 1 + Symbol.(nameof(sys), NAMESPACE_SEPARATOR_SYMBOL, named_parameters)) == 1 end function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym) From 4d7f9748bea7fde94c940f2f742a46aaf554a111 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 21 May 2024 11:27:52 +0530 Subject: [PATCH 03/13] refactor: continuous system is always last discrete subsystem --- src/systems/clock_inference.jl | 42 +++++++++++++++------------------- test/clock.jl | 7 +++--- 2 files changed, 22 insertions(+), 27 deletions(-) diff --git a/src/systems/clock_inference.jl b/src/systems/clock_inference.jl index dc1d612d73..7d9b3bc6ad 100644 --- a/src/systems/clock_inference.jl +++ b/src/systems/clock_inference.jl @@ -186,6 +186,13 @@ function split_system(ci::ClockInference{S}) where {S} end tss[id] = ts_i end + if continuous_id != 0 + tss[continuous_id], tss[end] = tss[end], tss[continuous_id] + inputs[continuous_id], inputs[end] = inputs[end], inputs[continuous_id] + id_to_clock[continuous_id], id_to_clock[end] = id_to_clock[end], + id_to_clock[continuous_id] + continuous_id = lastindex(tss) + end return tss, inputs, continuous_id, id_to_clock end @@ -270,25 +277,9 @@ function generate_discrete_affect( ], [], let_block) |> toexpr - if use_index_cache - cont_to_disc_idxs = [parameter_index(osys, sym) for sym in input] - disc_range = [parameter_index(osys, sym) for sym in unknowns(sys)] - else - cont_to_disc_idxs = (offset + 1):(offset += ni) - input_offset = offset - disc_range = (offset + 1):(offset += ns) - end - save_vec = Expr(:ref, :Float64) - if use_index_cache - for unk in unknowns(sys) - idx = parameter_index(osys, unk) - push!(save_vec.args, :($(parameter_values)(p, $idx))) - end - else - for i in 1:ns - push!(save_vec.args, :(p[$(input_offset + i)])) - end - end + cont_to_disc_idxs = [parameter_index(osys, sym) for sym in input] + disc_range = [parameter_index(osys, sym) for sym in unknowns(sys)] + save_expr = :($(SciMLBase.save_discretes!)(integrator, $i)) empty_disc = isempty(disc_range) disc_init = if use_index_cache :(function (u, p, t) @@ -351,11 +342,14 @@ function generate_discrete_affect( # d2c comes last # @show t # @show "incoming", p - $( - if use_index_cache + result = c2d_obs(u, p..., t) + for (val, i) in zip(result, $cont_to_disc_idxs) + $(_set_parameter_unchecked!)(p, val, i; update_dependent = false) + end + $(if !empty_disc quote - result = c2d_obs(integrator.u, p..., t) - for (val, i) in zip(result, $cont_to_disc_idxs) + disc(disc_unknowns, u, p..., t) + for (val, i) in zip(disc_unknowns, $disc_range) $(_set_parameter_unchecked!)(p, val, i; update_dependent = false) end end @@ -406,7 +400,7 @@ function generate_discrete_affect( end ) end) - sv = SavedValues(Float64, Vector{Float64}) + push!(affect_funs, affect!) push!(init_funs, disc_init) push!(svs, sv) diff --git a/test/clock.jl b/test/clock.jl index 86967365ad..69b7c30c50 100644 --- a/test/clock.jl +++ b/test/clock.jl @@ -64,10 +64,11 @@ By inference: ci, varmap = infer_clocks(sys) eqmap = ci.eq_domain -tss, inputs = ModelingToolkit.split_system(deepcopy(ci)) -sss, = ModelingToolkit._structural_simplify!(deepcopy(tss[1]), (inputs[1], ())) +tss, inputs, continuous_id = ModelingToolkit.split_system(deepcopy(ci)) +sss, = ModelingToolkit._structural_simplify!( + deepcopy(tss[continuous_id]), (inputs[continuous_id], ())) @test equations(sss) == [D(x) ~ u - x] -sss, = ModelingToolkit._structural_simplify!(deepcopy(tss[2]), (inputs[2], ())) +sss, = ModelingToolkit._structural_simplify!(deepcopy(tss[1]), (inputs[1], ())) @test isempty(equations(sss)) d = Clock(t, dt) k = ShiftIndex(d) From bc6a1867ac3ac9726e7b7ec335edf999773a847e Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 21 May 2024 11:28:14 +0530 Subject: [PATCH 04/13] fix: unwrap in `vars` --- src/utils.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index 5d6af76b77..acd7a8686d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -370,7 +370,9 @@ function vars(exprs::Symbolic; op = Differential) end vars(exprs::Num; op = Differential) = vars(unwrap(exprs); op) vars(exprs::Symbolics.Arr; op = Differential) = vars(unwrap(exprs); op) -vars(exprs; op = Differential) = foldl((x, y) -> vars!(x, y; op = op), exprs; init = Set()) +function vars(exprs; op = Differential) + foldl((x, y) -> vars!(x, unwrap(y); op = op), exprs; init = Set()) +end vars(eq::Equation; op = Differential) = vars!(Set(), eq; op = op) function vars!(vars, eq::Equation; op = Differential) (vars!(vars, eq.lhs; op = op); vars!(vars, eq.rhs; op = op); vars) From bfc73a3085bb81bd4377c4d7aa68926eb1f4a986 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 21 May 2024 11:30:17 +0530 Subject: [PATCH 05/13] refactor: store parameters from different clock partitions separately --- src/systems/abstractsystem.jl | 88 +++++++++++++-- src/systems/index_cache.jl | 184 +++++++++++++++++++++++++++----- src/systems/parameter_buffer.jl | 155 ++++++++++++++++++--------- 3 files changed, 341 insertions(+), 86 deletions(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index e02f6ead0a..368cab4c54 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -446,7 +446,7 @@ end function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym) sym = unwrap(sym) if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing - return is_parameter(ic, sym) || + return sym isa ParameterIndex || is_parameter(ic, sym) || iscall(sym) && operation(sym) === getindex && is_parameter(ic, first(arguments(sym))) end @@ -472,11 +472,21 @@ end function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym) sym = unwrap(sym) if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing - return if (idx = parameter_index(ic, sym)) !== nothing - idx + return if sym isa ParameterIndex + sym + elseif (idx = parameter_index(ic, sym)) !== nothing + if idx.portion isa SciMLStructures.Discrete && idx.idx[2] == idx.idx[3] == 0 + return nothing + else + idx + end elseif iscall(sym) && operation(sym) === getindex && (idx = parameter_index(ic, first(arguments(sym)))) !== nothing - ParameterIndex(idx.portion, (idx.idx..., arguments(sym)[(begin + 1):end]...)) + if idx.portion isa SciMLStructures.Discrete && idx.idx[2] == idx.idx[3] == nothing + return nothing + else + ParameterIndex(idx.portion, (idx.idx..., arguments(sym)[(begin + 1):end]...)) + end else nothing end @@ -494,7 +504,12 @@ end function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym::Symbol) if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing - return parameter_index(ic, sym) + idx = parameter_index(ic, sym) + if idx === nothing || idx.portion isa SciMLStructures.Discrete && idx.idx[2] == idx.idx[3] == 0 + return nothing + else + return idx + end end idx = findfirst(isequal(sym), getname.(parameter_symbols(sys))) if idx !== nothing @@ -507,6 +522,67 @@ function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym::Sym return nothing end +function SymbolicIndexingInterface.is_timeseries_parameter(sys::AbstractSystem, sym) + has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing || return false + is_timeseries_parameter(ic, sym) +end + +function SymbolicIndexingInterface.timeseries_parameter_index(sys::AbstractSystem, sym) + has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing || return nothing + timeseries_parameter_index(ic, sym) +end + +function SymbolicIndexingInterface.parameter_observed(sys::AbstractSystem, sym) + if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing + allvars = vars(sym; op = Symbolics.Operator) + ts_idxs = Set{Int}() + for var in allvars + var = unwrap(var) + # FIXME: Shouldn't have to shift systems + if istree(var) && (op = operation(var)) isa Shift && op.steps == 1 + var = only(arguments(var)) + end + ts_idx = check_index_map(ic.discrete_idx, unwrap(var)) + ts_idx === nothing && continue + push!(ts_idxs, ts_idx[1]) + end + if length(ts_idxs) == 1 + ts_idx = only(ts_idxs) + else + ts_idx = nothing + end + rawobs = build_explicit_observed_function( + sys, sym; param_only = true, return_inplace = true) + if rawobs isa Tuple + if is_time_dependent(sys) + obsfn = let oop = rawobs[1], iip = rawobs[2] + f1a(p::MTKParameters, t) = oop(p..., t) + f1a(out, p::MTKParameters, t) = iip(out, p..., t) + end + else + obsfn = let oop = rawobs[1], iip = rawobs[2] + f1b(p::MTKParameters) = oop(p...) + f1b(out, p::MTKParameters) = iip(out, p...) + end + end + else + if is_time_dependent(sys) + obsfn = let rawobs = rawobs + f2a(p::MTKParameters, t) = rawobs(p..., t) + end + else + obsfn = let rawobs = rawobs + f2b(p::MTKParameters) = rawobs(p...) + end + end + end + else + ts_idx = nothing + obsfn = build_explicit_observed_function(sys, sym; param_only = true) + end + return ParameterObservedFunction(ts_idx, obsfn) +end + function SymbolicIndexingInterface.parameter_symbols(sys::AbstractSystem) return full_parameters(sys) end @@ -524,7 +600,7 @@ function SymbolicIndexingInterface.independent_variable_symbols(sys::AbstractSys end function SymbolicIndexingInterface.is_observed(sys::AbstractSystem, sym) - return !is_variable(sys, sym) && !is_parameter(sys, sym) && + return !is_variable(sys, sym) && parameter_index(sys, sym) === nothing && !is_independent_variable(sys, sym) && symbolic_type(sym) != NotSymbolic() end diff --git a/src/systems/index_cache.jl b/src/systems/index_cache.jl index 75c8a7e235..b1063f214e 100644 --- a/src/systems/index_cache.jl +++ b/src/systems/index_cache.jl @@ -27,12 +27,12 @@ const UnknownIndexMap = Dict{ struct IndexCache unknown_idx::UnknownIndexMap - discrete_idx::ParamIndexMap + discrete_idx::Dict{Union{Symbol, BasicSymbolic}, Tuple{Int, Int, Int}} tunable_idx::ParamIndexMap constant_idx::ParamIndexMap dependent_idx::ParamIndexMap nonnumeric_idx::ParamIndexMap - discrete_buffer_sizes::Vector{BufferTemplate} + discrete_buffer_sizes::Vector{Vector{BufferTemplate}} tunable_buffer_sizes::Vector{BufferTemplate} constant_buffer_sizes::Vector{BufferTemplate} dependent_buffer_sizes::Vector{BufferTemplate} @@ -86,7 +86,8 @@ function IndexCache(sys::AbstractSystem) end end - disc_buffers = Dict{Any, Set{BasicSymbolic}}() + disc_buffers = Dict{Int, Dict{Any, Set{BasicSymbolic}}}() + disc_clocks = Dict{Union{Symbol, BasicSymbolic}, Int}() tunable_buffers = Dict{Any, Set{BasicSymbolic}}() constant_buffers = Dict{Any, Set{BasicSymbolic}}() dependent_buffers = Dict{Any, Set{BasicSymbolic}}() @@ -99,27 +100,106 @@ function IndexCache(sys::AbstractSystem) push!(buf, sym) end + if has_discrete_subsystems(sys) && get_discrete_subsystems(sys) !== nothing + syss, inputs, continuous_id, _ = get_discrete_subsystems(sys) + + for (i, (inps, disc_sys)) in enumerate(zip(inputs, syss)) + i == continuous_id && continue + disc_buffers[i] = Dict{Any, Set{BasicSymbolic}}() + + for inp in inps + inp = unwrap(inp) + is_parameter(sys, inp) || + error("Discrete subsystem $i input $inp is not a parameter") + disc_clocks[inp] = i + disc_clocks[default_toterm(inp)] = i + if hasname(inp) && (!istree(inp) || operation(inp) !== getindex) + disc_clocks[getname(inp)] = i + disc_clocks[default_toterm(inp)] = i + end + insert_by_type!(disc_buffers[i], inp) + end + + for sym in unknowns(disc_sys) + sym = unwrap(sym) + is_parameter(sys, sym) || + error("Discrete subsystem $i unknown $sym is not a parameter") + disc_clocks[sym] = i + disc_clocks[default_toterm(sym)] = i + if hasname(sym) && (!istree(sym) || operation(sym) !== getindex) + disc_clocks[getname(sym)] = i + disc_clocks[getname(default_toterm(sym))] = i + end + insert_by_type!(disc_buffers[i], sym) + end + t = get_iv(sys) + for eq in observed(disc_sys) + # TODO: Is this a valid check + # FIXME: This shouldn't be necessary + eq.rhs === -0.0 && continue + sym = eq.lhs + if istree(sym) && operation(sym) == Shift(t, 1) + sym = only(arguments(sym)) + end + disc_clocks[sym] = i + disc_clocks[sym] = i + disc_clocks[default_toterm(sym)] = i + if hasname(sym) && (!istree(sym) || operation(sym) !== getindex) + disc_clocks[getname(sym)] = i + disc_clocks[getname(default_toterm(sym))] = i + end + end + end + + for par in inputs[continuous_id] + is_parameter(sys, par) || error("Discrete subsystem input is not a parameter") + istree(par) && operation(par) isa Hold || + error("Continuous subsystem input is not a Hold") + if haskey(disc_clocks, par) + sym = par + else + sym = first(arguments(par)) + end + haskey(disc_clocks, sym) || + error("Variable $par not part of a discrete subsystem") + disc_clocks[par] = disc_clocks[sym] + insert_by_type!(disc_buffers[disc_clocks[sym]], par) + end + end + affs = vcat(affects(continuous_events(sys)), affects(discrete_events(sys))) + user_affect_clock = maximum(values(disc_clocks); init = 0) + 1 for affect in affs if affect isa Equation is_parameter(sys, affect.lhs) || continue - insert_by_type!(disc_buffers, affect.lhs) + + disc_clocks[affect.lhs] = user_affect_clock + disc_clocks[default_toterm(affect.lhs)] = user_affect_clock + if hasname(affect.lhs) && + (!istree(affect.lhs) || operation(affect.lhs) !== getindex) + disc_clocks[getname(affect.lhs)] = user_affect_clock + disc_clocks[getname(default_toterm(affect.lhs))] = user_affect_clock + end + buffer = get!(disc_buffers, user_affect_clock, Dict{Any, Set{BasicSymbolic}}()) + insert_by_type!(buffer, affect.lhs) else discs = discretes(affect) for disc in discs is_parameter(sys, disc) || error("Expected discrete variable $disc in callback to be a parameter") - insert_by_type!(disc_buffers, disc) + disc = unwrap(disc) + disc_clocks[disc] = user_affect_clock + disc_clocks[default_toterm(disc)] = user_affect_clock + if hasname(disc) && (!istree(disc) || operation(disc) !== getindex) + disc_clocks[getname(disc)] = user_affect_clock + disc_clocks[getname(default_toterm(disc))] = user_affect_clock + end + buffer = get!( + disc_buffers, user_affect_clock, Dict{Any, Set{BasicSymbolic}}()) + insert_by_type!(buffer, disc) end end end - if has_discrete_subsystems(sys) && get_discrete_subsystems(sys) !== nothing - _, inputs, continuous_id, _ = get_discrete_subsystems(sys) - for par in inputs[continuous_id] - is_parameter(sys, par) || error("Discrete subsystem input is not a parameter") - insert_by_type!(disc_buffers, par) - end - end if has_parameter_dependencies(sys) pdeps = parameter_dependencies(sys) @@ -132,13 +212,11 @@ function IndexCache(sys::AbstractSystem) for p in parameters(sys) p = unwrap(p) ctype = symtype(p) - haskey(disc_buffers, ctype) && p in disc_buffers[ctype] && continue + haskey(disc_clocks, p) && continue haskey(dependent_buffers, ctype) && p in dependent_buffers[ctype] && continue insert_by_type!( if ctype <: Real || ctype <: AbstractArray{<:Real} - if is_discrete_domain(p) - disc_buffers - elseif istunable(p, true) && Symbolics.shape(p) !== Symbolics.Unknown() + if istunable(p, true) && Symbolics.shape(p) !== Symbolics.Unknown() tunable_buffers else constant_buffers @@ -150,6 +228,40 @@ function IndexCache(sys::AbstractSystem) ) end + disc_idxs = Dict{Union{Symbol, BasicSymbolic}, Tuple{Int, Int, Int}}() + disc_buffer_sizes = [BufferTemplate[] for _ in 1:length(disc_buffers)] + disc_buffer_types = Set() + for buffer in values(disc_buffers) + union!(disc_buffer_types, keys(buffer)) + end + + for (clockidx, buffer) in disc_buffers + for (i, btype) in enumerate(disc_buffer_types) + if !haskey(buffer, btype) + push!(disc_buffer_sizes[clockidx], BufferTemplate(btype, 0)) + continue + end + push!(disc_buffer_sizes[clockidx], BufferTemplate(btype, length(buffer[btype]))) + for (j, sym) in enumerate(buffer[btype]) + disc_idxs[sym] = (clockidx, i, j) + disc_idxs[default_toterm(sym)] = (clockidx, i, j) + if hasname(sym) && (!istree(sym) || operation(sym) !== getindex) + disc_idxs[getname(sym)] = (clockidx, i, j) + disc_idxs[getname(default_toterm(sym))] = (clockidx, i, j) + end + end + end + end + for (sym, clockid) in disc_clocks + haskey(disc_idxs, sym) && continue + disc_idxs[sym] = (clockid, 0, 0) + disc_idxs[default_toterm(sym)] = (clockid, 0, 0) + if hasname(sym) && (!istree(sym) || operation(sym) !== getindex) + disc_idxs[getname(sym)] = (clockid, 0, 0) + disc_idxs[getname(default_toterm(sym))] = (clockid, 0, 0) + end + end + function get_buffer_sizes_and_idxs(buffers::Dict{Any, Set{BasicSymbolic}}) idxs = ParamIndexMap() buffer_sizes = BufferTemplate[] @@ -168,7 +280,7 @@ function IndexCache(sys::AbstractSystem) end return idxs, buffer_sizes end - disc_idxs, discrete_buffer_sizes = get_buffer_sizes_and_idxs(disc_buffers) + tunable_idxs, tunable_buffer_sizes = get_buffer_sizes_and_idxs(tunable_buffers) const_idxs, const_buffer_sizes = get_buffer_sizes_and_idxs(constant_buffers) dependent_idxs, dependent_buffer_sizes = get_buffer_sizes_and_idxs(dependent_buffers) @@ -181,7 +293,7 @@ function IndexCache(sys::AbstractSystem) const_idxs, dependent_idxs, nonnumeric_idxs, - discrete_buffer_sizes, + disc_buffer_sizes, tunable_buffer_sizes, const_buffer_sizes, dependent_buffer_sizes, @@ -227,6 +339,17 @@ function SymbolicIndexingInterface.parameter_index(ic::IndexCache, sym) end end +function SymbolicIndexingInterface.is_timeseries_parameter(ic::IndexCache, sym) + return check_index_map(ic.discrete_idx, sym) !== nothing +end + +function SymbolicIndexingInterface.timeseries_parameter_index(ic::IndexCache, sym) + idx = check_index_map(ic.discrete_idx, sym) + idx === nothing && return nothing + clockid, partitionid... = idx + return ParameterTimeseriesIndex(clockid, partitionid) +end + function check_index_map(idxmap, sym) if (idx = get(idxmap, sym, nothing)) !== nothing return idx @@ -249,10 +372,14 @@ end function discrete_linear_index(ic::IndexCache, idx::ParameterIndex) idx.portion isa SciMLStructures.Discrete || error("Discrete variable index expected") ind = sum(temp.length for temp in ic.tunable_buffer_sizes; init = 0) + for clockbuftemps in Iterators.take(ic.discrete_buffer_sizes, idx.idx[1] - 1) + ind += sum(temp.length for temp in clockbuftemps; init = 0) + end ind += sum( - temp.length for temp in Iterators.take(ic.discrete_buffer_sizes, idx.idx[1] - 1); + temp.length + for temp in Iterators.take(ic.discrete_buffer_sizes[idx.idx[1]], idx.idx[2] - 1); init = 0) - ind += idx.idx[2] + ind += idx.idx[3] return ind end @@ -271,30 +398,31 @@ function reorder_parameters(ic::IndexCache, ps; drop_missing = false) param_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:(temp.length)] for temp in ic.tunable_buffer_sizes) disc_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:(temp.length)] - for temp in ic.discrete_buffer_sizes) + for temp in Iterators.flatten(ic.discrete_buffer_sizes)) const_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:(temp.length)] for temp in ic.constant_buffer_sizes) dep_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:(temp.length)] for temp in ic.dependent_buffer_sizes) nonnumeric_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:(temp.length)] for temp in ic.nonnumeric_buffer_sizes) - for p in ps + p = unwrap(p) if haskey(ic.discrete_idx, p) - i, j = ic.discrete_idx[p] - disc_buf[i][j] = unwrap(p) + disc_offset = length(first(ic.discrete_buffer_sizes)) + i, j, k = ic.discrete_idx[p] + disc_buf[(i - 1) * disc_offset + j][k] = p elseif haskey(ic.tunable_idx, p) i, j = ic.tunable_idx[p] - param_buf[i][j] = unwrap(p) + param_buf[i][j] = p elseif haskey(ic.constant_idx, p) i, j = ic.constant_idx[p] - const_buf[i][j] = unwrap(p) + const_buf[i][j] = p elseif haskey(ic.dependent_idx, p) i, j = ic.dependent_idx[p] - dep_buf[i][j] = unwrap(p) + dep_buf[i][j] = p elseif haskey(ic.nonnumeric_idx, p) i, j = ic.nonnumeric_idx[p] - nonnumeric_buf[i][j] = unwrap(p) + nonnumeric_buf[i][j] = p else error("Invalid parameter $p") end diff --git a/src/systems/parameter_buffer.jl b/src/systems/parameter_buffer.jl index 1cc944e1d9..70c571f0dc 100644 --- a/src/systems/parameter_buffer.jl +++ b/src/systems/parameter_buffer.jl @@ -100,8 +100,11 @@ function MTKParameters( tunable_buffer = Tuple(Vector{temp.type}(undef, temp.length) for temp in ic.tunable_buffer_sizes) - disc_buffer = Tuple(Vector{temp.type}(undef, temp.length) - for temp in ic.discrete_buffer_sizes) + disc_buffer = SizedArray{Tuple{length(ic.discrete_buffer_sizes)}}([Tuple(Vector{temp.type}( + undef, + temp.length) + for temp in subbuffer_sizes) + for subbuffer_sizes in ic.discrete_buffer_sizes]) const_buffer = Tuple(Vector{temp.type}(undef, temp.length) for temp in ic.constant_buffer_sizes) dep_buffer = Tuple(Vector{temp.type}(undef, temp.length) @@ -114,8 +117,8 @@ function MTKParameters( i, j = ic.tunable_idx[sym] tunable_buffer[i][j] = val elseif haskey(ic.discrete_idx, sym) - i, j = ic.discrete_idx[sym] - disc_buffer[i][j] = val + i, j, k = ic.discrete_idx[sym] + disc_buffer[i][j][k] = val elseif haskey(ic.constant_idx, sym) i, j = ic.constant_idx[sym] const_buffer[i][j] = val @@ -132,7 +135,6 @@ function MTKParameters( end return done end - for (sym, val) in p sym = unwrap(sym) val = unwrap(val) @@ -220,11 +222,16 @@ function _split_helper(buf_v::T, recurse, raw, idx) where {T} _split_helper(eltype(T), buf_v, recurse, raw, idx) end -function _split_helper(::Type{<:AbstractArray}, buf_v, ::Val{true}, raw, idx) - map(b -> _split_helper(eltype(b), b, Val(false), raw, idx), buf_v) +function _split_helper(::Type{<:AbstractArray}, buf_v, ::Val{N}, raw, idx) where {N} + map(b -> _split_helper(eltype(b), b, Val(N - 1), raw, idx), buf_v) +end + +function _split_helper(::Type{<:AbstractArray}, buf_v::Tuple, ::Val{N}, raw, idx) where {N} + ntuple(i -> _split_helper(eltype(buf_v[i]), buf_v[i], Val(N - 1), raw, idx), + Val(length(buf_v))) end -function _split_helper(::Type{<:AbstractArray}, buf_v, ::Val{false}, raw, idx) +function _split_helper(::Type{<:AbstractArray}, buf_v, ::Val{0}, raw, idx) _split_helper((), buf_v, (), raw, idx) end @@ -234,7 +241,7 @@ function _split_helper(_, buf_v, _, raw, idx) return res end -function split_into_buffers(raw::AbstractArray, buf, recurse = Val(true)) +function split_into_buffers(raw::AbstractArray, buf, recurse = Val(1)) idx = Ref(1) ntuple(i -> _split_helper(buf[i], recurse, raw, idx), Val(length(buf))) end @@ -262,10 +269,10 @@ SciMLStructures.isscimlstructure(::MTKParameters) = true SciMLStructures.ismutablescimlstructure(::MTKParameters) = true -for (Portion, field) in [(SciMLStructures.Tunable, :tunable) - (SciMLStructures.Discrete, :discrete) - (SciMLStructures.Constants, :constant) - (Nonnumeric, :nonnumeric)] +for (Portion, field, recurse) in [(SciMLStructures.Tunable, :tunable, 1) + (SciMLStructures.Discrete, :discrete, 2) + (SciMLStructures.Constants, :constant, 1) + (Nonnumeric, :nonnumeric, 1)] @eval function SciMLStructures.canonicalize(::$Portion, p::MTKParameters) as_vector = buffer_to_arraypartition(p.$field) repack = let as_vector = as_vector, p = p @@ -283,7 +290,7 @@ for (Portion, field) in [(SciMLStructures.Tunable, :tunable) end @eval function SciMLStructures.replace(::$Portion, p::MTKParameters, newvals) - @set! p.$field = split_into_buffers(newvals, p.$field) + @set! p.$field = split_into_buffers(newvals, p.$field, Val($recurse)) if p.dependent_update_oop !== nothing raw = p.dependent_update_oop(p...) @set! p.dependent = split_into_buffers(raw, p.dependent, Val(false)) @@ -302,7 +309,8 @@ end function Base.copy(p::MTKParameters) tunable = Tuple(eltype(buf) <: Real ? copy(buf) : copy.(buf) for buf in p.tunable) - discrete = Tuple(eltype(buf) <: Real ? copy(buf) : copy.(buf) for buf in p.discrete) + discrete = typeof(p.discrete)([Tuple(eltype(buf) <: Real ? copy(buf) : copy.(buf) + for buf in clockbuf) for clockbuf in p.discrete]) constant = Tuple(eltype(buf) <: Real ? copy(buf) : copy.(buf) for buf in p.constant) dependent = Tuple(eltype(buf) <: Real ? copy(buf) : copy.(buf) for buf in p.dependent) nonnumeric = copy.(p.nonnumeric) @@ -323,7 +331,8 @@ function SymbolicIndexingInterface.parameter_values(p::MTKParameters, pind::Para if portion isa SciMLStructures.Tunable return isempty(k) ? p.tunable[i][j] : p.tunable[i][j][k...] elseif portion isa SciMLStructures.Discrete - return isempty(k) ? p.discrete[i][j] : p.discrete[i][j][k...] + k, l... = k + return isempty(l) ? p.discrete[i][j][k] : p.discrete[i][j][k][l...] elseif portion isa SciMLStructures.Constants return isempty(k) ? p.constant[i][j] : p.constant[i][j][k...] elseif portion === DEPENDENT_PORTION @@ -349,13 +358,14 @@ function SymbolicIndexingInterface.set_parameter!( p.tunable[i][j][k...] = val end elseif portion isa SciMLStructures.Discrete - if isempty(k) - if validate_size && size(val) !== size(p.discrete[i][j]) - throw(InvalidParameterSizeException(size(p.discrete[i][j]), size(val))) + k, l... = k + if isempty(l) + if validate_size && size(val) !== size(p.discrete[i][j][k]) + throw(InvalidParameterSizeException(size(p.discrete[i][j][k]), size(val))) end - p.discrete[i][j] = val + p.discrete[i][j][k][l...] = val else - p.discrete[i][j][k...] = val + p.discrete[i][j][k][l...] = val end elseif portion isa SciMLStructures.Constants if isempty(k) @@ -393,10 +403,11 @@ function _set_parameter_unchecked!( p.tunable[i][j][k...] = val end elseif portion isa SciMLStructures.Discrete - if isempty(k) - p.discrete[i][j] = val + k, l... = k + if isempty(l) + p.discrete[i][j][k] = val else - p.discrete[i][j][k...] = val + p.discrete[i][j][k][l...] = val end elseif portion isa SciMLStructures.Constants if isempty(k) @@ -499,8 +510,10 @@ end function SymbolicIndexingInterface.remake_buffer(indp, oldbuf::MTKParameters, vals::Dict) newbuf = @set oldbuf.tunable = Tuple(Vector{Any}(undef, length(buf)) for buf in oldbuf.tunable) - @set! newbuf.discrete = Tuple(Vector{Any}(undef, length(buf)) - for buf in newbuf.discrete) + @set! newbuf.discrete = SizedVector{length(newbuf.discrete)}([Tuple(Vector{Any}(undef, + length(buf)) + for buf in clockbuf) + for clockbuf in newbuf.discrete]) @set! newbuf.constant = Tuple(Vector{Any}(undef, length(buf)) for buf in newbuf.constant) @set! newbuf.nonnumeric = Tuple(Vector{Any}(undef, length(buf)) @@ -542,8 +555,11 @@ function SymbolicIndexingInterface.remake_buffer(indp, oldbuf::MTKParameters, va @set! newbuf.tunable = narrow_buffer_type_and_fallback_undefs.( oldbuf.tunable, newbuf.tunable) - @set! newbuf.discrete = narrow_buffer_type_and_fallback_undefs.( - oldbuf.discrete, newbuf.discrete) + @set! newbuf.discrete = SizedVector{length(newbuf.discrete)}([narrow_buffer_type_and_fallback_undefs.( + oldclockbuf, + newclockbuf) + for (oldclockbuf, newclockbuf) in zip( + oldbuf.discrete, newbuf.discrete)]) @set! newbuf.constant = narrow_buffer_type_and_fallback_undefs.( oldbuf.constant, newbuf.constant) @set! newbuf.nonnumeric = narrow_buffer_type_and_fallback_undefs.( @@ -557,6 +573,56 @@ function SymbolicIndexingInterface.remake_buffer(indp, oldbuf::MTKParameters, va return newbuf end +struct NestedGetIndex{T} + x::T +end + +function Base.getindex(ngi::NestedGetIndex, idx::Tuple) + i, j, k... = idx + return ngi.x[i][j][k...] +end + +# Required for DiffEqArray constructor to work during interpolation +Base.size(::NestedGetIndex) = () + +function SymbolicIndexingInterface.with_updated_parameter_timeseries_values( + ps::MTKParameters, args::Pair{A, B}...) where {A, B <: NestedGetIndex} + for (i, val) in args + ps.discrete[i] = val.x + end + return ps +end + +function SciMLBase.create_parameter_timeseries_collection( + sys::AbstractSystem, ps::MTKParameters, tspan) + ic = get_index_cache(sys) # this exists because the parameters are `MTKParameters` + has_discrete_subsystems(sys) || return nothing + (dss = get_discrete_subsystems(sys)) === nothing && return nothing + _, _, _, id_to_clock = dss + buffers = [] + + for (i, partition) in enumerate(ps.discrete) + clock = id_to_clock[i] + if clock isa Clock + ts = tspan[1]:(clock.dt):tspan[2] + push!(buffers, DiffEqArray(NestedGetIndex{typeof(partition)}[], ts, (1, 1))) + elseif clock isa SolverStepClock + push!(buffers, + DiffEqArray(NestedGetIndex{typeof(partition)}[], eltype(tspan)[], (1, 1))) + elseif clock isa Continuous + continue + else + error("Unhandled clock $clock") + end + end + + return ParameterTimeseriesCollection(Tuple(buffers), copy(ps)) +end + +function SciMLBase.get_saveable_values(ps::MTKParameters, timeseries_idx) + return NestedGetIndex(deepcopy(ps.discrete[timeseries_idx])) +end + function DiffEqBase.anyeltypedual( p::MTKParameters, ::Type{Val{counter}} = Val{0}) where {counter} DiffEqBase.anyeltypedual(p.tunable) @@ -582,8 +648,10 @@ function Base.getindex(buf::MTKParameters, i) i -= _num_subarrays(buf.tunable) end if !isempty(buf.discrete) - i <= _num_subarrays(buf.discrete) && return _subarrays(buf.discrete)[i] - i -= _num_subarrays(buf.discrete) + for clockbuf in buf.discrete + i <= _num_subarrays(clockbuf) && return _subarrays(clockbuf)[i] + i -= _num_subarrays(clockbuf) + end end if !isempty(buf.constant) i <= _num_subarrays(buf.constant) && return _subarrays(buf.constant)[i] @@ -612,7 +680,7 @@ function Base.setindex!(p::MTKParameters, val, i) end done end - _helper(p.tunable) || _helper(p.discrete) || _helper(p.constant) || + _helper(p.tunable) || _helper(Iterators.flatten(p.discrete)) || _helper(p.constant) || _helper(p.nonnumeric) || throw(BoundsError(p, i)) if p.dependent_update_iip !== nothing p.dependent_update_iip(ArrayPartition(p.dependent), p...) @@ -620,26 +688,7 @@ function Base.setindex!(p::MTKParameters, val, i) end function Base.getindex(p::MTKParameters, pind::ParameterIndex) - (; portion, idx) = pind - i, j, k... = idx - if isempty(k) - indexer = (v) -> v[i][j] - else - indexer = (v) -> v[i][j][k...] - end - if portion isa SciMLStructures.Tunable - indexer(p.tunable) - elseif portion isa SciMLStructures.Discrete - indexer(p.discrete) - elseif portion isa SciMLStructures.Constants - indexer(p.constant) - elseif portion === DEPENDENT_PORTION - indexer(p.dependent) - elseif portion === NONNUMERIC_PORTION - indexer(p.nonnumeric) - else - error("Unhandled portion ", portion) - end + parameter_values(p, pind) end function Base.setindex!(p::MTKParameters, val, pind::ParameterIndex) @@ -649,7 +698,9 @@ end function Base.iterate(buf::MTKParameters, state = 1) total_len = 0 total_len += _num_subarrays(buf.tunable) - total_len += _num_subarrays(buf.discrete) + for clockbuf in buf.discrete + total_len += _num_subarrays(clockbuf) + end total_len += _num_subarrays(buf.constant) total_len += _num_subarrays(buf.nonnumeric) total_len += _num_subarrays(buf.dependent) From ef728cc588e2591aa3a1a8984a3328982af6328f Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 21 May 2024 11:31:17 +0530 Subject: [PATCH 06/13] feat: use new discrete saving, only allow `split=true` hybrid systems --- src/systems/clock_inference.jl | 122 ++++------------------- src/systems/diffeqs/abstractodesystem.jl | 20 +--- test/parameter_dependencies.jl | 15 ++- 3 files changed, 27 insertions(+), 130 deletions(-) diff --git a/src/systems/clock_inference.jl b/src/systems/clock_inference.jl index 7d9b3bc6ad..c6a8464536 100644 --- a/src/systems/clock_inference.jl +++ b/src/systems/clock_inference.jl @@ -203,19 +203,14 @@ function generate_discrete_affect( @static if VERSION < v"1.7" error("The `generate_discrete_affect` function requires at least Julia 1.7") end - use_index_cache = has_index_cache(osys) && get_index_cache(osys) !== nothing + has_index_cache(osys) && get_index_cache(osys) !== nothing || + error("Hybrid systems require `split = true`") out = Sym{Any}(:out) appended_parameters = full_parameters(syss[continuous_id]) offset = length(appended_parameters) - param_to_idx = if use_index_cache - Dict{Any, ParameterIndex}(p => parameter_index(osys, p) - for p in appended_parameters) - else - Dict{Any, Int}(reverse(en) for en in enumerate(appended_parameters)) - end + param_to_idx = Dict{Any, ParameterIndex}(p => parameter_index(osys, p) + for p in appended_parameters) affect_funs = [] - init_funs = [] - svs = [] clocks = TimeDomain[] for (i, (sys, input)) in enumerate(zip(syss, inputs)) i == continuous_id && continue @@ -231,11 +226,7 @@ function generate_discrete_affect( push!(fullvars, s) end needed_disc_to_cont_obs = [] - if use_index_cache - disc_to_cont_idxs = ParameterIndex[] - else - disc_to_cont_idxs = Int[] - end + disc_to_cont_idxs = ParameterIndex[] for v in inputs[continuous_id] _v = arguments(v)[1] if _v in fullvars @@ -255,7 +246,7 @@ function generate_discrete_affect( end append!(appended_parameters, input) cont_to_disc_obs = build_explicit_observed_function( - use_index_cache ? osys : syss[continuous_id], + osys, needed_cont_to_disc_obs, throw = false, expression = true, @@ -281,56 +272,16 @@ function generate_discrete_affect( disc_range = [parameter_index(osys, sym) for sym in unknowns(sys)] save_expr = :($(SciMLBase.save_discretes!)(integrator, $i)) empty_disc = isempty(disc_range) - disc_init = if use_index_cache - :(function (u, p, t) - c2d_obs = $cont_to_disc_obs - d2c_obs = $disc_to_cont_obs - result = c2d_obs(u, p..., t) - for (val, i) in zip(result, $cont_to_disc_idxs) - $(_set_parameter_unchecked!)(p, val, i; update_dependent = false) - end - - disc_state = Tuple($(parameter_values)(p, i) for i in $disc_range) - result = d2c_obs(disc_state, p..., t) - for (val, i) in zip(result, $disc_to_cont_idxs) - # prevent multiple updates to dependents - _set_parameter_unchecked!(p, val, i; update_dependent = false) - end - discretes, repack, _ = $(SciMLStructures.canonicalize)( - $(SciMLStructures.Discrete()), p) - repack(discretes) # to force recalculation of dependents - end) - else - :(function (u, p, t) - c2d_obs = $cont_to_disc_obs - d2c_obs = $disc_to_cont_obs - c2d_view = view(p, $cont_to_disc_idxs) - d2c_view = view(p, $disc_to_cont_idxs) - disc_unknowns = view(p, $disc_range) - copyto!(c2d_view, c2d_obs(u, p, t)) - copyto!(d2c_view, d2c_obs(disc_unknowns, p, t)) - end) - end # @show disc_to_cont_idxs # @show cont_to_disc_idxs # @show disc_range - affect! = :(function (integrator, saved_values) + affect! = :(function (integrator) @unpack u, p, t = integrator c2d_obs = $cont_to_disc_obs d2c_obs = $disc_to_cont_obs - $( - if use_index_cache - :(disc_unknowns = [$(parameter_values)(p, i) for i in $disc_range]) - else - quote - c2d_view = view(p, $cont_to_disc_idxs) - d2c_view = view(p, $disc_to_cont_idxs) - disc_unknowns = view(p, $disc_range) - end - end - ) # TODO: find a way to do this without allocating + disc_unknowns = [$(parameter_values)(p, i) for i in $disc_range] disc = $disc # Write continuous into to discrete: handles `Sample` @@ -353,71 +304,32 @@ function generate_discrete_affect( $(_set_parameter_unchecked!)(p, val, i; update_dependent = false) end end - else - :(copyto!(c2d_view, c2d_obs(integrator.u, p, t))) - end - ) + end) # @show "after c2d", p - $( - if use_index_cache - quote - if !$empty_disc - disc(disc_unknowns, integrator.u, p..., t) - for (val, i) in zip(disc_unknowns, $disc_range) - $(_set_parameter_unchecked!)(p, val, i; update_dependent = false) - end - end - end - else - :($empty_disc || disc(disc_unknowns, disc_unknowns, p, t)) - end - ) # @show "after state update", p - $( - if use_index_cache - quote - result = d2c_obs(disc_unknowns, p..., t) - for (val, i) in zip(result, $disc_to_cont_idxs) - $(_set_parameter_unchecked!)(p, val, i; update_dependent = false) - end - end - else - :(copyto!(d2c_view, d2c_obs(disc_unknowns, p, t))) + result = d2c_obs(disc_unknowns, p..., t) + for (val, i) in zip(result, $disc_to_cont_idxs) + $(_set_parameter_unchecked!)(p, val, i; update_dependent = false) end - ) - push!(saved_values.t, t) - push!(saved_values.saveval, $save_vec) + $save_expr # @show "after d2c", p - $( - if use_index_cache - quote - discretes, repack, _ = $(SciMLStructures.canonicalize)( - $(SciMLStructures.Discrete()), p) - repack(discretes) - end - end - ) + discretes, repack, _ = $(SciMLStructures.canonicalize)( + $(SciMLStructures.Discrete()), p) + repack(discretes) end) push!(affect_funs, affect!) - push!(init_funs, disc_init) - push!(svs, sv) end if eval_expression affects = map(a -> eval_module.eval(toexpr(LiteralExpr(a))), affect_funs) - inits = map(a -> eval_module.eval(toexpr(LiteralExpr(a))), init_funs) else affects = map(affect_funs) do a drop_expr(RuntimeGeneratedFunction( eval_module, eval_module, toexpr(LiteralExpr(a)))) end - inits = map(init_funs) do a - drop_expr(RuntimeGeneratedFunction( - eval_module, eval_module, toexpr(LiteralExpr(a)))) - end end defaults = Dict{Any, Any}(v => 0.0 for v in Iterators.flatten(inputs)) - return affects, inits, clocks, svs, appended_parameters, defaults + return affects, clocks, appended_parameters, defaults end diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index 11abadad5f..c3d8fdf7b3 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -1008,14 +1008,13 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map = cbs = process_events(sys; callback, eval_expression, eval_module, kwargs...) inits = [] if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing - affects, inits, clocks, svs = ModelingToolkit.generate_discrete_affect( + affects, clocks = ModelingToolkit.generate_discrete_affect( sys, dss...; eval_expression, eval_module) - discrete_cbs = map(affects, clocks, svs) do affect, clock, sv + discrete_cbs = map(affects, clocks) do affect, clock if clock isa Clock - PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt; + PeriodicCallback(affect, clock.dt; final_affect = true, initial_affect = true) elseif clock isa SolverStepClock - affect = DiscreteSaveAffect(affect, sv) DiscreteCallback(Returns(true), affect, initialize = (c, u, t, integrator) -> affect(integrator)) else @@ -1031,8 +1030,6 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map = else cbs = CallbackSet(cbs, discrete_cbs...) end - else - svs = nothing end kwargs = filter_kwargs(kwargs) pt = something(get_metadata(sys), StandardODEProblem()) @@ -1041,17 +1038,8 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map = if cbs !== nothing kwargs1 = merge(kwargs1, (callback = cbs,)) end - if svs !== nothing - kwargs1 = merge(kwargs1, (disc_saved_values = svs,)) - end - prob = ODEProblem{iip}(f, u0, tspan, p, pt; kwargs1..., kwargs...) - if !isempty(inits) - for init in inits - # init(prob.u0, prob.p, tspan[1]) - end - end - prob + return ODEProblem{iip}(f, u0, tspan, p, pt; kwargs1..., kwargs...) end get_callback(prob::ODEProblem) = prob.kwargs[:callback] diff --git a/test/parameter_dependencies.jl b/test/parameter_dependencies.jl index ef446e9630..815c63cb59 100644 --- a/test/parameter_dependencies.jl +++ b/test/parameter_dependencies.jl @@ -173,22 +173,19 @@ end @test_skip begin Tf = 1.0 prob = ODEProblem(sys, [x => 0.0, y => 0.0], (0.0, Tf), - [kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0; z(k - 2) => 4.0; - yd(k - 2) => 2.0]) - @test_nowarn solve(prob, Tsit5(); kwargshandle = KeywordArgSilent) + [kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0; z(k - 2) => 4.0; yd(k - 2) => 2.0]) + @test_nowarn solve(prob, Tsit5()) @mtkbuild sys = ODESystem(eqs, t; parameter_dependencies = [kq => 2kp], discrete_events = [[0.5] => [kp ~ 2.0]]) prob = ODEProblem(sys, [x => 0.0, y => 0.0], (0.0, Tf), - [kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0; z(k - 2) => 4.0; - yd(k - 2) => 2.0]) + [kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0; z(k - 2) => 4.0; yd(k - 2) => 2.0]) @test prob.ps[kp] == 1.0 @test prob.ps[kq] == 2.0 - @test_nowarn solve(prob, Tsit5(), kwargshandle = KeywordArgSilent) + @test_nowarn solve(prob, Tsit5()) prob = ODEProblem(sys, [x => 0.0, y => 0.0], (0.0, Tf), - [kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0; z(k - 2) => 4.0; - yd(k - 2) => 2.0]) - integ = init(prob, Tsit5(), kwargshandle = KeywordArgSilent) + [kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0; z(k - 2) => 4.0; yd(k - 2) => 2.0]) + integ = init(prob, Tsit5()) @test integ.ps[kp] == 1.0 @test integ.ps[kq] == 2.0 step!(integ, 0.6) From f1344b2c58987791c298c7a2510b9a2359aef5c4 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 25 Jun 2024 11:51:48 +0530 Subject: [PATCH 07/13] test: test implementation of SII/SciMLBase discrete saving interface --- test/mtkparameters.jl | 26 +++ test/split_parameters.jl | 11 +- test/symbolic_indexing_interface.jl | 243 ++++++++++++++++++---------- 3 files changed, 190 insertions(+), 90 deletions(-) diff --git a/test/mtkparameters.jl b/test/mtkparameters.jl index 30bbb27ede..4af570db59 100644 --- a/test/mtkparameters.jl +++ b/test/mtkparameters.jl @@ -307,3 +307,29 @@ end newoprob = remake(oprob_scal_scal; p = ps_vec) @test newoprob.ps[k] == [2.0, 3.0, 4.0, 5.0] end + +# Parameter timeseries +ps = MTKParameters(([1.0, 1.0],), SizedArray{2}([([0.0, 0.0],), ([0.0, 0.0],)]), (), (), (), nothing, nothing) +with_updated_parameter_timeseries_values( + ps, 1 => ModelingToolkit.NestedGetIndex(([5.0, 10.0],))) +@test ps.discrete[1][1] == [5.0, 10.0] +with_updated_parameter_timeseries_values( + ps, 1 => ModelingToolkit.NestedGetIndex(([3.0, 30.0],)), + 2 => ModelingToolkit.NestedGetIndex(([4.0, 40.0],))) +@test ps.discrete[1][1] == [3.0, 30.0] +@test ps.discrete[2][1] == [4.0, 40.0] +@test SciMLBase.get_saveable_values(ps, 1).x == ps.discrete[1] + +# With multiple types and clocks +ps = MTKParameters((), SizedVector{2}([([1.0, 2.0, 3.0], [false]), ([4.0, 5.0, 6.0], Bool[])]), (), (), (), nothing, nothing) +@test SciMLBase.get_saveable_values(ps, 1).x isa Tuple{Vector{Float64}, Vector{Bool}} +tsidx1 = 1 +tsidx2 = 2 +@test length(ps.discrete[tsidx1][1]) == 3 +@test length(ps.discrete[tsidx1][2]) == 1 +@test length(ps.discrete[tsidx2][1]) == 3 +@test length(ps.discrete[tsidx2][2]) == 0 +with_updated_parameter_timeseries_values( + ps, tsidx1 => ModelingToolkit.NestedGetIndex(([10.0, 11.0, 12.0], [false]))) +@test ps.discrete[tsidx1][1] == [10.0, 11.0, 12.0] +@test ps.discrete[tsidx1][2][] == false diff --git a/test/split_parameters.jl b/test/split_parameters.jl index f707959135..01011828ab 100644 --- a/test/split_parameters.jl +++ b/test/split_parameters.jl @@ -4,6 +4,7 @@ using OrdinaryDiffEq using ModelingToolkit: t_nounits as t, D_nounits as D using ModelingToolkit: MTKParameters, ParameterIndex, DEPENDENT_PORTION, NONNUMERIC_PORTION using SciMLStructures: Tunable, Discrete, Constants +using StaticArrays: SizedVector x = [1, 2.0, false, [1, 2, 3], Parameter(1.0)] @@ -194,7 +195,7 @@ S = get_sensitivity(closed_loop, :u) @testset "Indexing MTKParameters with ParameterIndex" begin ps = MTKParameters(([1.0, 2.0], [3, 4]), - ([true, false], [[1 2; 3 4]]), + SizedVector{2}([([true, false], [[1 2; 3 4]]), ([false, true], [[2 4; 6 8]])]), ([5, 6],), ([7.0, 8.0],), (["hi", "bye"], [:lie, :die]), @@ -202,14 +203,14 @@ S = get_sensitivity(closed_loop, :u) nothing) @test ps[ParameterIndex(Tunable(), (1, 2))] === 2.0 @test ps[ParameterIndex(Tunable(), (2, 2))] === 4 - @test ps[ParameterIndex(Discrete(), (2, 1, 2, 2))] === 4 - @test ps[ParameterIndex(Discrete(), (2, 1))] == [1 2; 3 4] + @test ps[ParameterIndex(Discrete(), (1, 2, 1, 2, 2))] === 4 + @test ps[ParameterIndex(Discrete(), (2, 2, 1))] == [2 4; 6 8] @test ps[ParameterIndex(Constants(), (1, 1))] === 5 @test ps[ParameterIndex(DEPENDENT_PORTION, (1, 1))] === 7.0 @test ps[ParameterIndex(NONNUMERIC_PORTION, (2, 2))] === :die ps[ParameterIndex(Tunable(), (1, 2))] = 3.0 - ps[ParameterIndex(Discrete(), (2, 1, 2, 2))] = 5 + ps[ParameterIndex(Discrete(), (1, 2, 1, 2, 2))] = 5 @test ps[ParameterIndex(Tunable(), (1, 2))] === 3.0 - @test ps[ParameterIndex(Discrete(), (2, 1, 2, 2))] === 5 + @test ps[ParameterIndex(Discrete(), (1, 2, 1, 2, 2))] === 5 end diff --git a/test/symbolic_indexing_interface.jl b/test/symbolic_indexing_interface.jl index 7fd57c0474..7b60936852 100644 --- a/test/symbolic_indexing_interface.jl +++ b/test/symbolic_indexing_interface.jl @@ -1,90 +1,155 @@ using ModelingToolkit, SymbolicIndexingInterface, SciMLBase -using ModelingToolkit: t_nounits as t, D_nounits as D +using ModelingToolkit: t_nounits as t, D_nounits as D, ParameterIndex +using SciMLStructures: Tunable + +@testset "ODESystem" begin + @parameters a b + @variables x(t)=1.0 y(t)=2.0 xy(t) + eqs = [D(x) ~ a * y + t, D(y) ~ b * t] + @named odesys = ODESystem(eqs, t, [x, y], [a, b]; observed = [xy ~ x + y]) + odesys = complete(odesys) + @test all(is_variable.((odesys,), [x, y, 1, 2, :x, :y])) + @test all(.!is_variable.((odesys,), [a, b, t, 3, 0, :a, :b])) + @test variable_index.((odesys,), [x, y, a, b, t, 1, 2, :x, :y, :a, :b]) == + [1, 2, nothing, nothing, nothing, 1, 2, 1, 2, nothing, nothing] + @test isequal(variable_symbols(odesys), [x, y]) + @test all(is_parameter.((odesys,), [a, b, ParameterIndex(Tunable(), (1, 1)), :a, :b])) + @test all(.!is_parameter.((odesys,), [x, y, t, 3, 0, :x, :y])) + @test parameter_index(odesys, a) == parameter_index(odesys, :a) + @test parameter_index(odesys, a) isa ParameterIndex{Tunable, Tuple{Int, Int}} + @test parameter_index(odesys, b) == parameter_index(odesys, :b) + @test parameter_index(odesys, b) isa ParameterIndex{Tunable, Tuple{Int, Int}} + @test parameter_index.((odesys,), [x, y, t, ParameterIndex(Tunable(), (1, 1)), :x, :y,]) == + [nothing, nothing, nothing, ParameterIndex(Tunable(), (1, 1)), nothing, nothing] + @test isequal(parameter_symbols(odesys), [a, b]) + @test all(is_independent_variable.((odesys,), [t, :t])) + @test all(.!is_independent_variable.((odesys,), [x, y, a, :x, :y, :a])) + @test isequal(independent_variable_symbols(odesys), [t]) + @test is_time_dependent(odesys) + @test constant_structure(odesys) + @test !isempty(default_values(odesys)) + @test default_values(odesys)[x] == 1.0 + @test default_values(odesys)[y] == 2.0 + @test isequal(default_values(odesys)[xy], x + y) + + @named odesys = ODESystem( + eqs, t, [x, y], [a, b]; defaults = [xy => 3.0], observed = [xy ~ x + y]) + odesys = complete(odesys) + @test default_values(odesys)[xy] == 3.0 + pobs = parameter_observed(odesys, a + b) + @test pobs.timeseries_idx === nothing + @test pobs.observed_fn( + ModelingToolkit.MTKParameters(odesys, [a => 1.0, b => 2.0]), 0.0) ≈ 3.0 + pobs = parameter_observed(odesys, [a + b, a - b]) + @test pobs.timeseries_idx === nothing + @test pobs.observed_fn( + ModelingToolkit.MTKParameters(odesys, [a => 1.0, b => 2.0]), 0.0) ≈ [3.0, -1.0] +end -@parameters a b -@variables x(t)=1.0 y(t)=2.0 xy(t) -eqs = [D(x) ~ a * y + t, D(y) ~ b * t] -@named odesys = ODESystem(eqs, t, [x, y], [a, b]; observed = [xy ~ x + y]) - -@test all(is_variable.((odesys,), [x, y, 1, 2, :x, :y])) -@test all(.!is_variable.((odesys,), [a, b, t, 3, 0, :a, :b])) -@test variable_index.((odesys,), [x, y, a, b, t, 1, 2, :x, :y, :a, :b]) == - [1, 2, nothing, nothing, nothing, 1, 2, 1, 2, nothing, nothing] -@test isequal(variable_symbols(odesys), [x, y]) -@test all(is_parameter.((odesys,), [a, b, 1, 2, :a, :b])) -@test all(.!is_parameter.((odesys,), [x, y, t, 3, 0, :x, :y])) -@test parameter_index.((odesys,), [x, y, a, b, t, 1, 2, :x, :y, :a, :b]) == - [nothing, nothing, 1, 2, nothing, 1, 2, nothing, nothing, 1, 2] -@test isequal(parameter_symbols(odesys), [a, b]) -@test all(is_independent_variable.((odesys,), [t, :t])) -@test all(.!is_independent_variable.((odesys,), [x, y, a, :x, :y, :a])) -@test isequal(independent_variable_symbols(odesys), [t]) -@test is_time_dependent(odesys) -@test constant_structure(odesys) -@test !isempty(default_values(odesys)) -@test default_values(odesys)[x] == 1.0 -@test default_values(odesys)[y] == 2.0 -@test isequal(default_values(odesys)[xy], x + y) - -@named odesys = ODESystem( - eqs, t, [x, y], [a, b]; defaults = [xy => 3.0], observed = [xy ~ x + y]) -@test default_values(odesys)[xy] == 3.0 - -@variables x y z -@parameters σ ρ β - -eqs = [0 ~ σ * (y - x), - 0 ~ x * (ρ - z) - y, - 0 ~ x * y - β * z] -@named ns = NonlinearSystem(eqs, [x, y, z], [σ, ρ, β]) - -@test !is_time_dependent(ns) - -@parameters x -@variables u(..) -Dxx = Differential(x)^2 -Dtt = Differential(t)^2 -Dt = D - -#2D PDE -C = 1 -eq = Dtt(u(t, x)) ~ C^2 * Dxx(u(t, x)) - -# Initial and boundary conditions -bcs = [u(t, 0) ~ 0.0,# for all t > 0 - u(t, 1) ~ 0.0,# for all t > 0 - u(0, x) ~ x * (1.0 - x), #for all 0 < x < 1 - Dt(u(0, x)) ~ 0.0] #for all 0 < x < 1] - -# Space and time domains -domains = [t ∈ (0.0, 1.0), - x ∈ (0.0, 1.0)] - -@named pde_system = PDESystem(eq, bcs, domains, [t, x], [u]) - -@test pde_system.ps == SciMLBase.NullParameters() -@test parameter_symbols(pde_system) == [] - -@parameters x -@constants h = 1 -@variables u(..) -Dt = D -Dxx = Differential(x)^2 -eq = Dt(u(t, x)) ~ h * Dxx(u(t, x)) -bcs = [u(0, x) ~ -h * x * (x - 1) * sin(x), - u(t, 0) ~ 0, u(t, 1) ~ 0] - -domains = [t ∈ (0.0, 1.0), - x ∈ (0.0, 1.0)] - -analytic = [u(t, x) ~ -h * x * (x - 1) * sin(x) * exp(-2 * h * t)] -analytic_function = (ps, t, x) -> -ps[1] * x * (x - 1) * sin(x) * exp(-2 * ps[1] * t) - -@named pdesys = PDESystem(eq, bcs, domains, [t, x], [u], [h], analytic = analytic) - -@test isequal(pdesys.ps, [h]) -@test isequal(parameter_symbols(pdesys), [h]) -@test isequal(parameters(pdesys), [h]) +# @testset "Clock system" begin +# dt = 0.1 +# dt2 = 0.2 +# @variables x(t)=0 y(t)=0 u(t)=0 yd1(t)=0 ud1(t)=0 yd2(t)=0 ud2(t)=0 +# @parameters kp=1 r=1 + +# eqs = [ +# # controller (time discrete part `dt=0.1`) +# yd1 ~ Sample(t, dt)(y) +# ud1 ~ kp * (r - yd1) +# # controller (time discrete part `dt=0.2`) +# yd2 ~ Sample(t, dt2)(y) +# ud2 ~ kp * (r - yd2) + +# # plant (time continuous part) +# u ~ Hold(ud1) + Hold(ud2) +# D(x) ~ -x + u +# y ~ x] + +# @mtkbuild cl = ODESystem(eqs, t) +# partition1_params = [Hold(ud1), Sample(t, dt)(y), ud1, yd1] +# partition2_params = [Hold(ud2), Sample(t, dt2)(y), ud2, yd2] +# @test all( +# Base.Fix1(is_timeseries_parameter, cl), vcat(partition1_params, partition2_params)) +# @test allequal(timeseries_parameter_index(cl, p).timeseries_idx +# for p in partition1_params) +# @test allequal(timeseries_parameter_index(cl, p).timeseries_idx +# for p in partition2_params) +# tsidx1 = timeseries_parameter_index(cl, partition1_params[1]).timeseries_idx +# tsidx2 = timeseries_parameter_index(cl, partition2_params[1]).timeseries_idx +# @test tsidx1 != tsidx2 +# ps = ModelingToolkit.MTKParameters(cl, [kp => 1.0, Sample(t, dt)(y) => 1.0]) +# pobs = parameter_observed(cl, Shift(t, 1)(yd1)) +# @test pobs.timeseries_idx == tsidx1 +# @test pobs.observed_fn(ps, 0.0) == 1.0 +# pobs = parameter_observed(cl, [Shift(t, 1)(yd1), Shift(t, 1)(ud1)]) +# @test pobs.timeseries_idx == tsidx1 +# @test pobs.observed_fn(ps, 0.0) == [1.0, 0.0] +# pobs = parameter_observed(cl, [Shift(t, 1)(yd1), Shift(t, 1)(ud2)]) +# @test pobs.timeseries_idx === nothing +# @test pobs.observed_fn(ps, 0.0) == [1.0, 1.0] +# end + +@testset "Nonlinear system" begin + @variables x y z + @parameters σ ρ β + + eqs = [0 ~ σ * (y - x), + 0 ~ x * (ρ - z) - y, + 0 ~ x * y - β * z] + @named ns = NonlinearSystem(eqs, [x, y, z], [σ, ρ, β]) + ns = complete(ns) + @test !is_time_dependent(ns) + ps = ModelingToolkit.MTKParameters(ns, [σ => 1.0, ρ => 2.0, β => 3.0]) + pobs = parameter_observed(ns, σ + ρ) + @test pobs.timeseries_idx === nothing + @test pobs.observed_fn(ps) == 3.0 + pobs = parameter_observed(ns, [σ + ρ, ρ + β]) + @test pobs.timeseries_idx === nothing + @test pobs.observed_fn(ps) == [3.0, 5.0] +end + +@testset "PDESystem" begin + @parameters x + @variables u(..) + Dxx = Differential(x)^2 + Dtt = Differential(t)^2 + Dt = D + + #2D PDE + C = 1 + eq = Dtt(u(t, x)) ~ C^2 * Dxx(u(t, x)) + + # Initial and boundary conditions + bcs = [u(t, 0) ~ 0.0,# for all t > 0 + u(t, 1) ~ 0.0,# for all t > 0 + u(0, x) ~ x * (1.0 - x), #for all 0 < x < 1 + Dt(u(0, x)) ~ 0.0] #for all 0 < x < 1] + + # Space and time domains + domains = [t ∈ (0.0, 1.0), + x ∈ (0.0, 1.0)] + + @named pde_system = PDESystem(eq, bcs, domains, [t, x], [u]) + + @test pde_system.ps == SciMLBase.NullParameters() + @test parameter_symbols(pde_system) == [] + + @parameters x + @constants h = 1 + @variables u(..) + Dt = D + Dxx = Differential(x)^2 + eq = Dt(u(t, x)) ~ h * Dxx(u(t, x)) + bcs = [u(0, x) ~ -h * x * (x - 1) * sin(x), + u(t, 0) ~ 0, u(t, 1) ~ 0] + + domains = [t ∈ (0.0, 1.0), + x ∈ (0.0, 1.0)] + + @test isequal(pdesys.ps, [h]) + @test isequal(parameter_symbols(pdesys), [h]) + @test isequal(parameters(pdesys), [h]) +end # Issue#2767 using ModelingToolkit @@ -113,4 +178,12 @@ get_dep = @test_nowarn getu(prob, 2p1) @test getu(prob, z)(prob) == getu(prob, :z)(prob) @test getu(prob, p1)(prob) == getu(prob, :p1)(prob) @test getu(prob, p2)(prob) == getu(prob, :p2)(prob) + analytic = [u(t, x) ~ -h * x * (x - 1) * sin(x) * exp(-2 * h * t)] + analytic_function = (ps, t, x) -> -ps[1] * x * (x - 1) * sin(x) * exp(-2 * ps[1] * t) + + @named pdesys = PDESystem(eq, bcs, domains, [t, x], [u], [h], analytic = analytic) + + @test isequal(pdesys.ps, [h]) + @test isequal(parameter_symbols(pdesys), [h]) + @test isequal(parameters(pdesys), [h]) end From 192758be188e4eefeb77f29ecb8da5b3b00f57a2 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 24 May 2024 18:40:27 +0530 Subject: [PATCH 08/13] fix: make `SII.observed` support time-independent systems --- src/systems/abstractsystem.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 368cab4c54..d2b70dfbd6 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -639,6 +639,8 @@ function SymbolicIndexingInterface.observed( return let _fn = _fn fn2(u, p) = _fn(u, p) fn2(u, p::MTKParameters) = _fn(u, p...) + fn2(::Nothing, p) = _fn([], p) + fn2(::Nothing, p::MTKParameters) = _fn([], p...) fn2 end end From 51c5446fdf070a9338f64ed7ce405b5831f071e9 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 31 May 2024 22:49:03 +0530 Subject: [PATCH 09/13] fix: various bug and test fixes --- src/systems/abstractsystem.jl | 9 ++++++--- src/systems/parameter_buffer.jl | 4 ++-- test/mtkparameters.jl | 4 ++-- test/symbolic_indexing_interface.jl | 3 ++- 4 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index d2b70dfbd6..be186a1f09 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -482,10 +482,12 @@ function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym) end elseif iscall(sym) && operation(sym) === getindex && (idx = parameter_index(ic, first(arguments(sym)))) !== nothing - if idx.portion isa SciMLStructures.Discrete && idx.idx[2] == idx.idx[3] == nothing + if idx.portion isa SciMLStructures.Discrete && + idx.idx[2] == idx.idx[3] == nothing return nothing else - ParameterIndex(idx.portion, (idx.idx..., arguments(sym)[(begin + 1):end]...)) + ParameterIndex( + idx.portion, (idx.idx..., arguments(sym)[(begin + 1):end]...)) end else nothing @@ -505,7 +507,8 @@ end function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym::Symbol) if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing idx = parameter_index(ic, sym) - if idx === nothing || idx.portion isa SciMLStructures.Discrete && idx.idx[2] == idx.idx[3] == 0 + if idx === nothing || + idx.portion isa SciMLStructures.Discrete && idx.idx[2] == idx.idx[3] == 0 return nothing else return idx diff --git a/src/systems/parameter_buffer.jl b/src/systems/parameter_buffer.jl index 70c571f0dc..7f1da082bd 100644 --- a/src/systems/parameter_buffer.jl +++ b/src/systems/parameter_buffer.jl @@ -158,7 +158,7 @@ function MTKParameters( end end tunable_buffer = narrow_buffer_type.(tunable_buffer) - disc_buffer = narrow_buffer_type.(disc_buffer) + disc_buffer = broadcast.(narrow_buffer_type, disc_buffer) const_buffer = narrow_buffer_type.(const_buffer) # Don't narrow nonnumeric types nonnumeric_buffer = nonnumeric_buffer @@ -568,7 +568,7 @@ function SymbolicIndexingInterface.remake_buffer(indp, oldbuf::MTKParameters, va @set! newbuf.dependent = narrow_buffer_type_and_fallback_undefs.( oldbuf.dependent, split_into_buffers( - newbuf.dependent_update_oop(newbuf...), oldbuf.dependent, Val(false))) + newbuf.dependent_update_oop(newbuf...), oldbuf.dependent, Val(0))) end return newbuf end diff --git a/test/mtkparameters.jl b/test/mtkparameters.jl index 4af570db59..622acbddfb 100644 --- a/test/mtkparameters.jl +++ b/test/mtkparameters.jl @@ -321,8 +321,8 @@ with_updated_parameter_timeseries_values( @test SciMLBase.get_saveable_values(ps, 1).x == ps.discrete[1] # With multiple types and clocks -ps = MTKParameters((), SizedVector{2}([([1.0, 2.0, 3.0], [false]), ([4.0, 5.0, 6.0], Bool[])]), (), (), (), nothing, nothing) -@test SciMLBase.get_saveable_values(ps, 1).x isa Tuple{Vector{Float64}, Vector{Bool}} +ps = MTKParameters((), SizedVector{2}([([1.0, 2.0, 3.0], falses(1)), ([4.0, 5.0, 6.0], falses(0))]), (), (), (), nothing, nothing) +@test SciMLBase.get_saveable_values(ps, 1).x isa Tuple{Vector{Float64}, BitVector} tsidx1 = 1 tsidx2 = 2 @test length(ps.discrete[tsidx1][1]) == 3 diff --git a/test/symbolic_indexing_interface.jl b/test/symbolic_indexing_interface.jl index 7b60936852..70963e3371 100644 --- a/test/symbolic_indexing_interface.jl +++ b/test/symbolic_indexing_interface.jl @@ -19,7 +19,8 @@ using SciMLStructures: Tunable @test parameter_index(odesys, a) isa ParameterIndex{Tunable, Tuple{Int, Int}} @test parameter_index(odesys, b) == parameter_index(odesys, :b) @test parameter_index(odesys, b) isa ParameterIndex{Tunable, Tuple{Int, Int}} - @test parameter_index.((odesys,), [x, y, t, ParameterIndex(Tunable(), (1, 1)), :x, :y,]) == + @test parameter_index.( + (odesys,), [x, y, t, ParameterIndex(Tunable(), (1, 1)), :x, :y]) == [nothing, nothing, nothing, ParameterIndex(Tunable(), (1, 1)), nothing, nothing] @test isequal(parameter_symbols(odesys), [a, b]) @test all(is_independent_variable.((odesys,), [t, :t])) From e7917b8a2cdefea887e1c89c8330f78ba2aa6a2f Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 25 Jun 2024 14:40:14 +0530 Subject: [PATCH 10/13] refactor: update implementation of discrete save interface --- src/systems/abstractsystem.jl | 60 +++++++++++++++++++---------- src/systems/index_cache.jl | 18 ++++----- src/systems/parameter_buffer.jl | 5 ++- test/mtkparameters.jl | 14 ++++--- test/parameter_dependencies.jl | 9 +++-- test/symbolic_indexing_interface.jl | 29 +++++++------- 6 files changed, 80 insertions(+), 55 deletions(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index be186a1f09..51fd0cc206 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -447,7 +447,8 @@ function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym) sym = unwrap(sym) if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing return sym isa ParameterIndex || is_parameter(ic, sym) || - iscall(sym) && operation(sym) === getindex && + iscall(sym) && + operation(sym) === getindex && is_parameter(ic, first(arguments(sym))) end if unwrap(sym) isa Int @@ -526,34 +527,19 @@ function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym::Sym end function SymbolicIndexingInterface.is_timeseries_parameter(sys::AbstractSystem, sym) + is_time_dependent(sys) || return false has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing || return false is_timeseries_parameter(ic, sym) end function SymbolicIndexingInterface.timeseries_parameter_index(sys::AbstractSystem, sym) + is_time_dependent(sys) || return nothing has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing || return nothing timeseries_parameter_index(ic, sym) end function SymbolicIndexingInterface.parameter_observed(sys::AbstractSystem, sym) if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing - allvars = vars(sym; op = Symbolics.Operator) - ts_idxs = Set{Int}() - for var in allvars - var = unwrap(var) - # FIXME: Shouldn't have to shift systems - if istree(var) && (op = operation(var)) isa Shift && op.steps == 1 - var = only(arguments(var)) - end - ts_idx = check_index_map(ic.discrete_idx, unwrap(var)) - ts_idx === nothing && continue - push!(ts_idxs, ts_idx[1]) - end - if length(ts_idxs) == 1 - ts_idx = only(ts_idxs) - else - ts_idx = nothing - end rawobs = build_explicit_observed_function( sys, sym; param_only = true, return_inplace = true) if rawobs isa Tuple @@ -580,10 +566,44 @@ function SymbolicIndexingInterface.parameter_observed(sys::AbstractSystem, sym) end end else - ts_idx = nothing obsfn = build_explicit_observed_function(sys, sym; param_only = true) end - return ParameterObservedFunction(ts_idx, obsfn) + return obsfn +end + +function _all_ts_idxs!(ts_idxs, ::NotSymbolic, sys, sym) + if is_variable(sys, sym) + push!(ts_idxs, ContinuousTimeseries()) + elseif is_timeseries_parameter(sys, sym) + push!(ts_idxs, timeseries_parameter_index(sys, sym).timeseries_idx) + end +end +# Need this to avoid ambiguity with the array case +for traitT in [ + ScalarSymbolic, + ArraySymbolic +] + @eval function _all_ts_idxs!(ts_idxs, ::$traitT, sys, sym) + allsyms = vars(sym; op = Symbolics.Operator) + foreach(allsyms) do s + _all_ts_idxs!(ts_idxs, sys, s) + end + end +end +function _all_ts_idxs!(ts_idxs, ::NotSymbolic, sys, sym::AbstractArray) + foreach(sym) do s + _all_ts_idxs!(ts_idxs, sys, s) + end +end +_all_ts_idxs!(ts_idxs, sys, sym) = _all_ts_idxs!(ts_idxs, NotSymbolic(), sys, sym) + +function SymbolicIndexingInterface.get_all_timeseries_indexes(sys::AbstractSystem, sym) + if !is_time_dependent(sys) + return Set() + end + ts_idxs = Set() + _all_ts_idxs!(ts_idxs, sys, sym) + return ts_idxs end function SymbolicIndexingInterface.parameter_symbols(sys::AbstractSystem) diff --git a/src/systems/index_cache.jl b/src/systems/index_cache.jl index b1063f214e..13fb7adef2 100644 --- a/src/systems/index_cache.jl +++ b/src/systems/index_cache.jl @@ -113,7 +113,7 @@ function IndexCache(sys::AbstractSystem) error("Discrete subsystem $i input $inp is not a parameter") disc_clocks[inp] = i disc_clocks[default_toterm(inp)] = i - if hasname(inp) && (!istree(inp) || operation(inp) !== getindex) + if hasname(inp) && (!iscall(inp) || operation(inp) !== getindex) disc_clocks[getname(inp)] = i disc_clocks[default_toterm(inp)] = i end @@ -126,7 +126,7 @@ function IndexCache(sys::AbstractSystem) error("Discrete subsystem $i unknown $sym is not a parameter") disc_clocks[sym] = i disc_clocks[default_toterm(sym)] = i - if hasname(sym) && (!istree(sym) || operation(sym) !== getindex) + if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex) disc_clocks[getname(sym)] = i disc_clocks[getname(default_toterm(sym))] = i end @@ -138,13 +138,13 @@ function IndexCache(sys::AbstractSystem) # FIXME: This shouldn't be necessary eq.rhs === -0.0 && continue sym = eq.lhs - if istree(sym) && operation(sym) == Shift(t, 1) + if iscall(sym) && operation(sym) == Shift(t, 1) sym = only(arguments(sym)) end disc_clocks[sym] = i disc_clocks[sym] = i disc_clocks[default_toterm(sym)] = i - if hasname(sym) && (!istree(sym) || operation(sym) !== getindex) + if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex) disc_clocks[getname(sym)] = i disc_clocks[getname(default_toterm(sym))] = i end @@ -153,7 +153,7 @@ function IndexCache(sys::AbstractSystem) for par in inputs[continuous_id] is_parameter(sys, par) || error("Discrete subsystem input is not a parameter") - istree(par) && operation(par) isa Hold || + iscall(par) && operation(par) isa Hold || error("Continuous subsystem input is not a Hold") if haskey(disc_clocks, par) sym = par @@ -176,7 +176,7 @@ function IndexCache(sys::AbstractSystem) disc_clocks[affect.lhs] = user_affect_clock disc_clocks[default_toterm(affect.lhs)] = user_affect_clock if hasname(affect.lhs) && - (!istree(affect.lhs) || operation(affect.lhs) !== getindex) + (!iscall(affect.lhs) || operation(affect.lhs) !== getindex) disc_clocks[getname(affect.lhs)] = user_affect_clock disc_clocks[getname(default_toterm(affect.lhs))] = user_affect_clock end @@ -190,7 +190,7 @@ function IndexCache(sys::AbstractSystem) disc = unwrap(disc) disc_clocks[disc] = user_affect_clock disc_clocks[default_toterm(disc)] = user_affect_clock - if hasname(disc) && (!istree(disc) || operation(disc) !== getindex) + if hasname(disc) && (!iscall(disc) || operation(disc) !== getindex) disc_clocks[getname(disc)] = user_affect_clock disc_clocks[getname(default_toterm(disc))] = user_affect_clock end @@ -245,7 +245,7 @@ function IndexCache(sys::AbstractSystem) for (j, sym) in enumerate(buffer[btype]) disc_idxs[sym] = (clockidx, i, j) disc_idxs[default_toterm(sym)] = (clockidx, i, j) - if hasname(sym) && (!istree(sym) || operation(sym) !== getindex) + if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex) disc_idxs[getname(sym)] = (clockidx, i, j) disc_idxs[getname(default_toterm(sym))] = (clockidx, i, j) end @@ -256,7 +256,7 @@ function IndexCache(sys::AbstractSystem) haskey(disc_idxs, sym) && continue disc_idxs[sym] = (clockid, 0, 0) disc_idxs[default_toterm(sym)] = (clockid, 0, 0) - if hasname(sym) && (!istree(sym) || operation(sym) !== getindex) + if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex) disc_idxs[getname(sym)] = (clockid, 0, 0) disc_idxs[getname(default_toterm(sym))] = (clockid, 0, 0) end diff --git a/src/systems/parameter_buffer.jl b/src/systems/parameter_buffer.jl index 7f1da082bd..cd9123f0bf 100644 --- a/src/systems/parameter_buffer.jl +++ b/src/systems/parameter_buffer.jl @@ -363,7 +363,7 @@ function SymbolicIndexingInterface.set_parameter!( if validate_size && size(val) !== size(p.discrete[i][j][k]) throw(InvalidParameterSizeException(size(p.discrete[i][j][k]), size(val))) end - p.discrete[i][j][k][l...] = val + p.discrete[i][j][k] = val else p.discrete[i][j][k][l...] = val end @@ -586,7 +586,8 @@ end Base.size(::NestedGetIndex) = () function SymbolicIndexingInterface.with_updated_parameter_timeseries_values( - ps::MTKParameters, args::Pair{A, B}...) where {A, B <: NestedGetIndex} + ::AbstractSystem, ps::MTKParameters, args::Pair{A, B}...) where { + A, B <: NestedGetIndex} for (i, val) in args ps.discrete[i] = val.x end diff --git a/test/mtkparameters.jl b/test/mtkparameters.jl index 622acbddfb..b3b170df18 100644 --- a/test/mtkparameters.jl +++ b/test/mtkparameters.jl @@ -2,6 +2,7 @@ using ModelingToolkit using ModelingToolkit: t_nounits as t, D_nounits as D, MTKParameters using SymbolicIndexingInterface using SciMLStructures: SciMLStructures, canonicalize, Tunable, Discrete, Constants +using StaticArrays: SizedVector using OrdinaryDiffEq using ForwardDiff using JET @@ -309,19 +310,22 @@ end end # Parameter timeseries -ps = MTKParameters(([1.0, 1.0],), SizedArray{2}([([0.0, 0.0],), ([0.0, 0.0],)]), (), (), (), nothing, nothing) +ps = MTKParameters(([1.0, 1.0],), SizedVector{2}([([0.0, 0.0],), ([0.0, 0.0],)]), + (), (), (), nothing, nothing) with_updated_parameter_timeseries_values( - ps, 1 => ModelingToolkit.NestedGetIndex(([5.0, 10.0],))) + sys, ps, 1 => ModelingToolkit.NestedGetIndex(([5.0, 10.0],))) @test ps.discrete[1][1] == [5.0, 10.0] with_updated_parameter_timeseries_values( - ps, 1 => ModelingToolkit.NestedGetIndex(([3.0, 30.0],)), + sys, ps, 1 => ModelingToolkit.NestedGetIndex(([3.0, 30.0],)), 2 => ModelingToolkit.NestedGetIndex(([4.0, 40.0],))) @test ps.discrete[1][1] == [3.0, 30.0] @test ps.discrete[2][1] == [4.0, 40.0] @test SciMLBase.get_saveable_values(ps, 1).x == ps.discrete[1] # With multiple types and clocks -ps = MTKParameters((), SizedVector{2}([([1.0, 2.0, 3.0], falses(1)), ([4.0, 5.0, 6.0], falses(0))]), (), (), (), nothing, nothing) +ps = MTKParameters( + (), SizedVector{2}([([1.0, 2.0, 3.0], falses(1)), ([4.0, 5.0, 6.0], falses(0))]), + (), (), (), nothing, nothing) @test SciMLBase.get_saveable_values(ps, 1).x isa Tuple{Vector{Float64}, BitVector} tsidx1 = 1 tsidx2 = 2 @@ -330,6 +334,6 @@ tsidx2 = 2 @test length(ps.discrete[tsidx2][1]) == 3 @test length(ps.discrete[tsidx2][2]) == 0 with_updated_parameter_timeseries_values( - ps, tsidx1 => ModelingToolkit.NestedGetIndex(([10.0, 11.0, 12.0], [false]))) + sys, ps, tsidx1 => ModelingToolkit.NestedGetIndex(([10.0, 11.0, 12.0], [false]))) @test ps.discrete[tsidx1][1] == [10.0, 11.0, 12.0] @test ps.discrete[tsidx1][2][] == false diff --git a/test/parameter_dependencies.jl b/test/parameter_dependencies.jl index 815c63cb59..242be8f1d7 100644 --- a/test/parameter_dependencies.jl +++ b/test/parameter_dependencies.jl @@ -173,18 +173,21 @@ end @test_skip begin Tf = 1.0 prob = ODEProblem(sys, [x => 0.0, y => 0.0], (0.0, Tf), - [kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0; z(k - 2) => 4.0; yd(k - 2) => 2.0]) + [kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0; z(k - 2) => 4.0; + yd(k - 2) => 2.0]) @test_nowarn solve(prob, Tsit5()) @mtkbuild sys = ODESystem(eqs, t; parameter_dependencies = [kq => 2kp], discrete_events = [[0.5] => [kp ~ 2.0]]) prob = ODEProblem(sys, [x => 0.0, y => 0.0], (0.0, Tf), - [kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0; z(k - 2) => 4.0; yd(k - 2) => 2.0]) + [kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0; z(k - 2) => 4.0; + yd(k - 2) => 2.0]) @test prob.ps[kp] == 1.0 @test prob.ps[kq] == 2.0 @test_nowarn solve(prob, Tsit5()) prob = ODEProblem(sys, [x => 0.0, y => 0.0], (0.0, Tf), - [kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0; z(k - 2) => 4.0; yd(k - 2) => 2.0]) + [kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0; z(k - 2) => 4.0; + yd(k - 2) => 2.0]) integ = init(prob, Tsit5()) @test integ.ps[kp] == 1.0 @test integ.ps[kq] == 2.0 diff --git a/test/symbolic_indexing_interface.jl b/test/symbolic_indexing_interface.jl index 70963e3371..10d24fd6f2 100644 --- a/test/symbolic_indexing_interface.jl +++ b/test/symbolic_indexing_interface.jl @@ -38,12 +38,12 @@ using SciMLStructures: Tunable odesys = complete(odesys) @test default_values(odesys)[xy] == 3.0 pobs = parameter_observed(odesys, a + b) - @test pobs.timeseries_idx === nothing - @test pobs.observed_fn( + @test isempty(get_all_timeseries_indexes(odesys, a + b)) + @test pobs( ModelingToolkit.MTKParameters(odesys, [a => 1.0, b => 2.0]), 0.0) ≈ 3.0 pobs = parameter_observed(odesys, [a + b, a - b]) - @test pobs.timeseries_idx === nothing - @test pobs.observed_fn( + @test isempty(get_all_timeseries_indexes(odesys, [a + b, a - b])) + @test pobs( ModelingToolkit.MTKParameters(odesys, [a => 1.0, b => 2.0]), 0.0) ≈ [3.0, -1.0] end @@ -102,11 +102,11 @@ end @test !is_time_dependent(ns) ps = ModelingToolkit.MTKParameters(ns, [σ => 1.0, ρ => 2.0, β => 3.0]) pobs = parameter_observed(ns, σ + ρ) - @test pobs.timeseries_idx === nothing - @test pobs.observed_fn(ps) == 3.0 + @test isempty(get_all_timeseries_indexes(ns, σ + ρ)) + @test pobs(ps) == 3.0 pobs = parameter_observed(ns, [σ + ρ, ρ + β]) - @test pobs.timeseries_idx === nothing - @test pobs.observed_fn(ps) == [3.0, 5.0] + @test isempty(get_all_timeseries_indexes(ns, [σ + ρ, ρ + β])) + @test pobs(ps) == [3.0, 5.0] end @testset "PDESystem" begin @@ -147,6 +147,11 @@ end domains = [t ∈ (0.0, 1.0), x ∈ (0.0, 1.0)] + analytic = [u(t, x) ~ -h * x * (x - 1) * sin(x) * exp(-2 * h * t)] + analytic_function = (ps, t, x) -> -ps[1] * x * (x - 1) * sin(x) * exp(-2 * ps[1] * t) + + @named pdesys = PDESystem(eq, bcs, domains, [t, x], [u], [h], analytic = analytic) + @test isequal(pdesys.ps, [h]) @test isequal(parameter_symbols(pdesys), [h]) @test isequal(parameters(pdesys), [h]) @@ -179,12 +184,4 @@ get_dep = @test_nowarn getu(prob, 2p1) @test getu(prob, z)(prob) == getu(prob, :z)(prob) @test getu(prob, p1)(prob) == getu(prob, :p1)(prob) @test getu(prob, p2)(prob) == getu(prob, :p2)(prob) - analytic = [u(t, x) ~ -h * x * (x - 1) * sin(x) * exp(-2 * h * t)] - analytic_function = (ps, t, x) -> -ps[1] * x * (x - 1) * sin(x) * exp(-2 * ps[1] * t) - - @named pdesys = PDESystem(eq, bcs, domains, [t, x], [u], [h], analytic = analytic) - - @test isequal(pdesys.ps, [h]) - @test isequal(parameter_symbols(pdesys), [h]) - @test isequal(parameters(pdesys), [h]) end From 33d4ecc4e5685e9fd04160e3bcc7cd3a0a8dbbec Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 25 Jun 2024 18:32:35 +0530 Subject: [PATCH 11/13] refactor: use clock from SciMLBase, fix tests --- Project.toml | 2 + docs/src/tutorials/SampledData.md | 16 +-- src/ModelingToolkit.jl | 5 +- src/clock.jl | 101 ++++++------------ src/discretedomain.jl | 53 +++++----- src/systems/abstractsystem.jl | 35 ++++++- src/systems/clock_inference.jl | 6 +- src/systems/diffeqs/abstractodesystem.jl | 38 +++---- src/systems/index_cache.jl | 127 +++++++++++++++++++---- src/systems/parameter_buffer.jl | 17 ++- src/systems/systemstructure.jl | 5 +- test/clock.jl | 82 +++++++-------- test/parameter_dependencies.jl | 4 +- 13 files changed, 287 insertions(+), 204 deletions(-) diff --git a/Project.toml b/Project.toml index e6fd39d9e4..c9c7a47811 100644 --- a/Project.toml +++ b/Project.toml @@ -20,6 +20,7 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf" DynamicQuantities = "06fc5a27-2a28-4c7c-a15d-362465fb6821" ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04" +Expronicon = "6b7a57c9-7cc1-4fdf-b7f5-e857abae3636" FindFirstFunctions = "64ca27bc-2ba2-4a57-88aa-44e436879224" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" FunctionWrappersWrappers = "77dc65aa-8811-40c2-897b-53d922fa7daf" @@ -81,6 +82,7 @@ DocStringExtensions = "0.7, 0.8, 0.9" DomainSets = "0.6, 0.7" DynamicQuantities = "^0.11.2, 0.12, 0.13" ExprTools = "0.1.10" +Expronicon = "0.8" FindFirstFunctions = "1" ForwardDiff = "0.10.3" FunctionWrappersWrappers = "0.1" diff --git a/docs/src/tutorials/SampledData.md b/docs/src/tutorials/SampledData.md index 614e8b65c7..a72fd1698b 100644 --- a/docs/src/tutorials/SampledData.md +++ b/docs/src/tutorials/SampledData.md @@ -16,7 +16,7 @@ A clock can be seen as an *event source*, i.e., when the clock ticks, an event i - [`Hold`](@ref) - [`ShiftIndex`](@ref) -When a continuous-time variable `x` is sampled using `xd = Sample(x, dt)`, the result is a discrete-time variable `xd` that is defined and updated whenever the clock ticks. `xd` is *only defined when the clock ticks*, which it does with an interval of `dt`. If `dt` is unspecified, the tick rate of the clock associated with `xd` is inferred from the context in which `xd` appears. Any variable taking part in the same equation as `xd` is inferred to belong to the same *discrete partition* as `xd`, i.e., belonging to the same clock. A system may contain multiple different discrete-time partitions, each with a unique clock. This allows for modeling of multi-rate systems and discrete-time processes located on different computers etc. +When a continuous-time variable `x` is sampled using `xd = Sample(dt)(x)`, the result is a discrete-time variable `xd` that is defined and updated whenever the clock ticks. `xd` is *only defined when the clock ticks*, which it does with an interval of `dt`. If `dt` is unspecified, the tick rate of the clock associated with `xd` is inferred from the context in which `xd` appears. Any variable taking part in the same equation as `xd` is inferred to belong to the same *discrete partition* as `xd`, i.e., belonging to the same clock. A system may contain multiple different discrete-time partitions, each with a unique clock. This allows for modeling of multi-rate systems and discrete-time processes located on different computers etc. To make a discrete-time variable available to the continuous partition, the [`Hold`](@ref) operator is used. `xc = Hold(xd)` creates a continuous-time variable `xc` that is updated whenever the clock associated with `xd` ticks, and holds its value constant between ticks. @@ -34,7 +34,7 @@ using ModelingToolkit using ModelingToolkit: t_nounits as t @variables x(t) y(t) u(t) dt = 0.1 # Sample interval -clock = Clock(t, dt) # A periodic clock with tick rate dt +clock = Clock(dt) # A periodic clock with tick rate dt k = ShiftIndex(clock) eqs = [ @@ -99,7 +99,7 @@ may thus be modeled as ```julia t = ModelingToolkit.t_nounits @variables y(t) [description = "Output"] u(t) [description = "Input"] -k = ShiftIndex(Clock(t, dt)) +k = ShiftIndex(Clock(dt)) eqs = [ a2 * y(k) + a1 * y(k - 1) + a0 * y(k - 2) ~ b2 * u(k) + b1 * u(k - 1) + b0 * u(k - 2) ] @@ -128,10 +128,10 @@ requires specification of the initial condition for both `x(k-1)` and `x(k-2)`. Multi-rate systems are easy to model using multiple different clocks. The following set of equations is valid, and defines *two different discrete-time partitions*, each with its own clock: ```julia -yd1 ~ Sample(t, dt1)(y) -ud1 ~ kp * (Sample(t, dt1)(r) - yd1) -yd2 ~ Sample(t, dt2)(y) -ud2 ~ kp * (Sample(t, dt2)(r) - yd2) +yd1 ~ Sample(dt1)(y) +ud1 ~ kp * (Sample(dt1)(r) - yd1) +yd2 ~ Sample(dt2)(y) +ud2 ~ kp * (Sample(dt2)(r) - yd2) ``` `yd1` and `ud1` belong to the same clock which ticks with an interval of `dt1`, while `yd2` and `ud2` belong to a different clock which ticks with an interval of `dt2`. The two clocks are *not synchronized*, i.e., they are not *guaranteed* to tick at the same point in time, even if one tick interval is a rational multiple of the other. Mechanisms for synchronization of clocks are not yet implemented. @@ -148,7 +148,7 @@ using ModelingToolkit: t_nounits as t using ModelingToolkit: D_nounits as D dt = 0.5 # Sample interval @variables r(t) -clock = Clock(t, dt) +clock = Clock(dt) k = ShiftIndex(clock) function plant(; name) diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index e70991ad3e..211c184130 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -43,7 +43,8 @@ using SciMLStructures using Compat using AbstractTrees using DiffEqBase, SciMLBase, ForwardDiff -using SciMLBase: StandardODEProblem, StandardNonlinearProblem, handle_varmap +using SciMLBase: StandardODEProblem, StandardNonlinearProblem, handle_varmap, TimeDomain, + PeriodicClock, Clock, SolverStepClock, Continuous using Distributed import JuliaFormatter using MLStyle @@ -272,6 +273,6 @@ export debug_system #export has_discrete_domain, has_continuous_domain #export is_discrete_domain, is_continuous_domain, is_hybrid_domain export Sample, Hold, Shift, ShiftIndex, sampletime, SampleTime -export Clock #, InferredDiscrete, +export Clock, SolverStepClock, TimeDomain end # module diff --git a/src/clock.jl b/src/clock.jl index 5df6cfb022..26ea5832da 100644 --- a/src/clock.jl +++ b/src/clock.jl @@ -1,13 +1,26 @@ -abstract type TimeDomain end -abstract type AbstractDiscrete <: TimeDomain end +module InferredClock -Base.Broadcast.broadcastable(d::TimeDomain) = Ref(d) +export InferredTimeDomain -struct Inferred <: TimeDomain end -struct InferredDiscrete <: AbstractDiscrete end -struct Continuous <: TimeDomain end +using Expronicon.ADT: @adt, @match +using SciMLBase: TimeDomain -Symbolics.option_to_metadata_type(::Val{:timedomain}) = TimeDomain +@adt InferredTimeDomain begin + Inferred + InferredDiscrete +end + +Base.Broadcast.broadcastable(x::InferredTimeDomain) = Ref(x) + +end + +using .InferredClock + +struct VariableTimeDomain end +Symbolics.option_to_metadata_type(::Val{:timedomain}) = VariableTimeDomain + +is_concrete_time_domain(::TimeDomain) = true +is_concrete_time_domain(_) = false """ is_continuous_domain(x) @@ -16,7 +29,7 @@ true if `x` contains only continuous-domain signals. See also [`has_continuous_domain`](@ref) """ function is_continuous_domain(x) - issym(x) && return getmetadata(x, TimeDomain, false) isa Continuous + issym(x) && return getmetadata(x, VariableTimeDomain, false) == Continuous !has_discrete_domain(x) && has_continuous_domain(x) end @@ -24,7 +37,7 @@ function get_time_domain(x) if iscall(x) && operation(x) isa Operator output_timedomain(x) else - getmetadata(x, TimeDomain, nothing) + getmetadata(x, VariableTimeDomain, nothing) end end get_time_domain(x::Num) = get_time_domain(value(x)) @@ -37,14 +50,14 @@ Determine if variable `x` has a time-domain attributed to it. function has_time_domain(x::Symbolic) # getmetadata(x, Continuous, nothing) !== nothing || # getmetadata(x, Discrete, nothing) !== nothing - getmetadata(x, TimeDomain, nothing) !== nothing + getmetadata(x, VariableTimeDomain, nothing) !== nothing end has_time_domain(x::Num) = has_time_domain(value(x)) has_time_domain(x) = false for op in [Differential] - @eval input_timedomain(::$op, arg = nothing) = Continuous() - @eval output_timedomain(::$op, arg = nothing) = Continuous() + @eval input_timedomain(::$op, arg = nothing) = Continuous + @eval output_timedomain(::$op, arg = nothing) = Continuous end """ @@ -83,12 +96,17 @@ true if `x` contains only discrete-domain signals. See also [`has_discrete_domain`](@ref) """ function is_discrete_domain(x) - if hasmetadata(x, TimeDomain) || issym(x) - return getmetadata(x, TimeDomain, false) isa AbstractDiscrete + if hasmetadata(x, VariableTimeDomain) || issym(x) + return is_discrete_time_domain(getmetadata(x, VariableTimeDomain, false)) end !has_discrete_domain(x) && has_continuous_domain(x) end +sampletime(c) = @match c begin + PeriodicClock(dt, _...) => dt + _ => nothing +end + struct ClockInferenceException <: Exception msg::Any end @@ -97,57 +115,4 @@ function Base.showerror(io::IO, cie::ClockInferenceException) print(io, "ClockInferenceException: ", cie.msg) end -abstract type AbstractClock <: AbstractDiscrete end - -""" - Clock <: AbstractClock - Clock([t]; dt) - -The default periodic clock with independent variables `t` and tick interval `dt`. -If `dt` is left unspecified, it will be inferred (if possible). -""" -struct Clock <: AbstractClock - "Independent variable" - t::Union{Nothing, Symbolic} - "Period" - dt::Union{Nothing, Float64} - Clock(t::Union{Num, Symbolic}, dt = nothing) = new(value(t), dt) - Clock(t::Nothing, dt = nothing) = new(t, dt) -end -Clock(dt::Real) = Clock(nothing, dt) -Clock() = Clock(nothing, nothing) - -sampletime(c) = isdefined(c, :dt) ? c.dt : nothing -Base.hash(c::Clock, seed::UInt) = hash(c.dt, seed ⊻ 0x953d7a9a18874b90) -function Base.:(==)(c1::Clock, c2::Clock) - ((c1.t === nothing || c2.t === nothing) || isequal(c1.t, c2.t)) && c1.dt == c2.dt -end - -is_concrete_time_domain(x) = x isa Union{AbstractClock, Continuous} - -""" - SolverStepClock <: AbstractClock - SolverStepClock() - SolverStepClock(t) - -A clock that ticks at each solver step (sometimes referred to as "continuous sample time"). This clock **does generally not have equidistant tick intervals**, instead, the tick interval depends on the adaptive step-size selection of the continuous solver, as well as any continuous event handling. If adaptivity of the solver is turned off and there are no continuous events, the tick interval will be given by the fixed solver time step `dt`. - -Due to possibly non-equidistant tick intervals, this clock should typically not be used with discrete-time systems that assume a fixed sample time, such as PID controllers and digital filters. -""" -struct SolverStepClock <: AbstractClock - "Independent variable" - t::Union{Nothing, Symbolic} - "Period" - SolverStepClock(t::Union{Num, Symbolic}) = new(value(t)) -end -SolverStepClock() = SolverStepClock(nothing) - -Base.hash(c::SolverStepClock, seed::UInt) = seed ⊻ 0x953d7b9a18874b91 -function Base.:(==)(c1::SolverStepClock, c2::SolverStepClock) - ((c1.t === nothing || c2.t === nothing) || isequal(c1.t, c2.t)) -end - -struct IntegerSequence <: AbstractClock - t::Union{Nothing, Symbolic} - IntegerSequence(t::Union{Num, Symbolic}) = new(value(t)) -end +struct IntegerSequence end diff --git a/src/discretedomain.jl b/src/discretedomain.jl index cb723e159f..facb151d77 100644 --- a/src/discretedomain.jl +++ b/src/discretedomain.jl @@ -85,8 +85,8 @@ $(TYPEDEF) Represents a sample operator. A discrete-time signal is created by sampling a continuous-time signal. # Constructors -`Sample(clock::TimeDomain = InferredDiscrete())` -`Sample([t], dt::Real)` +`Sample(clock::Union{TimeDomain, InferredTimeDomain} = InferredDiscrete)` +`Sample(dt::Real)` `Sample(x::Num)`, with a single argument, is shorthand for `Sample()(x)`. @@ -100,16 +100,23 @@ julia> using Symbolics julia> t = ModelingToolkit.t_nounits -julia> Δ = Sample(t, 0.01) +julia> Δ = Sample(0.01) (::Sample) (generic function with 2 methods) ``` """ struct Sample <: Operator clock::Any - Sample(clock::TimeDomain = InferredDiscrete()) = new(clock) - Sample(t, dt::Real) = new(Clock(t, dt)) + Sample(clock::Union{TimeDomain, InferredTimeDomain} = InferredDiscrete) = new(clock) +end + +function Sample(arg::Real) + arg = unwrap(arg) + if symbolic_type(arg) == NotSymbolic() + Sample(Clock(arg)) + else + Sample()(arg) + end end -Sample(x) = Sample()(x) (D::Sample)(x) = Term{symtype(x)}(D, Any[x]) (D::Sample)(x::Num) = Num(D(value(x))) SymbolicUtils.promote_symtype(::Sample, x) = x @@ -176,15 +183,18 @@ julia> x(k) # no shift x(t) julia> x(k+1) # shift -Shift(t, 1)(x(t)) +Shift(1)(x(t)) ``` """ struct ShiftIndex - clock::TimeDomain + clock::Union{InferredTimeDomain, TimeDomain, IntegerSequence} steps::Int - ShiftIndex(clock::TimeDomain = Inferred(), steps::Int = 0) = new(clock, steps) - ShiftIndex(t::Num, dt::Real, steps::Int = 0) = new(Clock(t, dt), steps) - ShiftIndex(t::Num, steps::Int = 0) = new(IntegerSequence(t), steps) + function ShiftIndex( + clock::Union{TimeDomain, InferredTimeDomain, IntegerSequence} = Inferred, steps::Int = 0) + new(clock, steps) + end + ShiftIndex(dt::Real, steps::Int = 0) = new(Clock(dt), steps) + ShiftIndex(::Num, steps::Int) = new(IntegerSequence(), steps) end function (xn::Num)(k::ShiftIndex) @@ -197,18 +207,13 @@ function (xn::Num)(k::ShiftIndex) args = Symbolics.arguments(vars[]) # args should be one element vector with the t in x(t) length(args) == 1 || error("Cannot shift an expression with multiple independent variables $x.") - t = args[] - if hasfield(typeof(clock), :t) - isequal(t, clock.t) || - error("Independent variable of $xn is not the same as that of the ShiftIndex $(k.t)") - end # d, _ = propagate_time_domain(xn) # if d != clock # this is only required if the variable has another clock # xn = Sample(t, clock)(xn) # end # QUESTION: should we return a variable with time domain set to k.clock? - xn = setmetadata(xn, TimeDomain, k.clock) + xn = setmetadata(xn, VariableTimeDomain, k.clock) if steps == 0 return xn # x(k) needs no shift operator if the step of k is 0 end @@ -221,37 +226,37 @@ Base.:-(k::ShiftIndex, i::Int) = k + (-i) """ input_timedomain(op::Operator) -Return the time-domain type (`Continuous()` or `Discrete()`) that `op` operates on. +Return the time-domain type (`Continuous` or `InferredDiscrete`) that `op` operates on. """ function input_timedomain(s::Shift, arg = nothing) if has_time_domain(arg) return get_time_domain(arg) end - InferredDiscrete() + InferredDiscrete end """ output_timedomain(op::Operator) -Return the time-domain type (`Continuous()` or `Discrete()`) that `op` results in. +Return the time-domain type (`Continuous` or `InferredDiscrete`) that `op` results in. """ function output_timedomain(s::Shift, arg = nothing) if has_time_domain(arg) return get_time_domain(arg) end - InferredDiscrete() + InferredDiscrete end -input_timedomain(::Sample, arg = nothing) = Continuous() +input_timedomain(::Sample, arg = nothing) = Continuous output_timedomain(s::Sample, arg = nothing) = s.clock function input_timedomain(h::Hold, arg = nothing) if has_time_domain(arg) return get_time_domain(arg) end - InferredDiscrete() # the Hold accepts any discrete + InferredDiscrete # the Hold accepts any discrete end -output_timedomain(::Hold, arg = nothing) = Continuous() +output_timedomain(::Hold, arg = nothing) = Continuous sampletime(op::Sample, arg = nothing) = sampletime(op.clock) sampletime(op::ShiftIndex, arg = nothing) = sampletime(op.clock) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 51fd0cc206..11292752cc 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -571,8 +571,17 @@ function SymbolicIndexingInterface.parameter_observed(sys::AbstractSystem, sym) return obsfn end +function has_observed_with_lhs(sys, sym) + has_observed(sys) || return false + if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing + return any(isequal(sym), ic.observed_syms) + else + return any(isequal(sym), [eq.lhs for eq in observed(sys)]) + end +end + function _all_ts_idxs!(ts_idxs, ::NotSymbolic, sys, sym) - if is_variable(sys, sym) + if is_variable(sys, sym) || is_independent_variable(sys, sym) push!(ts_idxs, ContinuousTimeseries()) elseif is_timeseries_parameter(sys, sym) push!(ts_idxs, timeseries_parameter_index(sys, sym).timeseries_idx) @@ -585,17 +594,33 @@ for traitT in [ ] @eval function _all_ts_idxs!(ts_idxs, ::$traitT, sys, sym) allsyms = vars(sym; op = Symbolics.Operator) - foreach(allsyms) do s - _all_ts_idxs!(ts_idxs, sys, s) + for s in allsyms + s = unwrap(s) + if is_variable(sys, s) || is_independent_variable(sys, s) || + has_observed_with_lhs(sys, s) + push!(ts_idxs, ContinuousTimeseries()) + elseif is_timeseries_parameter(sys, s) + push!(ts_idxs, timeseries_parameter_index(sys, s).timeseries_idx) + end end end end +function _all_ts_idxs!(ts_idxs, ::ScalarSymbolic, sys, sym::Symbol) + if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing + return _all_ts_idxs!(ts_idxs, sys, ic.symbol_to_variable[sym]) + elseif is_variable(sys, sym) || is_independent_variable(sys, sym) || + any(isequal(sym), [getname(eq.lhs) for eq in observed(sys)]) + push!(ts_idxs, ContinuousTimeseries()) + elseif is_timeseries_parameter(sys, sym) + push!(ts_idxs, timeseries_parameter_index(sys, s).timeseries_idx) + end +end function _all_ts_idxs!(ts_idxs, ::NotSymbolic, sys, sym::AbstractArray) - foreach(sym) do s + for s in sym _all_ts_idxs!(ts_idxs, sys, s) end end -_all_ts_idxs!(ts_idxs, sys, sym) = _all_ts_idxs!(ts_idxs, NotSymbolic(), sys, sym) +_all_ts_idxs!(ts_idxs, sys, sym) = _all_ts_idxs!(ts_idxs, symbolic_type(sym), sys, sym) function SymbolicIndexingInterface.get_all_timeseries_indexes(sys::AbstractSystem, sym) if !is_time_dependent(sys) diff --git a/src/systems/clock_inference.jl b/src/systems/clock_inference.jl index c6a8464536..dfdef69034 100644 --- a/src/systems/clock_inference.jl +++ b/src/systems/clock_inference.jl @@ -8,8 +8,8 @@ end function ClockInference(ts::TransformationState) @unpack structure = ts @unpack graph = structure - eq_domain = TimeDomain[Continuous() for _ in 1:nsrcs(graph)] - var_domain = TimeDomain[Continuous() for _ in 1:ndsts(graph)] + eq_domain = TimeDomain[Continuous for _ in 1:nsrcs(graph)] + var_domain = TimeDomain[Continuous for _ in 1:ndsts(graph)] inferred = BitSet() for (i, v) in enumerate(get_fullvars(ts)) d = get_time_domain(v) @@ -151,7 +151,7 @@ function split_system(ci::ClockInference{S}) where {S} get!(clock_to_id, d) do cid = (cid_counter[] += 1) push!(id_to_clock, d) - if d isa Continuous + if d == Continuous continuous_id[] = cid end cid diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index c3d8fdf7b3..5f69266f7e 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -825,7 +825,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap; # ModelingToolkit.get_tearing_state(sys) !== nothing => Requires structural_simplify first if sys isa ODESystem && build_initializeprob && (((implicit_dae || !isempty(missingvars)) && - all(isequal(Continuous()), ci.var_domain) && + all(==(Continuous), ci.var_domain) && ModelingToolkit.get_tearing_state(sys) !== nothing) || !isempty(initialization_equations(sys))) && t !== nothing if eltype(u0map) <: Number @@ -1011,14 +1011,12 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map = affects, clocks = ModelingToolkit.generate_discrete_affect( sys, dss...; eval_expression, eval_module) discrete_cbs = map(affects, clocks) do affect, clock - if clock isa Clock - PeriodicCallback(affect, clock.dt; + @match clock begin + PeriodicClock(dt, _...) => PeriodicCallback(affect, dt; final_affect = true, initial_affect = true) - elseif clock isa SolverStepClock - DiscreteCallback(Returns(true), affect, + &SolverStepClock => DiscreteCallback(Returns(true), affect, initialize = (c, u, t, integrator) -> affect(integrator)) - else - error("$clock is not a supported clock type.") + _ => error("$clock is not a supported clock type.") end end if cbs === nothing @@ -1112,14 +1110,15 @@ function DiffEqBase.DDEProblem{iip}(sys::AbstractODESystem, u0map = [], u0 = h(p, tspan[1]) cbs = process_events(sys; callback, eval_expression, eval_module, kwargs...) if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing - affects, clocks, svs = ModelingToolkit.generate_discrete_affect( + affects, clocks = ModelingToolkit.generate_discrete_affect( sys, dss...; eval_expression, eval_module) - discrete_cbs = map(affects, clocks, svs) do affect, clock, sv - if clock isa Clock - PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt; + discrete_cbs = map(affects, clocks) do affect, clock + @match clock begin + PeriodicClock(dt, _...) => PeriodicCallback(affect, dt; final_affect = true, initial_affect = true) - else - error("$clock is not a supported clock type.") + &SolverStepClock => DiscreteCallback(Returns(true), affect, + initialize = (c, u, t, integrator) -> affect(integrator)) + _ => error("$clock is not a supported clock type.") end end if cbs === nothing @@ -1174,14 +1173,15 @@ function DiffEqBase.SDDEProblem{iip}(sys::AbstractODESystem, u0map = [], u0 = h(p, tspan[1]) cbs = process_events(sys; callback, eval_expression, eval_module, kwargs...) if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing - affects, clocks, svs = ModelingToolkit.generate_discrete_affect( + affects, clocks = ModelingToolkit.generate_discrete_affect( sys, dss...; eval_expression, eval_module) - discrete_cbs = map(affects, clocks, svs) do affect, clock, sv - if clock isa Clock - PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt; + discrete_cbs = map(affects, clocks) do affect, clock + @match clock begin + PeriodicClock(dt, _...) => PeriodicCallback(affect, dt; final_affect = true, initial_affect = true) - else - error("$clock is not a supported clock type.") + &SolverStepClock => DiscreteCallback(Returns(true), affect, + initialize = (c, u, t, integrator) -> affect(integrator)) + _ => error("$clock is not a supported clock type.") end end if cbs === nothing diff --git a/src/systems/index_cache.jl b/src/systems/index_cache.jl index 13fb7adef2..f992cd4907 100644 --- a/src/systems/index_cache.jl +++ b/src/systems/index_cache.jl @@ -32,6 +32,7 @@ struct IndexCache constant_idx::ParamIndexMap dependent_idx::ParamIndexMap nonnumeric_idx::ParamIndexMap + observed_syms::Set{Union{Symbol, BasicSymbolic}} discrete_buffer_sizes::Vector{Vector{BufferTemplate}} tunable_buffer_sizes::Vector{BufferTemplate} constant_buffer_sizes::Vector{BufferTemplate} @@ -48,16 +49,21 @@ function IndexCache(sys::AbstractSystem) let idx = 1 for sym in unks usym = unwrap(sym) + rsym = renamespace(sys, usym) sym_idx = if Symbolics.isarraysymbolic(sym) reshape(idx:(idx + length(sym) - 1), size(sym)) else idx end unk_idxs[usym] = sym_idx + unk_idxs[rsym] = sym_idx if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex) name = getname(usym) + rname = getname(rsym) unk_idxs[name] = sym_idx + unk_idxs[rname] = sym_idx symbol_to_variable[name] = sym + symbol_to_variable[rname] = sym end idx += length(sym) end @@ -71,18 +77,41 @@ function IndexCache(sys::AbstractSystem) if idxs == idxs[begin]:idxs[end] idxs = reshape(idxs[begin]:idxs[end], size(idxs)) end + rsym = renamespace(sys, arrsym) unk_idxs[arrsym] = idxs + unk_idxs[rsym] = idxs if hasname(arrsym) name = getname(arrsym) + rname = getname(rsym) unk_idxs[name] = idxs + unk_idxs[rname] = idxs symbol_to_variable[name] = arrsym + symbol_to_variable[rname] = arrsym end end end + observed_syms = Set{Union{Symbol, BasicSymbolic}}() for eq in observed(sys) - if symbolic_type(eq.lhs) != NotSymbolic() && hasname(eq.lhs) - symbol_to_variable[getname(eq.lhs)] = eq.lhs + if symbolic_type(eq.lhs) != NotSymbolic() + sym = eq.lhs + ttsym = default_toterm(sym) + rsym = renamespace(sys, sym) + rttsym = renamespace(sys, ttsym) + push!(observed_syms, sym) + push!(observed_syms, ttsym) + push!(observed_syms, rsym) + push!(observed_syms, rttsym) + if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex) + symbol_to_variable[getname(sym)] = eq.lhs + symbol_to_variable[getname(ttsym)] = eq.lhs + symbol_to_variable[getname(rsym)] = eq.lhs + symbol_to_variable[getname(rttsym)] = eq.lhs + push!(observed_syms, getname(sym)) + push!(observed_syms, getname(ttsym)) + push!(observed_syms, getname(rsym)) + push!(observed_syms, getname(rttsym)) + end end end @@ -109,26 +138,40 @@ function IndexCache(sys::AbstractSystem) for inp in inps inp = unwrap(inp) + ttinp = default_toterm(inp) + rinp = renamespace(sys, inp) + rttinp = renamespace(sys, ttinp) is_parameter(sys, inp) || error("Discrete subsystem $i input $inp is not a parameter") disc_clocks[inp] = i - disc_clocks[default_toterm(inp)] = i + disc_clocks[ttinp] = i + disc_clocks[rinp] = i + disc_clocks[rttinp] = i if hasname(inp) && (!iscall(inp) || operation(inp) !== getindex) disc_clocks[getname(inp)] = i - disc_clocks[default_toterm(inp)] = i + disc_clocks[getname(ttinp)] = i + disc_clocks[getname(rinp)] = i + disc_clocks[getname(rttinp)] = i end insert_by_type!(disc_buffers[i], inp) end for sym in unknowns(disc_sys) sym = unwrap(sym) + ttsym = default_toterm(sym) + rsym = renamespace(sys, sym) + rttsym = renamespace(sys, ttsym) is_parameter(sys, sym) || error("Discrete subsystem $i unknown $sym is not a parameter") disc_clocks[sym] = i - disc_clocks[default_toterm(sym)] = i + disc_clocks[ttsym] = i + disc_clocks[rsym] = i + disc_clocks[rttsym] = i if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex) disc_clocks[getname(sym)] = i - disc_clocks[getname(default_toterm(sym))] = i + disc_clocks[getname(ttsym)] = i + disc_clocks[getname(rsym)] = i + disc_clocks[getname(rttsym)] = i end insert_by_type!(disc_buffers[i], sym) end @@ -138,21 +181,31 @@ function IndexCache(sys::AbstractSystem) # FIXME: This shouldn't be necessary eq.rhs === -0.0 && continue sym = eq.lhs + ttsym = default_toterm(sym) + rsym = renamespace(sys, sym) + rttsym = renamespace(sys, ttsym) if iscall(sym) && operation(sym) == Shift(t, 1) sym = only(arguments(sym)) end disc_clocks[sym] = i - disc_clocks[sym] = i - disc_clocks[default_toterm(sym)] = i + disc_clocks[ttsym] = i + disc_clocks[rsym] = i + disc_clocks[rttsym] = i if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex) disc_clocks[getname(sym)] = i - disc_clocks[getname(default_toterm(sym))] = i + disc_clocks[getname(ttsym)] = i + disc_clocks[getname(rsym)] = i + disc_clocks[getname(rttsym)] = i end end end for par in inputs[continuous_id] is_parameter(sys, par) || error("Discrete subsystem input is not a parameter") + par = unwrap(par) + ttpar = default_toterm(par) + rpar = renamespace(sys, par) + rttpar = renamespace(sys, ttpar) iscall(par) && operation(par) isa Hold || error("Continuous subsystem input is not a Hold") if haskey(disc_clocks, par) @@ -163,6 +216,9 @@ function IndexCache(sys::AbstractSystem) haskey(disc_clocks, sym) || error("Variable $par not part of a discrete subsystem") disc_clocks[par] = disc_clocks[sym] + disc_clocks[ttpar] = disc_clocks[sym] + disc_clocks[rpar] = disc_clocks[sym] + disc_clocks[rttpar] = disc_clocks[sym] insert_by_type!(disc_buffers[disc_clocks[sym]], par) end end @@ -172,13 +228,21 @@ function IndexCache(sys::AbstractSystem) for affect in affs if affect isa Equation is_parameter(sys, affect.lhs) || continue - - disc_clocks[affect.lhs] = user_affect_clock - disc_clocks[default_toterm(affect.lhs)] = user_affect_clock - if hasname(affect.lhs) && - (!iscall(affect.lhs) || operation(affect.lhs) !== getindex) - disc_clocks[getname(affect.lhs)] = user_affect_clock - disc_clocks[getname(default_toterm(affect.lhs))] = user_affect_clock + sym = affect.lhs + ttsym = default_toterm(sym) + rsym = renamespace(sys, sym) + rttsym = renamespace(sys, ttsym) + + disc_clocks[sym] = user_affect_clock + disc_clocks[ttsym] = user_affect_clock + disc_clocks[rsym] = user_affect_clock + disc_clocks[rttsym] = user_affect_clock + if hasname(sym) && + (!iscall(sym) || operation(sym) !== getindex) + disc_clocks[getname(sym)] = user_affect_clock + disc_clocks[getname(ttsym)] = user_affect_clock + disc_clocks[getname(rsym)] = user_affect_clock + disc_clocks[getname(rttsym)] = user_affect_clock end buffer = get!(disc_buffers, user_affect_clock, Dict{Any, Set{BasicSymbolic}}()) insert_by_type!(buffer, affect.lhs) @@ -188,11 +252,18 @@ function IndexCache(sys::AbstractSystem) is_parameter(sys, disc) || error("Expected discrete variable $disc in callback to be a parameter") disc = unwrap(disc) + ttdisc = default_toterm(disc) + rdisc = renamespace(sys, disc) + rttdisc = renamespace(sys, ttdisc) disc_clocks[disc] = user_affect_clock - disc_clocks[default_toterm(disc)] = user_affect_clock + disc_clocks[ttdisc] = user_affect_clock + disc_clocks[rdisc] = user_affect_clock + disc_clocks[rttdisc] = user_affect_clock if hasname(disc) && (!iscall(disc) || operation(disc) !== getindex) disc_clocks[getname(disc)] = user_affect_clock - disc_clocks[getname(default_toterm(disc))] = user_affect_clock + disc_clocks[getname(ttdisc)] = user_affect_clock + disc_clocks[getname(rdisc)] = user_affect_clock + disc_clocks[getname(rttdisc)] = user_affect_clock end buffer = get!( disc_buffers, user_affect_clock, Dict{Any, Set{BasicSymbolic}}()) @@ -267,13 +338,22 @@ function IndexCache(sys::AbstractSystem) buffer_sizes = BufferTemplate[] for (i, (T, buf)) in enumerate(buffers) for (j, p) in enumerate(buf) + ttp = default_toterm(p) + rp = renamespace(sys, p) + rttp = renamespace(sys, ttp) idxs[p] = (i, j) - idxs[default_toterm(p)] = (i, j) + idxs[ttp] = (i, j) + idxs[rp] = (i, j) + idxs[rttp] = (i, j) if hasname(p) && (!iscall(p) || operation(p) !== getindex) idxs[getname(p)] = (i, j) + idxs[getname(ttp)] = (i, j) + idxs[getname(rp)] = (i, j) + idxs[getname(rttp)] = (i, j) symbol_to_variable[getname(p)] = p - idxs[getname(default_toterm(p))] = (i, j) - symbol_to_variable[getname(default_toterm(p))] = p + symbol_to_variable[getname(ttp)] = p + symbol_to_variable[getname(rp)] = p + symbol_to_variable[getname(rttp)] = p end end push!(buffer_sizes, BufferTemplate(T, length(buf))) @@ -293,6 +373,7 @@ function IndexCache(sys::AbstractSystem) const_idxs, dependent_idxs, nonnumeric_idxs, + observed_syms, disc_buffer_sizes, tunable_buffer_sizes, const_buffer_sizes, @@ -306,6 +387,10 @@ function SymbolicIndexingInterface.is_variable(ic::IndexCache, sym) return check_index_map(ic.unknown_idx, sym) !== nothing end +function SymbolicIndexingInterface.is_variable(ic::IndexCache, sym::Symbol) + return check_index_map(ic.unknown_idx, sym) !== nothing +end + function SymbolicIndexingInterface.variable_index(ic::IndexCache, sym) return check_index_map(ic.unknown_idx, sym) end diff --git a/src/systems/parameter_buffer.jl b/src/systems/parameter_buffer.jl index cd9123f0bf..43ccdb7e56 100644 --- a/src/systems/parameter_buffer.jl +++ b/src/systems/parameter_buffer.jl @@ -604,16 +604,15 @@ function SciMLBase.create_parameter_timeseries_collection( for (i, partition) in enumerate(ps.discrete) clock = id_to_clock[i] - if clock isa Clock - ts = tspan[1]:(clock.dt):tspan[2] - push!(buffers, DiffEqArray(NestedGetIndex{typeof(partition)}[], ts, (1, 1))) - elseif clock isa SolverStepClock - push!(buffers, + @match clock begin + PeriodicClock(dt, _...) => begin + ts = tspan[1]:(dt):tspan[2] + push!(buffers, DiffEqArray(NestedGetIndex{typeof(partition)}[], ts, (1, 1))) + end + &SolverStepClock => push!(buffers, DiffEqArray(NestedGetIndex{typeof(partition)}[], eltype(tspan)[], (1, 1))) - elseif clock isa Continuous - continue - else - error("Unhandled clock $clock") + &Continuous => continue + _ => error("Unhandled clock $clock") end end diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl index ff26552c79..2cbf820d0d 100644 --- a/src/systems/systemstructure.jl +++ b/src/systems/systemstructure.jl @@ -8,6 +8,7 @@ import ..ModelingToolkit: isdiffeq, var_from_nested_derivative, vars!, flatten, isparameter, isconstant, independent_variables, SparseMatrixCLIL, AbstractSystem, equations, isirreducible, input_timedomain, TimeDomain, + InferredTimeDomain, VariableType, getvariabletype, has_equations, ODESystem using ..BipartiteGraphs import ..BipartiteGraphs: invview, complete @@ -331,7 +332,7 @@ function TearingState(sys; quick_cancel = false, check = true) !isdifferential(var) && (it = input_timedomain(var)) !== nothing set_incidence = false var = only(arguments(var)) - var = setmetadata(var, TimeDomain, it) + var = setmetadata(var, VariableTimeDomain, it) @goto ANOTHER_VAR end end @@ -660,7 +661,7 @@ function structural_simplify!(state::TearingState, io = nothing; simplify = fals @set! sys.defaults = merge(ModelingToolkit.defaults(sys), Dict(v => 0.0 for v in Iterators.flatten(inputs))) end - ps = [setmetadata(sym, TimeDomain, get(time_domains, sym, Continuous())) + ps = [setmetadata(sym, VariableTimeDomain, get(time_domains, sym, Continuous)) for sym in get_ps(sys)] @set! sys.ps = ps else diff --git a/test/clock.jl b/test/clock.jl index 69b7c30c50..5bf5e917aa 100644 --- a/test/clock.jl +++ b/test/clock.jl @@ -14,7 +14,7 @@ dt = 0.1 @parameters kp # u(n + 1) := f(u(n)) -eqs = [yd ~ Sample(t, dt)(y) +eqs = [yd ~ Sample(dt)(y) ud ~ kp * (r - yd) r ~ 1.0 @@ -70,35 +70,35 @@ sss, = ModelingToolkit._structural_simplify!( @test equations(sss) == [D(x) ~ u - x] sss, = ModelingToolkit._structural_simplify!(deepcopy(tss[1]), (inputs[1], ())) @test isempty(equations(sss)) -d = Clock(t, dt) +d = Clock(dt) k = ShiftIndex(d) -@test observed(sss) == [yd(k + 1) ~ Sample(t, dt)(y); r(k + 1) ~ 1.0; +@test observed(sss) == [yd(k + 1) ~ Sample(dt)(y); r(k + 1) ~ 1.0; ud(k + 1) ~ kp * (r(k + 1) - yd(k + 1))] -d = Clock(t, dt) +d = Clock(dt) # Note that TearingState reorders the equations -@test eqmap[1] == Continuous() +@test eqmap[1] == Continuous @test eqmap[2] == d @test eqmap[3] == d @test eqmap[4] == d -@test eqmap[5] == Continuous() -@test eqmap[6] == Continuous() +@test eqmap[5] == Continuous +@test eqmap[6] == Continuous @test varmap[yd] == d @test varmap[ud] == d @test varmap[r] == d -@test varmap[x] == Continuous() -@test varmap[y] == Continuous() -@test varmap[u] == Continuous() +@test varmap[x] == Continuous +@test varmap[y] == Continuous +@test varmap[u] == Continuous @info "Testing shift normalization" dt = 0.1 @variables x(t) y(t) u(t) yd(t) ud(t) @parameters kp -d = Clock(t, dt) +d = Clock(dt) k = ShiftIndex(d) -eqs = [yd ~ Sample(t, dt)(y) +eqs = [yd ~ Sample(dt)(y) ud ~ kp * yd + ud(k - 2) # plant (time continuous part) @@ -171,10 +171,10 @@ eqs = [yd ~ Sample(t, dt)(y) eqs = [ # controller (time discrete part `dt=0.1`) - yd1 ~ Sample(t, dt)(y) - ud1 ~ kp * (Sample(t, dt)(r) - yd1) - yd2 ~ Sample(t, dt2)(y) - ud2 ~ kp * (Sample(t, dt2)(r) - yd2) + yd1 ~ Sample(dt)(y) + ud1 ~ kp * (Sample(dt)(r) - yd1) + yd2 ~ Sample(dt2)(y) + ud2 ~ kp * (Sample(dt2)(r) - yd2) # plant (time continuous part) u ~ Hold(ud1) + Hold(ud2) @@ -183,8 +183,8 @@ eqs = [yd ~ Sample(t, dt)(y) @named sys = ODESystem(eqs, t) ci, varmap = infer_clocks(sys) - d = Clock(t, dt) - d2 = Clock(t, dt2) + d = Clock(dt) + d2 = Clock(dt2) #@test get_eq_domain(eqs[1]) == d #@test get_eq_domain(eqs[3]) == d2 @@ -192,15 +192,15 @@ eqs = [yd ~ Sample(t, dt)(y) @test varmap[ud1] == d @test varmap[yd2] == d2 @test varmap[ud2] == d2 - @test varmap[r] == Continuous() - @test varmap[x] == Continuous() - @test varmap[y] == Continuous() - @test varmap[u] == Continuous() + @test varmap[r] == Continuous + @test varmap[x] == Continuous + @test varmap[y] == Continuous + @test varmap[u] == Continuous @info "test composed systems" dt = 0.5 - d = Clock(t, dt) + d = Clock(dt) k = ShiftIndex(d) timevec = 0:0.1:4 @@ -240,16 +240,16 @@ eqs = [yd ~ Sample(t, dt)(y) ci, varmap = infer_clocks(cl) - @test varmap[f.x] == Clock(t, 0.5) - @test varmap[p.x] == Continuous() - @test varmap[p.y] == Continuous() - @test varmap[c.ud] == Clock(t, 0.5) - @test varmap[c.yd] == Clock(t, 0.5) - @test varmap[c.y] == Continuous() - @test varmap[f.y] == Clock(t, 0.5) - @test varmap[f.u] == Clock(t, 0.5) - @test varmap[p.u] == Continuous() - @test varmap[c.r] == Clock(t, 0.5) + @test varmap[f.x] == Clock(0.5) + @test varmap[p.x] == Continuous + @test varmap[p.y] == Continuous + @test varmap[c.ud] == Clock(0.5) + @test varmap[c.yd] == Clock(0.5) + @test varmap[c.y] == Continuous + @test varmap[f.y] == Clock(0.5) + @test varmap[f.u] == Clock(0.5) + @test varmap[p.u] == Continuous + @test varmap[c.r] == Clock(0.5) ## Multiple clock rates @info "Testing multi-rate hybrid system" @@ -260,10 +260,10 @@ eqs = [yd ~ Sample(t, dt)(y) eqs = [ # controller (time discrete part `dt=0.1`) - yd1 ~ Sample(t, dt)(y) + yd1 ~ Sample(dt)(y) ud1 ~ kp * (r - yd1) # controller (time discrete part `dt=0.2`) - yd2 ~ Sample(t, dt2)(y) + yd2 ~ Sample(dt2)(y) ud2 ~ kp * (r - yd2) # plant (time continuous part) @@ -273,8 +273,8 @@ eqs = [yd ~ Sample(t, dt)(y) @named cl = ODESystem(eqs, t) - d = Clock(t, dt) - d2 = Clock(t, dt2) + d = Clock(dt) + d2 = Clock(dt2) ci, varmap = infer_clocks(cl) @test varmap[yd1] == d @@ -331,8 +331,8 @@ eqs = [yd ~ Sample(t, dt)(y) using ModelingToolkitStandardLibrary.Blocks dt = 0.05 - d = Clock(t, dt) - k = ShiftIndex() + d = Clock(dt) + k = ShiftIndex(d) @mtkmodel DiscretePI begin @components begin @@ -362,7 +362,7 @@ eqs = [yd ~ Sample(t, dt)(y) output = RealOutput() end @equations begin - output.u ~ Sample(t, dt)(input.u) + output.u ~ Sample(dt)(input.u) end end @@ -474,7 +474,7 @@ eqs = [yd ~ Sample(t, dt)(y) ## Test continuous clock - c = ModelingToolkit.SolverStepClock(t) + c = ModelingToolkit.SolverStepClock k = ShiftIndex(c) @mtkmodel CounterSys begin diff --git a/test/parameter_dependencies.jl b/test/parameter_dependencies.jl index 242be8f1d7..fc03f53d74 100644 --- a/test/parameter_dependencies.jl +++ b/test/parameter_dependencies.jl @@ -157,10 +157,10 @@ end dt = 0.1 @variables x(t) y(t) u(t) yd(t) ud(t) r(t) z(t) @parameters kp kq - d = Clock(t, dt) + d = Clock(dt) k = ShiftIndex(d) - eqs = [yd ~ Sample(t, dt)(y) + eqs = [yd ~ Sample(dt)(y) ud ~ kp * (r - yd) + kq * z r ~ 1.0 u ~ Hold(ud) From 2c56f1c42d2e3fc0253f1259f4998a2cdc7abd49 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 24 Jul 2024 10:59:08 +0530 Subject: [PATCH 12/13] build: bump SciMLBase, RAT, SII compats --- Project.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index c9c7a47811..b412c502b6 100644 --- a/Project.toml +++ b/Project.toml @@ -100,10 +100,10 @@ NonlinearSolve = "3.12" OrderedCollections = "1" OrdinaryDiffEq = "6.82.0" PrecompileTools = "1" -RecursiveArrayTools = "2.3, 3" +RecursiveArrayTools = "3.26" Reexport = "0.2, 1" RuntimeGeneratedFunctions = "0.5.9" -SciMLBase = "2.28.0" +SciMLBase = "2.46" SciMLStructures = "1.0" Serialization = "1" Setfield = "0.7, 0.8, 1" @@ -111,7 +111,7 @@ SimpleNonlinearSolve = "0.1.0, 1" SparseArrays = "1" SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2" StaticArrays = "0.10, 0.11, 0.12, 1.0" -SymbolicIndexingInterface = "0.3.12" +SymbolicIndexingInterface = "0.3.26" SymbolicUtils = "2.1" Symbolics = "5.32" URIs = "1" From 3a073eca16829d68c3334c920ff4382e3d36350f Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 24 Jul 2024 12:31:56 +0530 Subject: [PATCH 13/13] refactor: improve `Symbol` indexing --- src/systems/index_cache.jl | 124 ++++++++++++------------------------- 1 file changed, 40 insertions(+), 84 deletions(-) diff --git a/src/systems/index_cache.jl b/src/systems/index_cache.jl index f992cd4907..899bba4aa5 100644 --- a/src/systems/index_cache.jl +++ b/src/systems/index_cache.jl @@ -21,18 +21,18 @@ end ParameterIndex(portion, idx) = ParameterIndex(portion, idx, false) -const ParamIndexMap = Dict{Union{Symbol, BasicSymbolic}, Tuple{Int, Int}} +const ParamIndexMap = Dict{BasicSymbolic, Tuple{Int, Int}} const UnknownIndexMap = Dict{ - Union{Symbol, BasicSymbolic}, Union{Int, UnitRange{Int}, AbstractArray{Int}}} + BasicSymbolic, Union{Int, UnitRange{Int}, AbstractArray{Int}}} struct IndexCache unknown_idx::UnknownIndexMap - discrete_idx::Dict{Union{Symbol, BasicSymbolic}, Tuple{Int, Int, Int}} + discrete_idx::Dict{BasicSymbolic, Tuple{Int, Int, Int}} tunable_idx::ParamIndexMap constant_idx::ParamIndexMap dependent_idx::ParamIndexMap nonnumeric_idx::ParamIndexMap - observed_syms::Set{Union{Symbol, BasicSymbolic}} + observed_syms::Set{BasicSymbolic} discrete_buffer_sizes::Vector{Vector{BufferTemplate}} tunable_buffer_sizes::Vector{BufferTemplate} constant_buffer_sizes::Vector{BufferTemplate} @@ -57,14 +57,6 @@ function IndexCache(sys::AbstractSystem) end unk_idxs[usym] = sym_idx unk_idxs[rsym] = sym_idx - if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex) - name = getname(usym) - rname = getname(rsym) - unk_idxs[name] = sym_idx - unk_idxs[rname] = sym_idx - symbol_to_variable[name] = sym - symbol_to_variable[rname] = sym - end idx += length(sym) end for sym in unks @@ -80,14 +72,6 @@ function IndexCache(sys::AbstractSystem) rsym = renamespace(sys, arrsym) unk_idxs[arrsym] = idxs unk_idxs[rsym] = idxs - if hasname(arrsym) - name = getname(arrsym) - rname = getname(rsym) - unk_idxs[name] = idxs - unk_idxs[rname] = idxs - symbol_to_variable[name] = arrsym - symbol_to_variable[rname] = arrsym - end end end @@ -102,16 +86,6 @@ function IndexCache(sys::AbstractSystem) push!(observed_syms, ttsym) push!(observed_syms, rsym) push!(observed_syms, rttsym) - if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex) - symbol_to_variable[getname(sym)] = eq.lhs - symbol_to_variable[getname(ttsym)] = eq.lhs - symbol_to_variable[getname(rsym)] = eq.lhs - symbol_to_variable[getname(rttsym)] = eq.lhs - push!(observed_syms, getname(sym)) - push!(observed_syms, getname(ttsym)) - push!(observed_syms, getname(rsym)) - push!(observed_syms, getname(rttsym)) - end end end @@ -143,16 +117,12 @@ function IndexCache(sys::AbstractSystem) rttinp = renamespace(sys, ttinp) is_parameter(sys, inp) || error("Discrete subsystem $i input $inp is not a parameter") + disc_clocks[inp] = i disc_clocks[ttinp] = i disc_clocks[rinp] = i disc_clocks[rttinp] = i - if hasname(inp) && (!iscall(inp) || operation(inp) !== getindex) - disc_clocks[getname(inp)] = i - disc_clocks[getname(ttinp)] = i - disc_clocks[getname(rinp)] = i - disc_clocks[getname(rttinp)] = i - end + insert_by_type!(disc_buffers[i], inp) end @@ -163,16 +133,12 @@ function IndexCache(sys::AbstractSystem) rttsym = renamespace(sys, ttsym) is_parameter(sys, sym) || error("Discrete subsystem $i unknown $sym is not a parameter") + disc_clocks[sym] = i disc_clocks[ttsym] = i disc_clocks[rsym] = i disc_clocks[rttsym] = i - if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex) - disc_clocks[getname(sym)] = i - disc_clocks[getname(ttsym)] = i - disc_clocks[getname(rsym)] = i - disc_clocks[getname(rttsym)] = i - end + insert_by_type!(disc_buffers[i], sym) end t = get_iv(sys) @@ -191,12 +157,6 @@ function IndexCache(sys::AbstractSystem) disc_clocks[ttsym] = i disc_clocks[rsym] = i disc_clocks[rttsym] = i - if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex) - disc_clocks[getname(sym)] = i - disc_clocks[getname(ttsym)] = i - disc_clocks[getname(rsym)] = i - disc_clocks[getname(rttsym)] = i - end end end @@ -237,13 +197,7 @@ function IndexCache(sys::AbstractSystem) disc_clocks[ttsym] = user_affect_clock disc_clocks[rsym] = user_affect_clock disc_clocks[rttsym] = user_affect_clock - if hasname(sym) && - (!iscall(sym) || operation(sym) !== getindex) - disc_clocks[getname(sym)] = user_affect_clock - disc_clocks[getname(ttsym)] = user_affect_clock - disc_clocks[getname(rsym)] = user_affect_clock - disc_clocks[getname(rttsym)] = user_affect_clock - end + buffer = get!(disc_buffers, user_affect_clock, Dict{Any, Set{BasicSymbolic}}()) insert_by_type!(buffer, affect.lhs) else @@ -259,12 +213,7 @@ function IndexCache(sys::AbstractSystem) disc_clocks[ttdisc] = user_affect_clock disc_clocks[rdisc] = user_affect_clock disc_clocks[rttdisc] = user_affect_clock - if hasname(disc) && (!iscall(disc) || operation(disc) !== getindex) - disc_clocks[getname(disc)] = user_affect_clock - disc_clocks[getname(ttdisc)] = user_affect_clock - disc_clocks[getname(rdisc)] = user_affect_clock - disc_clocks[getname(rttdisc)] = user_affect_clock - end + buffer = get!( disc_buffers, user_affect_clock, Dict{Any, Set{BasicSymbolic}}()) insert_by_type!(buffer, disc) @@ -316,10 +265,6 @@ function IndexCache(sys::AbstractSystem) for (j, sym) in enumerate(buffer[btype]) disc_idxs[sym] = (clockidx, i, j) disc_idxs[default_toterm(sym)] = (clockidx, i, j) - if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex) - disc_idxs[getname(sym)] = (clockidx, i, j) - disc_idxs[getname(default_toterm(sym))] = (clockidx, i, j) - end end end end @@ -327,10 +272,6 @@ function IndexCache(sys::AbstractSystem) haskey(disc_idxs, sym) && continue disc_idxs[sym] = (clockid, 0, 0) disc_idxs[default_toterm(sym)] = (clockid, 0, 0) - if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex) - disc_idxs[getname(sym)] = (clockid, 0, 0) - disc_idxs[getname(default_toterm(sym))] = (clockid, 0, 0) - end end function get_buffer_sizes_and_idxs(buffers::Dict{Any, Set{BasicSymbolic}}) @@ -345,16 +286,6 @@ function IndexCache(sys::AbstractSystem) idxs[ttp] = (i, j) idxs[rp] = (i, j) idxs[rttp] = (i, j) - if hasname(p) && (!iscall(p) || operation(p) !== getindex) - idxs[getname(p)] = (i, j) - idxs[getname(ttp)] = (i, j) - idxs[getname(rp)] = (i, j) - idxs[getname(rttp)] = (i, j) - symbol_to_variable[getname(p)] = p - symbol_to_variable[getname(ttp)] = p - symbol_to_variable[getname(rp)] = p - symbol_to_variable[getname(rttp)] = p - end end push!(buffer_sizes, BufferTemplate(T, length(buf))) end @@ -366,6 +297,14 @@ function IndexCache(sys::AbstractSystem) dependent_idxs, dependent_buffer_sizes = get_buffer_sizes_and_idxs(dependent_buffers) nonnumeric_idxs, nonnumeric_buffer_sizes = get_buffer_sizes_and_idxs(nonnumeric_buffers) + for sym in Iterators.flatten((keys(unk_idxs), keys(disc_idxs), keys(tunable_idxs), + keys(const_idxs), keys(dependent_idxs), keys(nonnumeric_idxs), + observed_syms, independent_variable_symbols(sys))) + if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex) + symbol_to_variable[getname(sym)] = sym + end + end + return IndexCache( unk_idxs, disc_idxs, @@ -384,18 +323,26 @@ function IndexCache(sys::AbstractSystem) end function SymbolicIndexingInterface.is_variable(ic::IndexCache, sym) - return check_index_map(ic.unknown_idx, sym) !== nothing -end - -function SymbolicIndexingInterface.is_variable(ic::IndexCache, sym::Symbol) + if sym isa Symbol + sym = get(ic.symbol_to_variable, sym, nothing) + sym === nothing && return false + end return check_index_map(ic.unknown_idx, sym) !== nothing end function SymbolicIndexingInterface.variable_index(ic::IndexCache, sym) + if sym isa Symbol + sym = get(ic.symbol_to_variable, sym, nothing) + sym === nothing && return nothing + end return check_index_map(ic.unknown_idx, sym) end function SymbolicIndexingInterface.is_parameter(ic::IndexCache, sym) + if sym isa Symbol + sym = get(ic.symbol_to_variable, sym, nothing) + sym === nothing && return false + end return check_index_map(ic.tunable_idx, sym) !== nothing || check_index_map(ic.discrete_idx, sym) !== nothing || check_index_map(ic.constant_idx, sym) !== nothing || @@ -405,7 +352,8 @@ end function SymbolicIndexingInterface.parameter_index(ic::IndexCache, sym) if sym isa Symbol - sym = ic.symbol_to_variable[sym] + sym = get(ic.symbol_to_variable, sym, nothing) + sym === nothing && return nothing end validate_size = Symbolics.isarraysymbolic(sym) && Symbolics.shape(sym) !== Symbolics.Unknown() @@ -425,10 +373,18 @@ function SymbolicIndexingInterface.parameter_index(ic::IndexCache, sym) end function SymbolicIndexingInterface.is_timeseries_parameter(ic::IndexCache, sym) + if sym isa Symbol + sym = get(ic.symbol_to_variable, sym, nothing) + sym === nothing && return false + end return check_index_map(ic.discrete_idx, sym) !== nothing end function SymbolicIndexingInterface.timeseries_parameter_index(ic::IndexCache, sym) + if sym isa Symbol + sym = get(ic.symbol_to_variable, sym, nothing) + sym === nothing && return nothing + end idx = check_index_map(ic.discrete_idx, sym) idx === nothing && return nothing clockid, partitionid... = idx