Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add FMUComponent #3282

Draft
wants to merge 41 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
85cebf1
feat: add `FMUComponent` with support for v2 ME FMUs
AayushSabharwal Dec 20, 2024
2409b35
test: add tests for FMIComponent
AayushSabharwal Dec 20, 2024
5a49b03
fix: fix `getcalledparameter` namespacing issues
AayushSabharwal Dec 23, 2024
02e0ff0
fix: handle scalarized called parameter in `vars!`
AayushSabharwal Dec 23, 2024
11b79c7
fix: run array variables hack on equations added by CSE hack
AayushSabharwal Dec 23, 2024
7f4cd27
fix: only use `OffsetArray` in array hack for non-standard `firstindex`
AayushSabharwal Dec 23, 2024
86a0bcd
fix: handle usage of FMU in initialization
AayushSabharwal Dec 23, 2024
f50893d
refactor: modularize code and enable array hacks
AayushSabharwal Dec 23, 2024
a98eec5
test: add test for component FMU hooked up to MTK model, initialization
AayushSabharwal Dec 23, 2024
403b92d
test: add `SimpleAdder.fmu` for testing
AayushSabharwal Dec 23, 2024
3d7fc9c
feat: allow `ImperativeAffect` to accept callable structs
AayushSabharwal Dec 24, 2024
edcbaeb
feat: add `vars!` for callbacks and affects
AayushSabharwal Dec 24, 2024
ae35c77
fix: fix `add_fallbacks!`
AayushSabharwal Dec 24, 2024
aca5bdb
fix: fix removal of `missing` defaults for solved array parameters
AayushSabharwal Dec 24, 2024
97bc830
fix: avoid trying to scalarize `missing` default of array parameters
AayushSabharwal Dec 24, 2024
9800e5d
fix: fix observed timeseries detection
AayushSabharwal Dec 24, 2024
ac6ae98
fix: fix conversion of solved array parameters to variables in initia…
AayushSabharwal Dec 24, 2024
0cb7d9c
fix: make array hack also search callbacks
AayushSabharwal Dec 24, 2024
f06d600
feat: support building v2 Co-Simulation FMU components
AayushSabharwal Dec 24, 2024
3d68f77
test: test v2 Co-Simulation FMU components
AayushSabharwal Dec 24, 2024
7d2647e
feat: support v3 ME FMUs
AayushSabharwal Jan 2, 2025
a5f0985
feat: support v3 CS FMUs
AayushSabharwal Jan 2, 2025
90eead1
fix: use `NoInit` for FMI callback reinitialization
AayushSabharwal Jan 3, 2025
2f15702
fix: handle array of symbolics in `InitializationProblem` type promotion
AayushSabharwal Jan 3, 2025
531f78e
test: test v3 FMUs as subcomponents
AayushSabharwal Jan 3, 2025
078d429
test: update `SimpleAdder.fmu`
AayushSabharwal Jan 3, 2025
3bea3aa
refactor: remove redundant fields from structs
AayushSabharwal Jan 3, 2025
d1d3839
docs: add documentation for all functions in FMIExt
AayushSabharwal Jan 3, 2025
2337cbb
build: bump SymbolicIndexingInterface compat
AayushSabharwal Jan 7, 2025
15a467e
feat: mark inputs and outputs of FMU with appropriate metadata
AayushSabharwal Jan 7, 2025
6cf8d98
fix: HACK: handle incorrect result from `linear_expansion` when using…
AayushSabharwal Jan 8, 2025
52457be
fix: fix array hack when looking through unknowns
AayushSabharwal Jan 8, 2025
77174da
fix: fix several bugs causing incorrect results in FMU components
AayushSabharwal Jan 8, 2025
640cdb2
test: update `SimpleAdder.fmu`
AayushSabharwal Jan 8, 2025
0040f08
test: update `StateSpace.fmu`
AayushSabharwal Jan 8, 2025
860f3c3
test: add more comprehensive tests for FMIComponent
AayushSabharwal Jan 8, 2025
08e93c7
test: add instructions for reproducibly building `SimpleAdder.fmu`
AayushSabharwal Jan 8, 2025
7c15a45
test: add instructions for reproducibly building StateSpace.fmu
AayushSabharwal Jan 8, 2025
d3c9d36
build: bump SymbolicUtils compat
AayushSabharwal Jan 16, 2025
dc3e2dc
fix: fix type promotion in `InitializationProblem`
AayushSabharwal Jan 20, 2025
2e14e0b
test: fix CSE hack test
AayushSabharwal Jan 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
BifurcationKit = "0f109fa4-8a5d-4b75-95aa-f515264e7665"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6"
FMI = "14a09403-18e3-468f-ad8a-74f8dda2d9ac"
HomotopyContinuation = "f213a82b-91d6-5c5d-acf7-10f1c761b327"
InfiniteOpt = "20393b10-9daf-11e9-18c9-8db751c92c57"
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
Expand All @@ -72,6 +73,7 @@ LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
MTKBifurcationKitExt = "BifurcationKit"
MTKChainRulesCoreExt = "ChainRulesCore"
MTKDeepDiffsExt = "DeepDiffs"
MTKFMIExt = "FMI"
MTKHomotopyContinuationExt = "HomotopyContinuation"
MTKInfiniteOptExt = "InfiniteOpt"
MTKLabelledArraysExt = "LabelledArrays"
Expand Down Expand Up @@ -106,6 +108,7 @@ FindFirstFunctions = "1"
ForwardDiff = "0.10.3"
FunctionWrappers = "1.1"
FunctionWrappersWrappers = "0.1"
FMI = "0.14"
Graphs = "1.5.2"
HomotopyContinuation = "2.11"
InfiniteOpt = "0.5"
Expand Down Expand Up @@ -142,9 +145,9 @@ SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
StaticArrays = "0.10, 0.11, 0.12, 1.0"
StochasticDiffEq = "6.72.1"
StochasticDelayDiffEq = "1.8.1"
SymbolicIndexingInterface = "0.3.36"
SymbolicUtils = "3.10"
Symbolics = "6.22.1"
SymbolicIndexingInterface = "0.3.37"
SymbolicUtils = "3.10.1"
Symbolics = "6.23"
URIs = "1"
UnPack = "0.1, 1.0"
Unitful = "1.1"
Expand Down
930 changes: 930 additions & 0 deletions ext/MTKFMIExt.jl

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -297,4 +297,6 @@ export HomotopyContinuationProblem
export AnalysisPoint, get_sensitivity_function, get_comp_sensitivity_function,
get_looptransfer_function, get_sensitivity, get_comp_sensitivity, get_looptransfer,
open_loop
function FMIComponent end

