Solving Optimal Control Problems with Symbolic Universal Differential Equations
This tutorial is based on SciMLSensitivity.jl tutorial. Instead of using a classical NN architecture, here we will combine the NN with a symbolic expression from DynamicExpressions.jl (the symbolic engine behind SymbolicRegression.jl and PySR).
Here we will solve a classic optimal control problem with a universal differential equation. Let
where we want to optimize our controller
where
and thus
is our loss function on the first order system. We thus choose a neural network form for
Package Imports
using Lux,
Boltz,
ComponentArrays,
OrdinaryDiffEqVerner,
Optimization,
OptimizationOptimJL,
OptimizationOptimisers,
SciMLSensitivity,
Statistics,
Printf,
Random
using DynamicExpressions, SymbolicRegression, MLJ, SymbolicUtils, Latexify
using CairoMakie
Precompiling OrdinaryDiffEqVerner...
9588.1 ms ✓ SciMLBase
5094.2 ms ✓ DiffEqBase
3656.9 ms ✓ OrdinaryDiffEqCore
1189.0 ms ✓ OrdinaryDiffEqCore → OrdinaryDiffEqCoreEnzymeCoreExt
34270.4 ms ✓ OrdinaryDiffEqVerner
5 dependencies successfully precompiled in 54 seconds. 95 already precompiled.
Precompiling ComponentArraysSciMLBaseExt...
809.9 ms ✓ SciMLBase → SciMLBaseChainRulesCoreExt
929.1 ms ✓ ComponentArrays → ComponentArraysSciMLBaseExt
2 dependencies successfully precompiled in 1 seconds. 88 already precompiled.
Precompiling DiffEqBaseForwardDiffExt...
1359.8 ms ✓ DiffEqBase → DiffEqBaseForwardDiffExt
1 dependency successfully precompiled in 2 seconds. 109 already precompiled.
Precompiling DiffEqBaseChainRulesCoreExt...
1165.5 ms ✓ DiffEqBase → DiffEqBaseChainRulesCoreExt
1 dependency successfully precompiled in 2 seconds. 98 already precompiled.
Precompiling Optimization...
1905.0 ms ✓ OptimizationBase
1731.5 ms ✓ Optimization
2 dependencies successfully precompiled in 4 seconds. 105 already precompiled.
Precompiling DiffEqBaseSparseArraysExt...
1226.2 ms ✓ DiffEqBase → DiffEqBaseSparseArraysExt
1 dependency successfully precompiled in 1 seconds. 101 already precompiled.
Precompiling OptimizationMLDataDevicesExt...
1187.1 ms ✓ OptimizationBase → OptimizationMLDataDevicesExt
1 dependency successfully precompiled in 1 seconds. 90 already precompiled.
Precompiling OptimizationOptimJL...
15459.0 ms ✓ OptimizationOptimJL
1 dependency successfully precompiled in 16 seconds. 142 already precompiled.
Precompiling OptimizationOptimisers...
1668.9 ms ✓ OptimizationOptimisers
1 dependency successfully precompiled in 2 seconds. 116 already precompiled.
Precompiling SciMLSensitivity...
3497.6 ms ✓ DiffEqBase → DiffEqBaseDistributionsExt
4568.3 ms ✓ SciMLJacobianOperators
6194.0 ms ✓ SciMLBase → SciMLBaseZygoteExt
3794.5 ms ✓ DiffEqBase → DiffEqBaseTrackerExt
8346.1 ms ✓ DiffEqBase → DiffEqBaseReverseDiffExt
7111.5 ms ✓ DiffEqCallbacks
6163.6 ms ✓ DiffEqNoiseProcess
6994.4 ms ✓ DiffEqNoiseProcess → DiffEqNoiseProcessReverseDiffExt
25119.1 ms ✓ LinearSolve
2288.0 ms ✓ LinearSolve → LinearSolveEnzymeExt
3954.2 ms ✓ LinearSolve → LinearSolveKernelAbstractionsExt
5486.6 ms ✓ LinearSolve → LinearSolveSparseArraysExt
215508.5 ms ✓ Enzyme
10148.5 ms ✓ Enzyme → EnzymeLogExpFunctionsExt
10304.1 ms ✓ Enzyme → EnzymeGPUArraysCoreExt
10901.6 ms ✓ Enzyme → EnzymeSpecialFunctionsExt
15029.7 ms ✓ Enzyme → EnzymeStaticArraysExt
18303.2 ms ✓ Enzyme → EnzymeChainRulesCoreExt
8709.7 ms ✓ FastPower → FastPowerEnzymeExt
9192.8 ms ✓ DifferentiationInterface → DifferentiationInterfaceEnzymeExt
8786.3 ms ✓ QuadGK → QuadGKEnzymeExt
7700.3 ms ✓ DiffEqBase → DiffEqBaseEnzymeExt
19112.9 ms ✓ SciMLSensitivity
23 dependencies successfully precompiled in 255 seconds. 253 already precompiled.
Precompiling LuxEnzymeExt...
6153.8 ms ✓ Lux → LuxEnzymeExt
1 dependency successfully precompiled in 7 seconds. 146 already precompiled.
Precompiling OptimizationEnzymeExt...
12100.3 ms ✓ OptimizationBase → OptimizationEnzymeExt
1 dependency successfully precompiled in 12 seconds. 112 already precompiled.
Precompiling BoltzTrackerExt...
2053.4 ms ✓ Boltz → BoltzTrackerExt
1 dependency successfully precompiled in 2 seconds. 127 already precompiled.
Precompiling BoltzReverseDiffExt...
3870.7 ms ✓ Boltz → BoltzReverseDiffExt
1 dependency successfully precompiled in 4 seconds. 128 already precompiled.
Precompiling OptimizationZygoteExt...
1923.4 ms ✓ OptimizationBase → OptimizationZygoteExt
1 dependency successfully precompiled in 2 seconds. 162 already precompiled.
Precompiling SymbolicRegressionEnzymeExt...
17101.4 ms ✓ SymbolicRegression → SymbolicRegressionEnzymeExt
1 dependency successfully precompiled in 17 seconds. 130 already precompiled.
Precompiling SciMLBaseMLStyleExt...
983.5 ms ✓ SciMLBase → SciMLBaseMLStyleExt
1 dependency successfully precompiled in 1 seconds. 61 already precompiled.
Precompiling OptimizationMLUtilsExt...
1755.7 ms ✓ OptimizationBase → OptimizationMLUtilsExt
1 dependency successfully precompiled in 2 seconds. 147 already precompiled.
Precompiling SymbolicUtils...
20008.2 ms ✓ SymbolicUtils
1 dependency successfully precompiled in 21 seconds. 80 already precompiled.
Precompiling DynamicExpressionsSymbolicUtilsExt...
1835.1 ms ✓ DynamicExpressions → DynamicExpressionsSymbolicUtilsExt
1 dependency successfully precompiled in 2 seconds. 85 already precompiled.
Precompiling SymbolicRegressionSymbolicUtilsExt...
3745.8 ms ✓ SymbolicRegression → SymbolicRegressionSymbolicUtilsExt
1 dependency successfully precompiled in 4 seconds. 146 already precompiled.
Precompiling SymbolicUtilsReverseDiffExt...
3757.0 ms ✓ SymbolicUtils → SymbolicUtilsReverseDiffExt
1 dependency successfully precompiled in 4 seconds. 92 already precompiled.
Precompiling DiffEqBaseUnitfulExt...
1190.3 ms ✓ DiffEqBase → DiffEqBaseUnitfulExt
1 dependency successfully precompiled in 1 seconds. 98 already precompiled.
Precompiling SciMLBaseMakieExt...
7374.0 ms ✓ SciMLBase → SciMLBaseMakieExt
1 dependency successfully precompiled in 9 seconds. 307 already precompiled.
Helper Functions
function plot_dynamics(sol, us, ts)
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel=L"t")
ylims!(ax, (-6, 6))
lines!(ax, ts, sol[1, :]; label=L"u_1(t)", linewidth=3)
lines!(ax, ts, sol[2, :]; label=L"u_2(t)", linewidth=3)
lines!(ax, ts, vec(us); label=L"u(t)", linewidth=3)
axislegend(ax; position=:rb)
return fig
end
plot_dynamics (generic function with 1 method)
Training a Neural Network based UDE
Let's setup the neural network. For the first part, we won't do any symbolic regression. We will plain and simple train a neural network to solve the optimal control problem.
rng = Xoshiro(0)
tspan = (0.0, 8.0)
mlp = Chain(Dense(1 => 4, gelu), Dense(4 => 4, gelu), Dense(4 => 1))
function construct_ude(mlp, solver; kwargs...)
return @compact(; mlp, solver, kwargs...) do x_in, ps
x, ts, ret_sol = x_in
function dudt(du, u, p, t)
u₁, u₂ = u
du[1] = u₂
du[2] = mlp([t], p)[1]^3
return nothing
end
prob = ODEProblem{true}(dudt, x, extrema(ts), ps.mlp)
sol = solve(
prob,
solver;
saveat=ts,
sensealg=QuadratureAdjoint(; autojacvec=ReverseDiffVJP(true)),
kwargs...,
)
us = mlp(reshape(ts, 1, :), ps.mlp)
ret_sol === Val(true) && @return sol, us
@return Array(sol), us
end
end
ude = construct_ude(mlp, Vern9(); abstol=1e-10, reltol=1e-10);
Here we are going to tuse the same configuration for testing, but this is to show that we can setup them up with different ode solve configurations
ude_test = construct_ude(mlp, Vern9(); abstol=1e-10, reltol=1e-10);
function train_model_1(ude, rng, ts_)
ps, st = Lux.setup(rng, ude)
ps = ComponentArray{Float64}(ps)
stateful_ude = StatefulLuxLayer{true}(ude, nothing, st)
ts = collect(ts_)
function loss_adjoint(θ)
x, us = stateful_ude(([-4.0, 0.0], ts, Val(false)), θ)
return mean(abs2, 4 .- x[1, :]) + 2 * mean(abs2, x[2, :]) + 0.1 * mean(abs2, us)
end
callback = function (state, l)
state.iter % 50 == 1 && @printf "Iteration: %5d\tLoss: %10g\n" state.iter l
return false
end
optf = OptimizationFunction((x, p) -> loss_adjoint(x), AutoZygote())
optprob = OptimizationProblem(optf, ps)
res1 = solve(optprob, Optimisers.Adam(0.001); callback, maxiters=500)
optprob = OptimizationProblem(optf, res1.u)
res2 = solve(optprob, LBFGS(); callback, maxiters=100)
return StatefulLuxLayer{true}(ude, res2.u, st)
end
trained_ude = train_model_1(ude, rng, 0.0:0.01:8.0)
┌ Warning: Lux.apply(m::AbstractLuxLayer, x::AbstractArray{<:ReverseDiff.TrackedReal}, ps, st) input was corrected to Lux.apply(m::AbstractLuxLayer, x::ReverseDiff.TrackedArray}, ps, st).
│
│ 1. If this was not the desired behavior overload the dispatch on `m`.
│
│ 2. This might have performance implications. Check which layer was causing this problem using `Lux.Experimental.@debug_mode`.
└ @ LuxCoreArrayInterfaceReverseDiffExt ~/.julia/packages/LuxCore/paKmr/ext/LuxCoreArrayInterfaceReverseDiffExt.jl:11
Iteration: 1 Loss: 40.5618
Iteration: 51 Loss: 29.4147
Iteration: 101 Loss: 28.2559
Iteration: 151 Loss: 27.217
Iteration: 201 Loss: 26.1657
Iteration: 251 Loss: 25.1631
Iteration: 301 Loss: 24.2914
Iteration: 351 Loss: 23.5965
Iteration: 401 Loss: 23.0763
Iteration: 451 Loss: 22.6983
Iteration: 1 Loss: 22.245
Iteration: 51 Loss: 12.0085
sol, us = ude_test(([-4.0, 0.0], 0.0:0.01:8.0, Val(true)), trained_ude.ps, trained_ude.st)[1];
plot_dynamics(sol, us, 0.0:0.01:8.0)
Now that the system is in a better behaved part of parameter space, we return to the original loss function to finish the optimization:
function train_model_2(stateful_ude::StatefulLuxLayer, ts_)
ts = collect(ts_)
function loss_adjoint(θ)
x, us = stateful_ude(([-4.0, 0.0], ts, Val(false)), θ)
return mean(abs2, 4 .- x[1, :]) .+ 2 * mean(abs2, x[2, :]) .+ mean(abs2, us)
end
callback = function (state, l)
state.iter % 10 == 1 && @printf "Iteration: %5d\tLoss: %10g\n" state.iter l
return false
end
optf = OptimizationFunction((x, p) -> loss_adjoint(x), AutoZygote())
optprob = OptimizationProblem(optf, stateful_ude.ps)
res2 = solve(optprob, LBFGS(); callback, maxiters=100)
return StatefulLuxLayer{true}(stateful_ude.model, res2.u, stateful_ude.st)
end
trained_ude = train_model_2(trained_ude, 0.0:0.01:8.0)
┌ Warning: Lux.apply(m::AbstractLuxLayer, x::AbstractArray{<:ReverseDiff.TrackedReal}, ps, st) input was corrected to Lux.apply(m::AbstractLuxLayer, x::ReverseDiff.TrackedArray}, ps, st).
│
│ 1. If this was not the desired behavior overload the dispatch on `m`.
│
│ 2. This might have performance implications. Check which layer was causing this problem using `Lux.Experimental.@debug_mode`.
└ @ LuxCoreArrayInterfaceReverseDiffExt ~/.julia/packages/LuxCore/paKmr/ext/LuxCoreArrayInterfaceReverseDiffExt.jl:11
Iteration: 1 Loss: 12.7099
Iteration: 11 Loss: 12.6761
Iteration: 21 Loss: 12.664
Iteration: 31 Loss: 12.6503
Iteration: 41 Loss: 12.6331
Iteration: 51 Loss: 12.6149
Iteration: 61 Loss: 12.5907
Iteration: 71 Loss: 12.5801
Iteration: 81 Loss: 12.5585
Iteration: 91 Loss: 12.5348
sol, us = ude_test(([-4.0, 0.0], 0.0:0.01:8.0, Val(true)), trained_ude.ps, trained_ude.st)[1];
plot_dynamics(sol, us, 0.0:0.01:8.0)
Symbolic Regression
Ok so now we have a trained neural network that solves the optimal control problem. But can we replace Dense(4 => 4, gelu)
with a symbolic expression? Let's try!
Data Generation for Symbolic Regression
First, we need to generate data for the symbolic regression.
ts = reshape(collect(0.0:0.1:8.0), 1, :)
X_train = mlp[1](ts, trained_ude.ps.mlp.layer_1, trained_ude.st.mlp.layer_1)[1]
4×81 Matrix{Float64}:
-0.120704 -0.107623 -0.0946398 -0.0821155 -0.0703241 -0.059458 -0.0496368 -0.0409181 -0.0333078 -0.0267716 -0.0212448 -0.0166426 -0.0128676 -0.00981729 -0.0073891 -0.00548503 -0.00401443 -0.00289594 -0.0020584 -0.00144109 -0.000993376 -0.000673966 -0.000449879 -0.000295336 -0.000190602 -0.000120879 -7.53033e-5 -4.6061e-5 -2.76522e-5 -1.62862e-5 -9.40639e-6 -5.32542e-6 -2.95411e-6 -1.60494e-6 -8.53613e-7 -4.44271e-7 -2.26168e-7 -1.12571e-7 -5.47571e-8 -2.60188e-8 -1.2072e-8 -5.46667e-9 -2.41508e-9 -1.04044e-9 -4.36901e-10 -1.78749e-10 -7.1221e-11 -2.76238e-11 -1.04251e-11 -3.82652e-12 -1.36542e-12 -4.73449e-13 -1.59453e-13 -5.21382e-14 -1.65443e-14 -5.09231e-15 -1.51973e-15 -4.39553e-16 -1.23155e-16 -3.3412e-17 -8.77331e-18 -2.22866e-18 -5.4746e-19 -1.29985e-19 -2.98178e-20 -6.60546e-21 -1.41249e-21 -2.91424e-22 -5.79873e-23 -1.11227e-23 -2.05575e-24 -3.65943e-25 -6.27114e-26 -1.03414e-26 -1.64026e-27 -2.50124e-28 -3.66533e-29 -5.15932e-30 -6.97266e-31 -9.04352e-32 -1.12516e-32
0.284447 0.191474 0.108764 0.0369542 -0.0236105 -0.0729022 -0.111211 -0.139126 -0.157498 -0.167391 -0.170012 -0.166652 -0.158613 -0.147152 -0.133429 -0.118465 -0.103124 -0.0880986 -0.0739084 -0.0609151 -0.0493372 -0.0392727 -0.0307237 -0.0236197 -0.0178406 -0.0132361 -0.0096422 -0.00689425 -0.00483611 -0.00332654 -0.00224257 -0.00148086 -0.000957287 -0.00060544 -0.000374394 -0.000226225 -0.000133483 -7.68597e-5 -4.31589e-5 -2.36184e-5 -1.25877e-5 -6.52923e-6 -3.29382e-6 -1.61496e-6 -7.6904e-7 -3.55435e-7 -1.59328e-7 -6.92222e-8 -2.91282e-8 -1.1863e-8 -4.67283e-9 -1.77897e-9 -6.54109e-10 -2.32125e-10 -7.94465e-11 -2.6206e-11 -8.32519e-12 -2.54533e-12 -7.48418e-13 -2.11487e-13 -5.73924e-14 -1.49467e-14 -3.73292e-15 -8.93415e-16 -2.04762e-16 -4.49082e-17 -9.4183e-18 -1.88746e-18 -3.61187e-19 -6.59513e-20 -1.14826e-20 -1.90489e-21 -3.00888e-22 -4.52198e-23 -6.46148e-24 -8.77208e-25 -1.13065e-25 -1.3826e-26 -1.60286e-27 -1.7604e-28 -1.83035e-29
-0.162793 -0.142606 -0.116171 -0.0887784 -0.0639403 -0.0434972 -0.0279656 -0.0169827 -0.00972666 -0.00524259 -0.002652 -0.00125513 -0.000553878 -0.000227077 -8.61651e-5 -3.01445e-5 -9.68487e-6 -2.84613e-6 -7.61976e-7 -1.85093e-7 -4.06283e-8 -8.02553e-9 -1.42081e-9 -2.245e-10 -3.15296e-11 -3.91952e-12 -4.29487e-13 -4.13102e-14 -3.47325e-15 -2.54196e-16 -1.61262e-17 -8.83088e-19 -4.15679e-20 -1.67481e-21 -5.75174e-23 -1.6766e-24 -4.13069e-26 -8.56539e-28 -1.48857e-29 -2.15902e-31 -2.60239e-33 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0
-0.0873662 -0.0613839 -0.0404561 -0.0250232 -0.0145118 -0.00787502 -0.00398753 -0.00187753 -0.000818874 -0.000329434 -0.000121709 -4.11053e-5 -1.26319e-5 -3.5154e-6 -8.81734e-7 -1.98363e-7 -3.98318e-8 -7.10437e-9 -1.11999e-9 -1.55295e-10 -1.88456e-11 -1.99166e-12 -1.82399e-13 -1.44034e-14 -9.7585e-16 -5.64424e-17 -2.77308e-18 -1.15154e-19 -4.02145e-21 -1.17515e-22 -2.85914e-24 -5.76268e-26 -9.57371e-28 -1.30442e-29 -1.45029e-31 -1.3092e-33 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0
This is the training input data. Now we generate the targets
Y_train = mlp[2](X_train, trained_ude.ps.mlp.layer_2, trained_ude.st.mlp.layer_2)[1]
4×81 Matrix{Float64}:
0.0258644 0.0276239 0.0251692 0.0209178 0.0164794 0.0126878 0.00979604 0.0077031 0.00614522 0.0048323 0.0035284 0.00208602 0.000446717 -0.00137783 -0.00333192 -0.00534041 -0.0073271 -0.00922609 -0.0109871 -0.0125765 -0.0139762 -0.0151813 -0.0161968 -0.0170352 -0.0177135 -0.0182515 -0.01867 -0.0189891 -0.0192275 -0.0194021 -0.0195272 -0.0196151 -0.0196754 -0.0197159 -0.0197425 -0.0197596 -0.0197702 -0.0197768 -0.0197806 -0.0197829 -0.0197841 -0.0197848 -0.0197852 -0.0197854 -0.0197855 -0.0197855 -0.0197856 -0.0197856 -0.0197856 -0.0197856 -0.0197856 -0.0197856 -0.0197856 -0.0197856 -0.0197856 -0.0197856 -0.0197856 -0.0197856 -0.0197856 -0.0197856 -0.0197856 -0.0197856 -0.0197856 -0.0197856 -0.0197856 -0.0197856 -0.0197856 -0.0197856 -0.0197856 -0.0197856 -0.0197856 -0.0197856 -0.0197856 -0.0197856 -0.0197856 -0.0197856 -0.0197856 -0.0197856 -0.0197856 -0.0197856 -0.0197856
1.15615 0.893431 0.66566 0.479893 0.336555 0.231173 0.157009 0.107194 0.075877 0.058558 0.0519547 0.0536973 0.0620088 0.0754495 0.0927485 0.112718 0.134233 0.156257 0.177881 0.19836 0.217143 0.233871 0.248367 0.260613 0.270707 0.278833 0.285228 0.290149 0.293852 0.296577 0.29854 0.299922 0.300873 0.301512 0.301933 0.302203 0.302372 0.302475 0.302536 0.302572 0.302592 0.302603 0.302609 0.302612 0.302613 0.302614 0.302614 0.302615 0.302615 0.302615 0.302615 0.302615 0.302615 0.302615 0.302615 0.302615 0.302615 0.302615 0.302615 0.302615 0.302615 0.302615 0.302615 0.302615 0.302615 0.302615 0.302615 0.302615 0.302615 0.302615 0.302615 0.302615 0.302615 0.302615 0.302615 0.302615 0.302615 0.302615 0.302615 0.302615 0.302615
0.734348 0.908501 1.06398 1.19699 1.30671 1.39399 1.46054 1.50839 1.53967 1.55659 1.56133 1.5561 1.54306 1.52427 1.50164 1.47686 1.45139 1.42637 1.40271 1.38101 1.36166 1.34482 1.33051 1.31862 1.30894 1.30122 1.2952 1.29059 1.28714 1.28461 1.2828 1.28152 1.28064 1.28005 1.27966 1.27942 1.27926 1.27916 1.27911 1.27908 1.27906 1.27905 1.27904 1.27904 1.27904 1.27904 1.27904 1.27904 1.27904 1.27904 1.27904 1.27904 1.27904 1.27904 1.27904 1.27904 1.27904 1.27904 1.27904 1.27904 1.27904 1.27904 1.27904 1.27904 1.27904 1.27904 1.27904 1.27904 1.27904 1.27904 1.27904 1.27904 1.27904 1.27904 1.27904 1.27904 1.27904 1.27904 1.27904 1.27904 1.27904
0.510947 0.184332 -0.0184213 -0.121812 -0.162144 -0.169999 -0.164754 -0.156503 -0.149661 -0.145724 -0.144839 -0.146556 -0.150164 -0.154843 -0.159763 -0.164184 -0.167543 -0.169516 -0.170028 -0.16922 -0.167379 -0.164859 -0.162008 -0.159119 -0.156409 -0.154013 -0.151994 -0.150361 -0.149086 -0.148123 -0.147417 -0.146913 -0.146564 -0.146327 -0.146171 -0.14607 -0.146007 -0.145969 -0.145946 -0.145933 -0.145925 -0.145921 -0.145919 -0.145918 -0.145917 -0.145917 -0.145917 -0.145917 -0.145916 -0.145916 -0.145916 -0.145916 -0.145916 -0.145916 -0.145916 -0.145916 -0.145916 -0.145916 -0.145916 -0.145916 -0.145916 -0.145916 -0.145916 -0.145916 -0.145916 -0.145916 -0.145916 -0.145916 -0.145916 -0.145916 -0.145916 -0.145916 -0.145916 -0.145916 -0.145916 -0.145916 -0.145916 -0.145916 -0.145916 -0.145916 -0.145916
Fitting the Symbolic Expression
We will follow the example from SymbolicRegression.jl docs to fit the symbolic expression.
srmodel = MultitargetSRRegressor(;
binary_operators=[+, -, *, /], niterations=100, save_to_file=false
);
One important note here is to transpose the data because that is how MLJ expects the data to be structured (this is in contrast to how Lux or SymbolicRegression expects the data)
mach = machine(srmodel, X_train', Y_train')
fit!(mach; verbosity=0)
r = report(mach)
best_eq = [
r.equations[1][r.best_idx[1]],
r.equations[2][r.best_idx[2]],
r.equations[3][r.best_idx[3]],
r.equations[4][r.best_idx[4]],
]
4-element Vector{DynamicExpressions.ExpressionModule.Expression{Float64, DynamicExpressions.NodeModule.Node{Float64}, @NamedTuple{operators::DynamicExpressions.OperatorEnumModule.OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{}}, variable_names::Vector{String}}}}:
(((x3 * -2.4092916789032435) - ((x3 + 0.08960958046338138) + x1)) - ((x2 + 0.08078663584368008) + ((x3 + (x3 + 0.09334729295009023)) * (((-0.521541699885746 - x3) + x2) * x3)))) * 0.11612160025301593
(((x1 - x3) * -0.3337324220036217) - (x3 - x4)) + ((-0.17839728272567787 - ((x2 * -0.9878397915588929) * (-1.0832320383529381 - x2))) / (-0.5895663018948398 - (x3 - x4)))
(((x1 - x2) + (x1 * (x3 * -1.5462025967599249))) + ((1.2790627940040789 - (x3 * (x2 * -0.6178417897906934))) - x1)) - ((x1 * -0.08707973184496466) - (x2 * -0.6718295912326885))
((x2 * (-3.8008604342042847 - ((((((x2 + x2) + x2) + (x3 / x2)) * (x1 + (x2 + x4))) - x2) * -4.958375209351411))) * (-0.16838310264227488 - x2)) + -0.1459382012628447
Let's see the expressions that SymbolicRegression.jl found. In case you were wondering, these expressions are not hardcoded, it is live updated from the output of the code above using Latexify.jl
and the integration of SymbolicUtils.jl
with DynamicExpressions.jl
.
Combining the Neural Network with the Symbolic Expression
Now that we have the symbolic expression, we can combine it with the neural network to solve the optimal control problem. but we do need to perform some finetuning.
hybrid_mlp = Chain(
Dense(1 => 4, gelu),
Layers.DynamicExpressionsLayer(OperatorEnum(; binary_operators=[+, -, *, /]), best_eq),
Dense(4 => 1),
)
Chain(
layer_1 = Dense(1 => 4, gelu_tanh), # 8 parameters
layer_2 = DynamicExpressionsLayer(
chain = Chain(
layer_1 = Parallel(
layer_1 = InternalDynamicExpressionWrapper(DynamicExpressions.OperatorEnumModule.OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{}}((+, -, *, /), ()), (((x3 * -2.4092916789032435) - ((x3 + 0.08960958046338138) + x1)) - ((x2 + 0.08078663584368008) + ((x3 + (x3 + 0.09334729295009023)) * (((-0.521541699885746 - x3) + x2) * x3)))) * 0.11612160025301593; eval_options=(turbo = Val{false}(), bumper = Val{false}())), # 6 parameters
layer_2 = InternalDynamicExpressionWrapper(DynamicExpressions.OperatorEnumModule.OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{}}((+, -, *, /), ()), (((x1 - x3) * -0.3337324220036217) - (x3 - x4)) + ((-0.17839728272567787 - ((x2 * -0.9878397915588929) * (-1.0832320383529381 - x2))) / (-0.5895663018948398 - (x3 - x4))); eval_options=(turbo = Val{false}(), bumper = Val{false}())), # 5 parameters
layer_3 = InternalDynamicExpressionWrapper(DynamicExpressions.OperatorEnumModule.OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{}}((+, -, *, /), ()), (((x1 - x2) + (x1 * (x3 * -1.5462025967599249))) + ((1.2790627940040789 - (x3 * (x2 * -0.6178417897906934))) - x1)) - ((x1 * -0.08707973184496466) - (x2 * -0.6718295912326885)); eval_options=(turbo = Val{false}(), bumper = Val{false}())), # 5 parameters
layer_4 = InternalDynamicExpressionWrapper(DynamicExpressions.OperatorEnumModule.OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{}}((+, -, *, /), ()), ((x2 * (-3.8008604342042847 - ((((((x2 + x2) + x2) + (x3 / x2)) * (x1 + (x2 + x4))) - x2) * -4.958375209351411))) * (-0.16838310264227488 - x2)) + -0.1459382012628447; eval_options=(turbo = Val{false}(), bumper = Val{false}())), # 4 parameters
),
layer_2 = WrappedFunction(stack1),
),
),
layer_3 = Dense(4 => 1), # 5 parameters
) # Total: 33 parameters,
# plus 0 states.
There you have it! It is that easy to take the fitted Symbolic Expression and combine it with a neural network. Let's see how it performs before fintetuning.
hybrid_ude = construct_ude(hybrid_mlp, Vern9(); abstol=1e-10, reltol=1e-10);
We want to reuse the trained neural network parameters, so we will copy them over to the new model
st = Lux.initialstates(rng, hybrid_ude)
ps = (;
mlp=(;
layer_1=trained_ude.ps.mlp.layer_1,
layer_2=Lux.initialparameters(rng, hybrid_mlp[2]),
layer_3=trained_ude.ps.mlp.layer_3,
)
)
ps = ComponentArray(ps)
sol, us = hybrid_ude(([-4.0, 0.0], 0.0:0.01:8.0, Val(true)), ps, st)[1];
plot_dynamics(sol, us, 0.0:0.01:8.0)
Now that does perform well! But we could finetune this model very easily. We will skip that part on CI, but you can do it by using the same training code as above.
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.4
Commit 8561cc3d68d (2025-03-10 11:36 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 4 × AMD EPYC 7763 64-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)
Environment:
JULIA_NUM_THREADS = 1
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
This page was generated using Literate.jl.