Skip to content

Commit 3503046

Browse files
Merge pull request #3717 from AayushSabharwal/as/linear-problem
feat: add `LinearProblem` codegen
2 parents e8a3aae + 8a299e0 commit 3503046

File tree

13 files changed

+481
-22
lines changed

13 files changed

+481
-22
lines changed

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ LabelledArrays = "1.3"
129129
Latexify = "0.11, 0.12, 0.13, 0.14, 0.15, 0.16"
130130
Libdl = "1"
131131
LinearAlgebra = "1"
132+
LinearSolve = "3"
132133
Logging = "1"
133134
MLStyle = "0.4.17"
134135
ModelingToolkitStandardLibrary = "2.20"
@@ -148,7 +149,7 @@ RecursiveArrayTools = "3.26"
148149
Reexport = "0.2, 1"
149150
RuntimeGeneratedFunctions = "0.5.9"
150151
SCCNonlinearSolve = "1.0.0"
151-
SciMLBase = "2.91.1"
152+
SciMLBase = "2.100.0"
152153
SciMLPublic = "1.0.0"
153154
SciMLStructures = "1.7"
154155
Serialization = "1"
@@ -180,6 +181,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
180181
Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9"
181182
Ipopt_jll = "9cc047cb-c261-5740-88fc-0cf96f7bdcc7"
182183
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
184+
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
183185
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
184186
ModelingToolkitStandardLibrary = "16a59e39-deab-5bd0-87e4-056b12336739"
185187
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
@@ -205,4 +207,4 @@ Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
205207
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
206208