end # module
7 changes: 6 additions & 1 deletion src/parameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,12 @@ end

function getcalledparameter(x)
x = unwrap(x)
return getmetadata(x, CallWithParent)
# `parent` is a `CallWithMetadata` with the correct metadata,
# but no namespacing. `operation(x)` has the correct namespacing,
# but is not a `CallWithMetadata` and doesn't have any metadata.
# This approach combines both.
parent = getmetadata(x, CallWithParent)
return CallWithMetadata(operation(x), metadata(parent))
end

"""
Expand Down
19 changes: 13 additions & 6 deletions src/structural_transformation/symbolics_tearing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
@set! sys.unknowns = unknowns

obs, subeqs, deps = cse_and_array_hacks(
obs, subeqs, unknowns, neweqs; cse = cse_hack, array = array_hack)
sys, obs, subeqs, unknowns, neweqs; cse = cse_hack, array = array_hack)

@set! sys.eqs = neweqs
@set! sys.observed = obs
Expand Down Expand Up @@ -627,7 +627,7 @@ if all `p[i]` are present and the unscalarized form is used in any equation (obs
not) we first count the number of times the scalarized form of each observed variable
occurs in observed equations (and unknowns if it's split).
"""
function cse_and_array_hacks(obs, subeqs, unknowns, neweqs; cse = true, array = true)
function cse_and_array_hacks(sys, obs, subeqs, unknowns, neweqs; cse = true, array = true)
# HACK 1
# mapping of rhs to temporary CSE variable
# `f(...) => tmpvar` in above example
Expand Down Expand Up @@ -696,6 +696,7 @@ function cse_and_array_hacks(obs, subeqs, unknowns, neweqs; cse = true, array =
tempvar; T = Symbolics.symtype(rhs_arr)))
tempvar = setmetadata(
tempvar, Symbolics.ArrayShapeCtx, Symbolics.shape(rhs_arr))
vars!(all_vars, rhs_arr)
tempeq = tempvar ~ rhs_arr
rhs_to_tempvar[rhs_arr] = tempvar
push!(obs, tempeq)
Expand All @@ -718,12 +719,16 @@ function cse_and_array_hacks(obs, subeqs, unknowns, neweqs; cse = true, array =
Symbolics.shape(sym) != Symbolics.Unknown() || continue
arg1 = arguments(sym)[1]
cnt = get(arr_obs_occurrences, arg1, 0)
cnt == 0 && continue
arr_obs_occurrences[arg1] = cnt + 1
end
for eq in neweqs
vars!(all_vars, eq.rhs)
end

# also count unscalarized variables used in callbacks
for ev in Iterators.flatten((continuous_events(sys), discrete_events(sys)))
vars!(all_vars, ev)
end
obs_arr_eqs = Equation[]
for (arrvar, cnt) in arr_obs_occurrences
cnt == length(arrvar) || continue
Expand All @@ -737,7 +742,9 @@ function cse_and_array_hacks(obs, subeqs, unknowns, neweqs; cse = true, array =
# try to `create_array(OffsetArray{...}, ...)` which errors.
# `term(Origin(firstind), scal)` doesn't retain the `symtype` and `size`
# of `scal`.
push!(obs_arr_eqs, arrvar ~ change_origin(Origin(firstind), scal))
rhs = scal
rhs = change_origin(firstind, rhs)
push!(obs_arr_eqs, arrvar ~ rhs)
end
append!(obs, obs_arr_eqs)
append!(subeqs, obs_arr_eqs)
Expand All @@ -764,10 +771,10 @@ getindex_wrapper(x, i) = x[i...]

# PART OF HACK 2
function change_origin(origin, arr)
return origin(arr)
return Origin(origin)(arr)
end

@register_array_symbolic change_origin(origin::Origin, arr::AbstractArray) begin
@register_array_symbolic change_origin(origin::Any, arr::AbstractArray) begin
size = size(arr)
eltype = eltype(arr)
ndims = ndims(arr)
Expand Down
7 changes: 7 additions & 0 deletions src/structural_transformation/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -218,11 +218,18 @@ function find_eq_solvables!(state::TearingState, ieq, to_rm = Int[], coeffs = no
all_int_vars = true
coeffs === nothing || empty!(coeffs)
empty!(to_rm)

vars_buffer = Set()
for j in 𝑠neighbors(graph, ieq)
var = fullvars[j]
isirreducible(var) && (all_int_vars = false; continue)
a, b, islinear = linear_expansion(term, var)
a, b = unwrap(a), unwrap(b)
vars!(vars_buffer, b)
if islinear && isequal(a, 0) && var in vars_buffer
islinear = false
end
empty!(vars_buffer)
islinear || (all_int_vars = false; continue)
a = ModelingToolkit.fold_constants(a)
b = ModelingToolkit.fold_constants(b)
Expand Down
45 changes: 45 additions & 0 deletions src/systems/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,13 @@ function has_functional_affect(cb)
(affects(cb) isa FunctionalAffect || affects(cb) isa ImperativeAffect)
end

function vars!(vars, aff::FunctionalAffect; op = Differential)
for var in Iterators.flatten((unknowns(aff), parameters(aff), discretes(aff)))
vars!(vars, var)
end
return vars
end

#################################### continuous events #####################################

const NULL_AFFECT = Equation[]
Expand Down Expand Up @@ -333,6 +340,22 @@ function continuous_events(sys::AbstractSystem)
filter(!isempty, cbs)
end

function vars!(vars, cb::SymbolicContinuousCallback; op = Differential)
for eq in equations(cb)
vars!(vars, eq; op)
end
for aff in (affects(cb), affect_negs(cb), initialize_affects(cb), finalize_affects(cb))
if aff isa Vector{Equation}
for eq in aff
vars!(vars, eq; op)
end
elseif aff !== nothing
vars!(vars, aff; op)
end
end
return vars
end

#################################### discrete events #####################################

struct SymbolicDiscreteCallback
Expand Down Expand Up @@ -469,6 +492,28 @@ function discrete_events(sys::AbstractSystem)
cbs
end

function vars!(vars, cb::SymbolicDiscreteCallback; op = Differential)
if symbolic_type(cb.condition) == NotSymbolic
if cb.condition isa AbstractArray
for eq in cb.condition
vars!(vars, eq; op)
end
end
else
vars!(vars, cb.condition; op)
end
for aff in (cb.affects, cb.initialize, cb.finalize)
if aff isa Vector{Equation}
for eq in aff
vars!(vars, eq; op)
end
elseif aff !== nothing
vars!(vars, aff; op)
end
end
return vars
end

################################# compilation functions ####################################

# handles ensuring that affect! functions work with integrator arguments
Expand Down
14 changes: 8 additions & 6 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1391,14 +1391,16 @@ function InitializationProblem{iip, specialize}(sys::AbstractSystem,
fullmap = merge(u0map, parammap)
u0T = Union{}
for sym in unknowns(isys)
haskey(fullmap, sym) || continue
symbolic_type(fullmap[sym]) == NotSymbolic() || continue
u0T = promote_type(u0T, typeof(fullmap[sym]))
val = fixpoint_sub(sym, fullmap)
symbolic_type(val) == NotSymbolic() || continue
u0T = promote_type(u0T, typeof(val))
end
for eq in observed(isys)
haskey(fullmap, eq.lhs) || continue
symbolic_type(fullmap[eq.lhs]) == NotSymbolic() || continue
u0T = promote_type(u0T, typeof(fullmap[eq.lhs]))
# ignore HACK-ed observed equations
symbolic_type(eq.lhs) == ArraySymbolic() && continue
val = fixpoint_sub(eq.lhs, fullmap)
symbolic_type(val) == NotSymbolic() || continue
u0T = promote_type(u0T, typeof(val))
end
if u0T != Union{}
u0T = eltype(u0T)
Expand Down
25 changes: 21 additions & 4 deletions src/systems/imperative_affect.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ in the returned tuple, in which case the associated field will not be updated.
skip_checks::Bool
end

function ImperativeAffect(f::Function;
function ImperativeAffect(f;
observed::NamedTuple = NamedTuple{()}(()),
modified::NamedTuple = NamedTuple{()}(()),
ctx = nothing,
Expand All @@ -48,18 +48,18 @@ function ImperativeAffect(f::Function;
collect(values(modified)), collect(keys(modified)),
ctx, skip_checks)
end
function ImperativeAffect(f::Function, modified::NamedTuple;
function ImperativeAffect(f, modified::NamedTuple;
observed::NamedTuple = NamedTuple{()}(()), ctx = nothing, skip_checks = false)
ImperativeAffect(
f, observed = observed, modified = modified, ctx = ctx, skip_checks = skip_checks)
end
function ImperativeAffect(
f::Function, modified::NamedTuple, observed::NamedTuple; ctx = nothing, skip_checks = false)
f, modified::NamedTuple, observed::NamedTuple; ctx = nothing, skip_checks = false)
ImperativeAffect(
f, observed = observed, modified = modified, ctx = ctx, skip_checks = skip_checks)
end
function ImperativeAffect(
f::Function, modified::NamedTuple, observed::NamedTuple, ctx; skip_checks = false)
f, modified::NamedTuple, observed::NamedTuple, ctx; skip_checks = false)
ImperativeAffect(
f, observed = observed, modified = modified, ctx = ctx, skip_checks = skip_checks)
end
Expand Down Expand Up @@ -216,3 +216,20 @@ function compile_user_affect(affect::ImperativeAffect, cb, sys, dvs, ps; kwargs.
end

scalarize_affects(affects::ImperativeAffect) = affects

function vars!(vars, aff::ImperativeAffect; op = Differential)
for var in Iterators.flatten((observed(aff), modified(aff)))
if symbolic_type(var) == NotSymbolic()
if var isa AbstractArray
for v in var
v = unwrap(v)
vars!(vars, v)
end
end
else
var = unwrap(var)
vars!(vars, var)
end
end
return vars
end
3 changes: 3 additions & 0 deletions src/systems/index_cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,9 @@ function IndexCache(sys::AbstractSystem)
for v in vs
if (idx = get(disc_idxs, v, nothing)) !== nothing
push!(timeseries, idx.clock_idx)
elseif iscall(v) && operation(v) === getindex &&
(idx = get(disc_idxs, arguments(v)[1], nothing)) !== nothing
push!(timeseries, idx.clock_idx)
elseif haskey(observed_syms_to_timeseries, v)
union!(timeseries, observed_syms_to_timeseries[v])
elseif haskey(dependent_pars_to_timeseries, v)
Expand Down
9 changes: 9 additions & 0 deletions src/systems/nonlinear/initializesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,15 @@ function generate_initializesystem(sys::AbstractSystem;
append!(eqs_ics, trueobs)
end

# even if `p => tovar(p)` is in `paramsubs`, `isparameter(p[1]) === true` after substitution
# so add scalarized versions as well
for k in collect(keys(paramsubs))
symbolic_type(k) == ArraySymbolic() || continue
for i in eachindex(k)
paramsubs[k[i]] = paramsubs[k][i]
end
end

eqs_ics = Symbolics.substitute.(eqs_ics, (paramsubs,))
if is_time_dependent(sys)
vars = [vars; collect(values(paramsubs))]
Expand Down
21 changes: 19 additions & 2 deletions src/systems/problem_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,18 @@ function symbols_to_symbolics!(sys::AbstractSystem, varmap::AbstractDict)
end
end

"""
$(TYPEDSIGNATURES)

Utility function to get the value `val` corresponding to key `var` in `varmap`, and
return `getindex(val, idx)` if it exists or `nothing` otherwise.
"""
function get_and_getindex(varmap, var, idx)
val = get(varmap, var, nothing)
val === nothing && return nothing
return val[idx]
end

"""
$(TYPEDSIGNATURES)

Expand Down Expand Up @@ -115,8 +127,9 @@ function add_fallbacks!(
val = map(eachindex(var)) do idx
# @something is lazy and saves from writing a massive if-elseif-else
@something(get(varmap, var[idx], nothing),
get(varmap, ttvar[idx], nothing), get(fallbacks, var, nothing)[idx],
get(fallbacks, ttvar, nothing)[idx], get(fallbacks, var[idx], nothing),
get(varmap, ttvar[idx], nothing), get_and_getindex(fallbacks, var, idx),
get_and_getindex(fallbacks, ttvar, idx), get(
fallbacks, var[idx], nothing),
get(fallbacks, ttvar[idx], nothing), Some(nothing))
end
# only push the missing entries
Expand Down Expand Up @@ -578,6 +591,10 @@ function maybe_build_initialization_problem(
p = unwrap(p)
stype = symtype(p)
op[p] = get_temporary_value(p)
if iscall(p) && operation(p) === getindex
arrp = arguments(p)[1]
op[arrp] = collect(arrp)
end
end

if is_time_dependent(sys)
Expand Down
1 change: 1 addition & 0 deletions src/systems/systems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ function structural_simplify(
ks = collect(keys(defs)) # take copy to avoid mutating defs while iterating.
for k in ks
if Symbolics.isarraysymbolic(k) && Symbolics.shape(k) !== Symbolics.Unknown()
defs[k] === missing && continue
for i in eachindex(k)
defs[k[i]] = defs[k][i]
end
Expand Down
3 changes: 3 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,9 @@ function vars!(vars, eq::Equation; op = Differential)
end
function vars!(vars, O; op = Differential)
if isvariable(O)
if iscall(O) && operation(O) === getindex && iscalledparameter(first(arguments(O)))
O = first(arguments(O))
end
if iscalledparameter(O)
f = getcalledparameter(O)
push!(vars, f)
Expand Down
2 changes: 2 additions & 0 deletions test/extensions/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
BifurcationKit = "0f109fa4-8a5d-4b75-95aa-f515264e7665"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
FMI = "14a09403-18e3-468f-ad8a-74f8dda2d9ac"
FMIZoo = "724179cf-c260-40a9-bd27-cccc6fe2f195"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
HomotopyContinuation = "f213a82b-91d6-5c5d-acf7-10f1c761b327"
InfiniteOpt = "20393b10-9daf-11e9-18c9-8db751c92c57"
Expand Down
Loading
Loading