207209
[targets]
208-
test = ["AmplNLWriter", "BenchmarkTools", "BoundaryValueDiffEqMIRK", "BoundaryValueDiffEqAscher", "ControlSystemsBase", "DataInterpolations", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "OrdinaryDiffEqCore", "OrdinaryDiffEqDefault", "REPL", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET", "OrdinaryDiffEqNonlinearSolve", "Logging", "OptimizationBase"]
210+
test = ["AmplNLWriter", "BenchmarkTools", "BoundaryValueDiffEqMIRK", "BoundaryValueDiffEqAscher", "ControlSystemsBase", "DataInterpolations", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "OrdinaryDiffEqCore", "OrdinaryDiffEqDefault", "REPL", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET", "OrdinaryDiffEqNonlinearSolve", "Logging", "OptimizationBase", "LinearSolve"]

docs/src/API/codegen.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ ModelingToolkit.generate_constraint_hessian
2121
ModelingToolkit.generate_control_jacobian
2222
ModelingToolkit.build_explicit_observed_function
2323
ModelingToolkit.generate_control_function
24+
ModelingToolkit.generate_update_A
25+
ModelingToolkit.generate_update_b
2426
```
2527

2628
For functions such as jacobian calculation which require symbolic computation, there
@@ -42,6 +44,7 @@ ModelingToolkit.cost_hessian_sparsity
4244
ModelingToolkit.calculate_constraint_jacobian
4345
ModelingToolkit.calculate_constraint_hessian
4446
ModelingToolkit.calculate_control_jacobian
47+
ModelingToolkit.calculate_A_b
4548
```
4649

4750
All code generation eventually calls `build_function_wrapper`.

docs/src/API/problems.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ SciMLBase.DiscreteProblem
2929
SciMLBase.ImplicitDiscreteProblem
3030
```
3131

32-
## Nonlinear systems
32+
## Linear and Nonlinear systems
3333

3434
```@docs
3535
SciMLBase.NonlinearFunction
@@ -41,6 +41,7 @@ SciMLBase.IntervalNonlinearFunction
4141
SciMLBase.IntervalNonlinearProblem
4242
ModelingToolkit.HomotopyContinuationProblem
4343
SciMLBase.HomotopyNonlinearFunction
44+
SciMLBase.LinearProblem
4445
```
4546

4647
## Optimization and optimal control

src/ModelingToolkit.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ include("problems/jumpproblem.jl")
188188
include("problems/initializationproblem.jl")
189189
include("problems/sccnonlinearproblem.jl")
190190
include("problems/bvproblem.jl")
191+
include("problems/linearproblem.jl")
191192

192193
include("modelingtoolkitize/common.jl")
193194
include("modelingtoolkitize/odeproblem.jl")

src/problems/compatibility.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,3 +169,12 @@ function check_no_equations(sys::System, T)
169169
"""))
170170
end
171171
end
172+
173+
function check_affine(sys::System, T)
174+
if !isaffine(sys)
175+
throw(SystemCompatibilityError("""
176+
A non-affine system cannot be used to construct a `$T`. Consider a
177+
`NonlinearProblem` instead.
178+
"""))
179+
end
180+
end

src/problems/docs.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,3 +391,32 @@ $PROBLEM_INTERNALS_HEADER
391391
392392
$PROBLEM_INTERNAL_KWARGS
393393
""" SciMLBase.IntervalNonlinearProblem
394+
395+
@doc """
396+
SciMLBase.LinearProblem(sys::System, op; kwargs...)
397+
SciMLBase.LinearProblem{iip}(sys::System, op; kwargs...)
398+
399+
Build a `LinearProblem` given a system `sys` and operating point `op`. `iip` is a boolean
400+
indicating whether the problem should be in-place. The operating point should be an
401+
iterable collection of key-value pairs mapping variables/parameters in the system to the
402+
(initial) values they should take in `LinearProblem`. Any values not provided will
403+
fallback to the corresponding default (if present).
404+
405+
Note that since `u0` is optional for `LinearProblem`, values of unknowns do not need to be
406+
specified in `op` to create a `LinearProblem`. In such a case, `prob.u0` will be `nothing`
407+
and attempting to symbolically index the problem with an unknown, observable, or expression
408+
depending on unknowns/observables will error.
409+
410+
Updating the parameters automatically updates the `A` and `b` arrays.
411+
412+
# Keyword arguments
413+
414+
$PROBLEM_KWARGS
415+
$(prob_fun_common_kwargs(LinearProblem, false))
416+
417+
All other keyword arguments are forwarded to the $func constructor.
418+
419+
$PROBLEM_INTERNALS_HEADER
420+
421+
$PROBLEM_INTERNAL_KWARGS
422+
""" SciMLBase.LinearProblem

src/problems/linearproblem.jl

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
function SciMLBase.LinearProblem(sys::System, op; kwargs...)
2+
SciMLBase.LinearProblem{true}(sys, op; kwargs...)
3+
end
4+
5+
function SciMLBase.LinearProblem(sys::System, op::StaticArray; kwargs...)
6+
SciMLBase.LinearProblem{false}(sys, op; kwargs...)
7+
end
8+
9+
function SciMLBase.LinearProblem{iip}(
10+
sys::System, op; check_length = true, expression = Val{false},
11+
check_compatibility = true, sparse = false, eval_expression = false,
12+
eval_module = @__MODULE__, checkbounds = false, cse = true,
13+
u0_constructor = identity, u0_eltype = nothing, kwargs...) where {iip}
14+
check_complete(sys, LinearProblem)
15+
check_compatibility && check_compatible_system(LinearProblem, sys)
16+
17+
_, u0, p = process_SciMLProblem(
18+
EmptySciMLFunction{iip}, sys, op; check_length, expression,
19+
build_initializeprob = false, symbolic_u0 = true, u0_constructor, u0_eltype,
20+
kwargs...)
21+
22+
if any(x -> symbolic_type(x) != NotSymbolic(), u0)
23+
u0 = nothing
24+
end
25+
26+
u0Type = typeof(op)
27+
floatT = if u0 === nothing
28+
calculate_float_type(op, u0Type)
29+
else
30+
eltype(u0)
31+
end
32+
u0_eltype = something(u0_eltype, floatT)
33+
34+
u0_constructor = get_p_constructor(u0_constructor, u0Type, u0_eltype)
35+
36+
A, b = calculate_A_b(sys; sparse)
37+
update_A = generate_update_A(sys, A; expression, wrap_gfw = Val{true}, eval_expression,
38+
eval_module, checkbounds, cse, kwargs...)
39+
update_b = generate_update_b(sys, b; expression, wrap_gfw = Val{true}, eval_expression,
40+
eval_module, checkbounds, cse, kwargs...)
41+
observedfun = ObservedFunctionCache(
42+
sys; steady_state = false, expression, eval_expression, eval_module, checkbounds,
43+
cse)
44+
45+
if expression == Val{true}
46+
symbolic_interface = quote
47+
update_A = $update_A
48+
update_b = $update_b
49+
sys = $sys
50+
observedfun = $observedfun
51+
$(SciMLBase.SymbolicLinearInterface)(
52+
update_A, update_b, sys, observedfun, nothing)
53+
end
54+
get_A = build_explicit_observed_function(
55+
sys, A; param_only = true, eval_expression, eval_module)
56+
if sparse
57+
get_A = SparseArrays.sparse get_A
58+
end
59+
get_b = build_explicit_observed_function(
60+
sys, b; param_only = true, eval_expression, eval_module)
61+
A = u0_constructor(get_A(p))
62+
b = u0_constructor(get_b(p))
63+
else
64+
symbolic_interface = SciMLBase.SymbolicLinearInterface(
65+
update_A, update_b, sys, observedfun, nothing)
66+
A = u0_constructor(update_A(p))
67+
b = u0_constructor(update_b(p))
68+
end
69+
70+
kwargs = (; u0, process_kwargs(sys; kwargs...)..., f = symbolic_interface)
71+
args = (; A, b, p)
72+
73+
return maybe_codegen_scimlproblem(expression, LinearProblem{iip}, args; kwargs...)
74+
end
75+
76+
# For remake
77+
function SciMLBase.get_new_A_b(
78+
sys::AbstractSystem, f::SciMLBase.SymbolicLinearInterface, p, A, b; kw...)
79+
if ArrayInterface.ismutable(A)
80+
f.update_A!(A, p)
81+
f.update_b!(b, p)
82+
else
83+
# The generated function has both IIP and OOP variants
84+
A = StaticArraysCore.similar_type(A)(f.update_A!(p))
85+
b = StaticArraysCore.similar_type(b)(f.update_b!(p))
86+
end
87+
return A, b
88+
end
89+
90+
function check_compatible_system(T::Type{LinearProblem}, sys::System)
91+
check_time_independent(sys, T)
92+
check_affine(sys, T)
93+
check_not_dde(sys)
94+
check_no_cost(sys, T)
95+
check_no_constraints(sys, T)
96+
check_no_jumps(sys, T)
97+
check_no_noise(sys, T)
98+
end

src/systems/abstractsystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1780,13 +1780,13 @@ function preface(sys::AbstractSystem)
17801780
end
17811781

17821782
function islinear(sys::AbstractSystem)
1783-
rhs = [eq.rhs for eq in equations(sys)]
1783+
rhs = [eq.rhs for eq in full_equations(sys)]
17841784

17851785
all(islinear(r, unknowns(sys)) for r in rhs)
17861786
end
17871787

17881788
function isaffine(sys::AbstractSystem)
1789-
rhs = [eq.rhs for eq in equations(sys)]
1789+
rhs = [eq.rhs for eq in full_equations(sys)]
17901790

17911791
all(isaffine(r, unknowns(sys)) for r in rhs)
17921792
end

src/systems/codegen.jl

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1130,3 +1130,90 @@ function build_explicit_observed_function(sys, ts;
11301130
return f
11311131
end
11321132
end
1133+
1134+
"""
1135+
$(TYPEDSIGNATURES)
1136+
1137+
Return matrix `A` and vector `b` such that the system `sys` can be represented as
1138+
`A * x = b` where `x` is `unknowns(sys)`. Errors if the system is not affine.
1139+
1140+
# Keyword arguments
1141+
1142+
- `sparse`: return a sparse `A`.
1143+
"""
1144+
function calculate_A_b(sys::System; sparse = false)
1145+
rhss = [eq.rhs for eq in full_equations(sys)]
1146+
dvs = unknowns(sys)
1147+
1148+
A = Matrix{Any}(undef, length(rhss), length(dvs))
1149+
b = Vector{Any}(undef, length(rhss))
1150+
for (i, rhs) in enumerate(rhss)
1151+
# mtkcompile makes this `0 ~ rhs` which typically ends up giving
1152+
# unknowns negative coefficients. If given the equations `A * x ~ b`
1153+
# it will simplify to `0 ~ b - A * x`. Thus this negation usually leads
1154+
# to more comprehensible user API.
1155+
resid = -rhs
1156+
for (j, var) in enumerate(dvs)
1157+
p, q, islinear = Symbolics.linear_expansion(resid, var)
1158+
if !islinear
1159+
throw(ArgumentError("System is not linear. Equation $((0 ~ rhs)) is not linear in unknown $var."))
1160+
end
1161+
A[i, j] = p
1162+
resid = q
1163+
end
1164+
# negate beucause `resid` is the residual on the LHS
1165+
b[i] = -resid
1166+
end
1167+
1168+
@assert all(Base.Fix1(isassigned, A), eachindex(A))
1169+
@assert all(Base.Fix1(isassigned, A), eachindex(b))
1170+
1171+
if sparse
1172+
A = SparseArrays.sparse(A)
1173+
end
1174+
return A, b
1175+
end
1176+
1177+
"""
1178+
$(TYPEDSIGNATURES)
1179+
1180+
Given a system `sys` and the `A` from [`calculate_A_b`](@ref) generate the function that
1181+
updates `A` given the parameter object.
1182+
1183+
# Keyword arguments
1184+
1185+
$GENERATE_X_KWARGS
1186+
1187+
All other keyword arguments are forwarded to [`build_function_wrapper`](@ref).
1188+
"""
1189+
function generate_update_A(sys::System, A::AbstractMatrix; expression = Val{true},
1190+
wrap_gfw = Val{false}, eval_expression = false, eval_module = @__MODULE__, kwargs...)
1191+
ps = reorder_parameters(sys)
1192+
1193+
res = build_function_wrapper(sys, A, ps...; p_start = 1, expression = Val{true},
1194+
similarto = typeof(A), kwargs...)
1195+
return maybe_compile_function(expression, wrap_gfw, (1, 1, is_split(sys)), res;
1196+
eval_expression, eval_module)
1197+
end
1198+
1199+
"""
1200+
$(TYPEDSIGNATURES)
1201+
1202+
Given a system `sys` and the `b` from [`calculate_A_b`](@ref) generate the function that
1203+
updates `b` given the parameter object.
1204+
1205+
# Keyword arguments
1206+
1207+
$GENERATE_X_KWARGS
1208+
1209+
All other keyword arguments are forwarded to [`build_function_wrapper`](@ref).
1210+
"""
1211+
function generate_update_b(sys::System, b::AbstractVector; expression = Val{true},
1212+
wrap_gfw = Val{false}, eval_expression = false, eval_module = @__MODULE__, kwargs...)
1213+
ps = reorder_parameters(sys)
1214+
1215+
res = build_function_wrapper(sys, b, ps...; p_start = 1, expression = Val{true},
1216+
similarto = typeof(b), kwargs...)
1217+
return maybe_compile_function(expression, wrap_gfw, (1, 1, is_split(sys)), res;
1218+
eval_expression, eval_module)
1219+
end

src/systems/nonlinear/initializesystem.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -727,6 +727,7 @@ function SciMLBase.late_binding_update_u0_p(
727727
prob, sys::AbstractSystem, u0, p, t0, newu0, newp)
728728
supports_initialization(sys) || return newu0, newp
729729
prob isa IntervalNonlinearProblem && return newu0, newp
730+
prob isa LinearProblem && return newu0, newp
730731

731732
initdata = prob.f.initialization_data
732733
meta = initdata === nothing ? nothing : initdata.metadata

0 commit comments

Comments
 (0)