Solving Optimal Control Problems with Symbolic Universal Differential Equations
This tutorial is based on SciMLSensitivity.jl tutorial. Instead of using a classical NN architecture, here we will combine the NN with a symbolic expression from DynamicExpressions.jl (the symbolic engine behind SymbolicRegression.jl and PySR).
Here we will solve a classic optimal control problem with a universal differential equation. Let
where we want to optimize our controller
where
and thus
is our loss function on the first order system. We thus choose a neural network form for
Package Imports
using Lux, Boltz, ComponentArrays, OrdinaryDiffEqVerner, Optimization, OptimizationOptimJL,
OptimizationOptimisers, SciMLSensitivity, Statistics, Printf, Random
using DynamicExpressions, SymbolicRegression, MLJ, SymbolicUtils, Latexify
using CairoMakie
Precompiling Lux...
1029.4 ms ✓ ADTypes → ADTypesConstructionBaseExt
903.8 ms ✓ LuxCore → LuxCoreSetfieldExt
1481.7 ms ✓ MLDataDevices → MLDataDevicesChainRulesExt
2753.4 ms ✓ MLDataDevices → MLDataDevicesZygoteExt
4641.9 ms ✓ LazyArrays
2463.8 ms ✓ LazyArrays → LazyArraysStaticArraysExt
6309.9 ms ✓ LuxCore → LuxCoreArrayInterfaceReverseDiffExt
6398.9 ms ✓ MLDataDevices → MLDataDevicesReverseDiffExt
1683.5 ms ✓ Transducers → TransducersLazyArraysExt
2292.3 ms ✓ MLDataDevices → MLDataDevicesMLUtilsExt
7346.9 ms ✓ Tracker
3269.5 ms ✓ ArrayInterface → ArrayInterfaceTrackerExt
3467.6 ms ✓ LuxCore → LuxCoreArrayInterfaceTrackerExt
3548.4 ms ✓ MLDataDevices → MLDataDevicesTrackerExt
3561.5 ms ✓ Tracker → TrackerPDMatsExt
1724.2 ms ✓ SymbolicIndexingInterface
3054.5 ms ✓ Zygote → ZygoteTrackerExt
2340.2 ms ✓ RecursiveArrayTools
1611.2 ms ✓ RecursiveArrayTools → RecursiveArrayToolsForwardDiffExt
1611.8 ms ✓ RecursiveArrayTools → RecursiveArrayToolsSparseArraysExt
1316.4 ms ✓ MLDataDevices → MLDataDevicesRecursiveArrayToolsExt
3402.0 ms ✓ RecursiveArrayTools → RecursiveArrayToolsTrackerExt
5530.2 ms ✓ RecursiveArrayTools → RecursiveArrayToolsZygoteExt
6808.5 ms ✓ LuxLib → LuxLibTrackerExt
9162.1 ms ✓ RecursiveArrayTools → RecursiveArrayToolsReverseDiffExt
5430.7 ms ✓ LuxLib → LuxLibSLEEFPiratesExt
8952.3 ms ✓ LuxLib → LuxLibReverseDiffExt
6186.7 ms ✓ LuxLib → LuxLibLoopVectorizationExt
186575.5 ms ✓ Enzyme
4046.4 ms ✓ LuxLib → LuxLibEnzymeExt
10572.9 ms ✓ Enzyme → EnzymeLogExpFunctionsExt
11667.3 ms ✓ Enzyme → EnzymeSpecialFunctionsExt
14471.9 ms ✓ Enzyme → EnzymeStaticArraysExt
13699.6 ms ✓ Lux
5096.7 ms ✓ Lux → LuxTrackerExt
5487.4 ms ✓ Lux → LuxMLUtilsExt
25980.7 ms ✓ Enzyme → EnzymeChainRulesCoreExt
8571.9 ms ✓ Lux → LuxReverseDiffExt
4669.3 ms ✓ Lux → LuxZygoteExt
10335.7 ms ✓ Lux → LuxEnzymeExt
40 dependencies successfully precompiled in 215 seconds. 244 already precompiled.
Precompiling Boltz...
958.2 ms ✓ ComponentArrays
1049.4 ms ✓ ComponentArrays → ComponentArraysConstructionBaseExt
1057.4 ms ✓ ComponentArrays → ComponentArraysAdaptExt
1140.5 ms ✓ ComponentArrays → ComponentArraysOptimisersExt
1372.1 ms ✓ ComponentArrays → ComponentArraysRecursiveArrayToolsExt
3390.1 ms ✓ Lux → LuxLossFunctionsExt
2991.4 ms ✓ ComponentArrays → ComponentArraysGPUArraysExt
3028.5 ms ✓ ComponentArrays → ComponentArraysTrackerExt
2217.9 ms ✓ ComponentArrays → ComponentArraysZygoteExt
2937.6 ms ✓ Lux → LuxComponentArraysExt
4622.9 ms ✓ ComponentArrays → ComponentArraysReverseDiffExt
5231.3 ms ✓ Boltz
2867.6 ms ✓ Boltz → BoltzZygoteExt
3514.1 ms ✓ Boltz → BoltzTrackerExt
5484.4 ms ✓ Boltz → BoltzReverseDiffExt
15 dependencies successfully precompiled in 18 seconds. 285 already precompiled.
Precompiling OrdinaryDiffEqVerner...
1505.5 ms ✓ XML2_jll
1699.5 ms ✓ Fontconfig_jll
2669.5 ms ✓ QOI
3000.0 ms ✓ OpenEXR
4083.8 ms ✓ FreeTypeAbstraction
3307.3 ms ✓ FastPower → FastPowerTrackerExt
1600.1 ms ✓ RecursiveArrayTools → RecursiveArrayToolsFastBroadcastExt
1479.8 ms ✓ Gettext_jll
1499.6 ms ✓ XSLT_jll
11177.2 ms ✓ FastPower → FastPowerEnzymeExt
11516.9 ms ✓ QuadGK → QuadGKEnzymeExt
1940.2 ms ✓ Glib_jll
2148.2 ms ✓ Xorg_libxcb_jll
1429.0 ms ✓ Xorg_libX11_jll
1466.3 ms ✓ Xorg_libXrender_jll
1475.7 ms ✓ Xorg_libXext_jll
1574.7 ms ✓ Cairo_jll
1973.0 ms ✓ Libglvnd_jll
1991.8 ms ✓ HarfBuzz_jll
2128.1 ms ✓ libwebp_jll
2251.3 ms ✓ libass_jll
2813.6 ms ✓ FFMPEG_jll
22322.6 ms ✓ SciMLBase
35778.1 ms ✓ ImageCore
2676.4 ms ✓ SciMLBase → SciMLBaseChainRulesCoreExt
5016.2 ms ✓ ImageBase
29231.1 ms ✓ MathTeXEngine
7000.5 ms ✓ SciMLBase → SciMLBaseZygoteExt
8413.2 ms ✓ PNGFiles
4815.6 ms ✓ WebP
6191.2 ms ✓ JpegTurbo
7033.3 ms ✓ Sixel
2732.9 ms ✓ ImageAxes
1197.0 ms ✓ ImageMetadata
1831.5 ms ✓ Netpbm
84015.9 ms ✓ TiffImages
1269.7 ms ✓ ImageIO
138666.4 ms ✓ Makie
8199.6 ms ✓ SciMLBase → SciMLBaseMakieExt
5526.1 ms ✓ DiffEqBase
2514.3 ms ✓ DiffEqBase → DiffEqBaseChainRulesCoreExt
2838.5 ms ✓ DiffEqBase → DiffEqBaseSparseArraysExt
3212.2 ms ✓ DiffEqBase → DiffEqBaseUnitfulExt
3583.0 ms ✓ DiffEqBase → DiffEqBaseDistributionsExt
4662.5 ms ✓ DiffEqBase → DiffEqBaseTrackerExt
4817.7 ms ✓ DiffEqBase → DiffEqBaseReverseDiffExt
9890.8 ms ✓ DiffEqBase → DiffEqBaseEnzymeExt
4199.5 ms ✓ OrdinaryDiffEqCore
1447.2 ms ✓ OrdinaryDiffEqCore → OrdinaryDiffEqCoreEnzymeCoreExt
34960.4 ms ✓ OrdinaryDiffEqVerner
50 dependencies successfully precompiled in 291 seconds. 355 already precompiled.
Precompiling ComponentArraysSciMLBaseExt...
1083.1 ms ✓ ComponentArrays → ComponentArraysSciMLBaseExt
1 dependency successfully precompiled in 2 seconds. 209 already precompiled.
Precompiling Optimization...
1025.1 ms ✓ DifferentiationInterface
1985.2 ms ✓ SparseMatrixColorings
1505.6 ms ✓ DifferentiationInterface → DifferentiationInterfaceForwardDiffExt
951.1 ms ✓ DifferentiationInterface → DifferentiationInterfaceFiniteDiffExt
1175.0 ms ✓ DifferentiationInterface → DifferentiationInterfaceStaticArraysExt
3412.6 ms ✓ DifferentiationInterface → DifferentiationInterfaceTrackerExt
885.4 ms ✓ DifferentiationInterface → DifferentiationInterfaceChainRulesCoreExt
1438.2 ms ✓ DifferentiationInterface → DifferentiationInterfaceSparseArraysExt
6635.4 ms ✓ SparseConnectivityTracer
2848.1 ms ✓ DifferentiationInterface → DifferentiationInterfaceZygoteExt
1679.6 ms ✓ SparseMatrixColorings → SparseMatrixColoringsColorsExt
1658.8 ms ✓ DifferentiationInterface → DifferentiationInterfaceSparseMatrixColoringsExt
6603.6 ms ✓ DifferentiationInterface → DifferentiationInterfaceReverseDiffExt
2215.5 ms ✓ SparseConnectivityTracer → SparseConnectivityTracerNaNMathExt
2270.1 ms ✓ SparseConnectivityTracer → SparseConnectivityTracerLogExpFunctionsExt
3787.3 ms ✓ SparseConnectivityTracer → SparseConnectivityTracerNNlibExt
2369.8 ms ✓ SparseConnectivityTracer → SparseConnectivityTracerSpecialFunctionsExt
10654.5 ms ✓ DifferentiationInterface → DifferentiationInterfaceEnzymeExt
2040.7 ms ✓ OptimizationBase
850.3 ms ✓ OptimizationBase → OptimizationFiniteDiffExt
2555.7 ms ✓ OptimizationBase → OptimizationMLDataDevicesExt
1415.9 ms ✓ OptimizationBase → OptimizationForwardDiffExt
3737.6 ms ✓ OptimizationBase → OptimizationZygoteExt
4917.0 ms ✓ OptimizationBase → OptimizationMLUtilsExt
5229.5 ms ✓ OptimizationBase → OptimizationReverseDiffExt
17972.7 ms ✓ OptimizationBase → OptimizationEnzymeExt
1871.8 ms ✓ Optimization
27 dependencies successfully precompiled in 34 seconds. 429 already precompiled.
Precompiling OptimizationOptimJL...
15298.0 ms ✓ OptimizationOptimJL
1 dependency successfully precompiled in 16 seconds. 462 already precompiled.
Precompiling OptimizationOptimisers...
1793.0 ms ✓ OptimizationOptimisers
1 dependency successfully precompiled in 3 seconds. 456 already precompiled.
Precompiling SciMLSensitivity...
3814.0 ms ✓ SciMLJacobianOperators
5325.3 ms ✓ DiffEqNoiseProcess
6035.6 ms ✓ DiffEqCallbacks
4433.9 ms ✓ DiffEqNoiseProcess → DiffEqNoiseProcessReverseDiffExt
33319.7 ms ✓ LinearSolve
2349.1 ms ✓ LinearSolve → LinearSolveRecursiveArrayToolsExt
3294.7 ms ✓ LinearSolve → LinearSolveEnzymeExt
5368.0 ms ✓ LinearSolve → LinearSolveKernelAbstractionsExt
23889.6 ms ✓ SciMLSensitivity
9 dependencies successfully precompiled in 63 seconds. 447 already precompiled.
Precompiling BoltzDynamicExpressionsExt...
2123.7 ms ✓ Boltz → BoltzDynamicExpressionsExt
1 dependency successfully precompiled in 3 seconds. 304 already precompiled.
Precompiling SymbolicUtils...
1459.7 ms ✓ LabelledArrays
18209.3 ms ✓ SymbolicUtils
2 dependencies successfully precompiled in 20 seconds. 197 already precompiled.
Precompiling DynamicExpressionsSymbolicUtilsExt...
1829.1 ms ✓ DynamicExpressions → DynamicExpressionsSymbolicUtilsExt
1 dependency successfully precompiled in 2 seconds. 202 already precompiled.
Precompiling SymbolicRegressionSymbolicUtilsExt...
3499.8 ms ✓ SymbolicRegression → SymbolicRegressionSymbolicUtilsExt
1 dependency successfully precompiled in 4 seconds. 276 already precompiled.
Precompiling CairoMakie...
819.3 ms ✓ Pango_jll
1318.3 ms ✓ Cairo
80356.4 ms ✓ CairoMakie
3 dependencies successfully precompiled in 83 seconds. 293 already precompiled.
Helper Functions
function plot_dynamics(sol, us, ts)
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel=L"t")
ylims!(ax, (-6, 6))
lines!(ax, ts, sol[1, :]; label=L"u_1(t)", linewidth=3)
lines!(ax, ts, sol[2, :]; label=L"u_2(t)", linewidth=3)
lines!(ax, ts, vec(us); label=L"u(t)", linewidth=3)
axislegend(ax; position=:rb)
return fig
end
plot_dynamics (generic function with 1 method)
Training a Neural Network based UDE
Let's setup the neural network. For the first part, we won't do any symbolic regression. We will plain and simple train a neural network to solve the optimal control problem.
rng = Xoshiro(0)
tspan = (0.0, 8.0)
mlp = Chain(Dense(1 => 4, gelu), Dense(4 => 4, gelu), Dense(4 => 1))
function construct_ude(mlp, solver; kwargs...)
return @compact(; mlp, solver, kwargs...) do x_in, ps
x, ts, ret_sol = x_in
function dudt(du, u, p, t)
u₁, u₂ = u
du[1] = u₂
du[2] = mlp([t], p)[1]^3
return
end
prob = ODEProblem{true}(dudt, x, extrema(ts), ps.mlp)
sol = solve(prob, solver; saveat=ts,
sensealg=QuadratureAdjoint(; autojacvec=ReverseDiffVJP(true)), kwargs...)
us = mlp(reshape(ts, 1, :), ps.mlp)
ret_sol === Val(true) && @return sol, us
@return Array(sol), us
end
end
ude = construct_ude(mlp, Vern9(); abstol=1e-10, reltol=1e-10);
Here we are going to tuse the same configuration for testing, but this is to show that we can setup them up with different ode solve configurations
ude_test = construct_ude(mlp, Vern9(); abstol=1e-10, reltol=1e-10);
function train_model_1(ude, rng, ts_)
ps, st = Lux.setup(rng, ude)
ps = ComponentArray{Float64}(ps)
stateful_ude = StatefulLuxLayer{true}(ude, nothing, st)
ts = collect(ts_)
function loss_adjoint(θ)
x, us = stateful_ude(([-4.0, 0.0], ts, Val(false)), θ)
return mean(abs2, 4 .- x[1, :]) + 2 * mean(abs2, x[2, :]) + 0.1 * mean(abs2, us)
end
callback = function (state, l)
state.iter % 50 == 1 && @printf "Iteration: %5d\tLoss: %10g\n" state.iter l
return false
end
optf = OptimizationFunction((x, p) -> loss_adjoint(x), AutoZygote())
optprob = OptimizationProblem(optf, ps)
res1 = solve(optprob, Optimisers.Adam(0.001); callback, maxiters=500)
optprob = OptimizationProblem(optf, res1.u)
res2 = solve(optprob, LBFGS(); callback, maxiters=100)
return StatefulLuxLayer{true}(ude, res2.u, st)
end
trained_ude = train_model_1(ude, rng, 0.0:0.01:8.0)
┌ Warning: Lux.apply(m::AbstractLuxLayer, x::AbstractArray{<:ReverseDiff.TrackedReal}, ps, st) input was corrected to Lux.apply(m::AbstractLuxLayer, x::ReverseDiff.TrackedArray}, ps, st).
│
│ 1. If this was not the desired behavior overload the dispatch on `m`.
│
│ 2. This might have performance implications. Check which layer was causing this problem using `Lux.Experimental.@debug_mode`.
└ @ LuxCoreArrayInterfaceReverseDiffExt ~/.julia/packages/LuxCore/SN4dl/ext/LuxCoreArrayInterfaceReverseDiffExt.jl:10
Iteration: 1 Loss: 40.5618
Iteration: 51 Loss: 29.4147
Iteration: 101 Loss: 28.2559
Iteration: 151 Loss: 27.217
Iteration: 201 Loss: 26.1657
Iteration: 251 Loss: 25.1631
Iteration: 301 Loss: 24.2914
Iteration: 351 Loss: 23.5965
Iteration: 401 Loss: 23.0763
Iteration: 451 Loss: 22.6983
Iteration: 1 Loss: 22.2401
Iteration: 51 Loss: 11.981
sol, us = ude_test(([-4.0, 0.0], 0.0:0.01:8.0, Val(true)), trained_ude.ps, trained_ude.st)[1];
plot_dynamics(sol, us, 0.0:0.01:8.0)
Now that the system is in a better behaved part of parameter space, we return to the original loss function to finish the optimization:
function train_model_2(stateful_ude::StatefulLuxLayer, ts_)
ts = collect(ts_)
function loss_adjoint(θ)
x, us = stateful_ude(([-4.0, 0.0], ts, Val(false)), θ)
return mean(abs2, 4 .- x[1, :]) .+ 2 * mean(abs2, x[2, :]) .+ mean(abs2, us)
end
callback = function (state, l)
state.iter % 10 == 1 && @printf "Iteration: %5d\tLoss: %10g\n" state.iter l
return false
end
optf = OptimizationFunction((x, p) -> loss_adjoint(x), AutoZygote())
optprob = OptimizationProblem(optf, stateful_ude.ps)
res2 = solve(optprob, LBFGS(); callback, maxiters=100)
return StatefulLuxLayer{true}(stateful_ude.model, res2.u, stateful_ude.st)
end
trained_ude = train_model_2(trained_ude, 0.0:0.01:8.0)
┌ Warning: Lux.apply(m::AbstractLuxLayer, x::AbstractArray{<:ReverseDiff.TrackedReal}, ps, st) input was corrected to Lux.apply(m::AbstractLuxLayer, x::ReverseDiff.TrackedArray}, ps, st).
│
│ 1. If this was not the desired behavior overload the dispatch on `m`.
│
│ 2. This might have performance implications. Check which layer was causing this problem using `Lux.Experimental.@debug_mode`.
└ @ LuxCoreArrayInterfaceReverseDiffExt ~/.julia/packages/LuxCore/SN4dl/ext/LuxCoreArrayInterfaceReverseDiffExt.jl:10
Iteration: 1 Loss: 12.7005
Iteration: 11 Loss: 12.6854
Iteration: 21 Loss: 12.6642
Iteration: 31 Loss: 12.6517
Iteration: 41 Loss: 12.6409
Iteration: 51 Loss: 12.6279
Iteration: 61 Loss: 12.618
Iteration: 71 Loss: 12.6067
Iteration: 81 Loss: 12.5826
Iteration: 91 Loss: 12.5693
sol, us = ude_test(([-4.0, 0.0], 0.0:0.01:8.0, Val(true)), trained_ude.ps, trained_ude.st)[1];
plot_dynamics(sol, us, 0.0:0.01:8.0)
Symbolic Regression
Ok so now we have a trained neural network that solves the optimal control problem. But can we replace Dense(4 => 4, gelu)
with a symbolic expression? Let's try!
Data Generation for Symbolic Regression
First, we need to generate data for the symbolic regression.
ts = reshape(collect(0.0:0.1:8.0), 1, :)
X_train = mlp[1](ts, trained_ude.ps.mlp.layer_1, trained_ude.st.mlp.layer_1)[1]
4×81 Matrix{Float64}:
-0.126364 -0.113737 -0.101013 -0.088566 -0.0766875 -0.0655939 -0.0554318 -0.0462866 -0.0381916 -0.031138 -0.025084 -0.0199636 -0.0156948 -0.0121862 -0.00934305 -0.00707156 -0.00528246 -0.00389342 -0.00283057 -0.0020292 -0.001434 -0.000998613 -0.000685048 -0.00046277 -0.000307734 -0.000201368 -0.000129614 -8.2034e-5 -5.10331e-5 -3.11933e-5 -1.87264e-5 -1.10373e-5 -6.38434e-6 -3.62282e-6 -2.01597e-6 -1.09965e-6 -5.87746e-7 -3.07691e-7 -1.5771e-7 -7.91131e-8 -3.88248e-8 -1.86324e-8 -8.74084e-9 -4.0067e-9 -1.79389e-9 -7.84159e-10 -3.34532e-10 -1.39226e-10 -5.65039e-11 -2.23529e-11 -8.6161e-12 -3.2347e-12 -1.1823e-12 -4.20545e-13 -1.45517e-13 -4.89617e-14 -1.60126e-14 -5.08804e-15 -1.57018e-15 -4.70413e-16 -1.36761e-16 -3.85676e-17 -1.05458e-17 -2.79487e-18 -7.17604e-19 -1.78433e-19 -4.29488e-20 -1.00032e-20 -2.2535e-21 -4.90829e-22 -1.03319e-22 -2.10099e-23 -4.12562e-24 -7.8198e-25 -1.43009e-25 -2.52242e-26 -4.2892e-27 -7.0285e-28 -1.10942e-28 -1.68618e-29 -2.46661e-30
0.320806 0.223676 0.136649 0.0604756 -0.0043859 -0.0577953 -0.0999407 -0.131326 -0.152741 -0.165206 -0.169916 -0.16817 -0.1613 -0.150608 -0.137308 -0.122485 -0.107066 -0.0918026 -0.0772691 -0.0638734 -0.0518714 -0.0413899 -0.0324512 -0.0249974 -0.0189152 -0.0140561 -0.0102543 -0.00734124 -0.00515532 -0.00354938 -0.00239457 -0.00158211 -0.00102311 -0.000647175 -0.00040018 -0.000241738 -0.000142562 -8.20247e-5 -4.60118e-5 -2.51468e-5 -1.33809e-5 -6.92754e-6 -3.48704e-6 -1.70536e-6 -8.09753e-7 -3.73041e-7 -1.66618e-7 -7.21007e-8 -3.02064e-8 -1.22431e-8 -4.79739e-9 -1.81605e-9 -6.63662e-10 -2.33965e-10 -7.95106e-11 -2.60289e-11 -8.2022e-12 -2.48618e-12 -7.24349e-13 -2.02704e-13 -5.44449e-14 -1.40256e-14 -3.46284e-15 -8.18804e-16 -1.85287e-16 -4.00968e-17 -8.29198e-18 -1.63747e-18 -3.08559e-19 -5.54413e-20 -9.49168e-21 -1.54721e-21 -2.39956e-22 -3.53814e-23 -4.95633e-24 -6.59127e-25 -8.31538e-26 -9.94447e-27 -1.12655e-27 -1.208e-28 -1.22521e-29
-0.16266 -0.142434 -0.116028 -0.0886899 -0.0639064 -0.0435037 -0.0279947 -0.0170192 -0.00976053 -0.0052691 -0.00267028 -0.00126644 -0.000560214 -0.000230301 -8.76569e-5 -3.0772e-5 -9.92434e-6 -2.92889e-6 -7.87805e-7 -1.92352e-7 -4.24592e-8 -8.43867e-9 -1.50391e-9 -2.39349e-10 -3.38773e-11 -4.24678e-12 -4.69551e-13 -4.56012e-14 -3.87376e-15 -2.86645e-16 -1.83991e-17 -1.02018e-18 -4.86601e-20 -1.98822e-21 -6.92998e-23 -2.05191e-24 -5.13949e-26 -1.08442e-27 -1.91939e-29 -2.83789e-31 -3.49036e-33 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0
-0.0971689 -0.0708357 -0.048652 -0.0315161 -0.0192489 -0.0110697 -0.00598116 -0.00302802 -0.00143173 -0.000630012 -0.000257027 -9.68351e-5 -3.35541e-5 -1.06492e-5 -3.08264e-6 -8.10436e-7 -1.92684e-7 -4.12511e-8 -7.91789e-9 -1.35669e-9 -2.06611e-10 -2.7844e-11 -3.30605e-12 -3.44332e-13 -3.13202e-14 -2.47705e-15 -1.69586e-16 -1.00062e-17 -5.06584e-19 -2.19083e-20 -8.05775e-22 -2.50923e-23 -6.58656e-25 -1.4509e-26 -2.67019e-28 -4.08736e-30 -5.1809e-32 -5.41368e-34 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0
This is the training input data. Now we generate the targets
Y_train = mlp[2](X_train, trained_ude.ps.mlp.layer_2, trained_ude.st.mlp.layer_2)[1]
4×81 Matrix{Float64}:
0.0265776 0.0287949 0.0268332 0.0230741 0.0191051 0.0157478 0.0132492 0.0115036 0.0102435 0.00917577 0.00806274 0.00675754 0.00520492 0.00342205 0.00147185 -0.000562454 -0.00259652 -0.00455687 -0.00638661 -0.00804687 -0.00951551 -0.0107846 -0.0118575 -0.0127457 -0.0134662 -0.0140389 -0.0144852 -0.0148261 -0.0150812 -0.0152682 -0.0154025 -0.0154968 -0.0155616 -0.0156052 -0.0156338 -0.0156521 -0.0156636 -0.0156706 -0.0156747 -0.0156771 -0.0156785 -0.0156792 -0.0156796 -0.0156798 -0.0156799 -0.01568 -0.01568 -0.01568 -0.01568 -0.01568 -0.01568 -0.01568 -0.01568 -0.01568 -0.01568 -0.01568 -0.01568 -0.01568 -0.01568 -0.01568 -0.01568 -0.01568 -0.01568 -0.01568 -0.01568 -0.01568 -0.01568 -0.01568 -0.01568 -0.01568 -0.01568 -0.01568 -0.01568 -0.01568 -0.01568 -0.01568 -0.01568 -0.01568 -0.01568 -0.01568 -0.01568
0.584846 0.385835 0.226072 0.106037 0.0208161 -0.0369618 -0.0746977 -0.0985211 -0.112936 -0.120989 -0.124619 -0.125013 -0.122896 -0.118746 -0.112935 -0.105816 -0.0977562 -0.0891399 -0.0803507 -0.0717428 -0.0636151 -0.0561935 -0.049624 -0.0439755 -0.0392505 -0.0354001 -0.0323404 -0.0299676 -0.0281706 -0.0268413 -0.0258805 -0.025202 -0.0247339 -0.0244185 -0.024211 -0.0240778 -0.0239943 -0.0239434 -0.023913 -0.0238955 -0.0238855 -0.0238801 -0.0238772 -0.0238757 -0.0238749 -0.0238746 -0.0238744 -0.0238743 -0.0238743 -0.0238743 -0.0238743 -0.0238743 -0.0238743 -0.0238743 -0.0238743 -0.0238743 -0.0238743 -0.0238743 -0.0238743 -0.0238743 -0.0238743 -0.0238743 -0.0238743 -0.0238743 -0.0238743 -0.0238743 -0.0238743 -0.0238743 -0.0238743 -0.0238743 -0.0238743 -0.0238743 -0.0238743 -0.0238743 -0.0238743 -0.0238743 -0.0238743 -0.0238743 -0.0238743 -0.0238743 -0.0238743
0.234934 0.37719 0.517749 0.647565 0.760965 0.855234 0.929779 0.985351 1.02349 1.04617 1.05559 1.05407 1.04389 1.02728 1.00629 0.982733 0.958166 0.933837 0.910693 0.889395 0.870353 0.853758 0.839636 0.827881 0.818306 0.810666 0.804696 0.800125 0.796697 0.794179 0.79237 0.791097 0.790221 0.789632 0.789245 0.788997 0.788841 0.788746 0.78869 0.788657 0.788639 0.788629 0.788623 0.788621 0.788619 0.788619 0.788618 0.788618 0.788618 0.788618 0.788618 0.788618 0.788618 0.788618 0.788618 0.788618 0.788618 0.788618 0.788618 0.788618 0.788618 0.788618 0.788618 0.788618 0.788618 0.788618 0.788618 0.788618 0.788618 0.788618 0.788618 0.788618 0.788618 0.788618 0.788618 0.788618 0.788618 0.788618 0.788618 0.788618 0.788618
0.798978 0.466864 0.217185 0.0480198 -0.0559418 -0.114347 -0.144489 -0.15878 -0.164965 -0.167342 -0.167978 -0.167615 -0.166271 -0.163642 -0.159372 -0.153233 -0.145218 -0.135568 -0.124725 -0.113252 -0.101734 -0.0906964 -0.0805516 -0.0715713 -0.0638896 -0.0575232 -0.0523995 -0.0483888 -0.0453308 -0.0430579 -0.0414097 -0.0402432 -0.0394374 -0.0388939 -0.0385362 -0.0383065 -0.0381626 -0.0380747 -0.0380225 -0.0379922 -0.0379751 -0.0379657 -0.0379607 -0.0379581 -0.0379568 -0.0379562 -0.0379559 -0.0379558 -0.0379557 -0.0379557 -0.0379557 -0.0379556 -0.0379556 -0.0379556 -0.0379556 -0.0379556 -0.0379556 -0.0379556 -0.0379556 -0.0379556 -0.0379556 -0.0379556 -0.0379556 -0.0379556 -0.0379556 -0.0379556 -0.0379556 -0.0379556 -0.0379556 -0.0379556 -0.0379556 -0.0379556 -0.0379556 -0.0379556 -0.0379556 -0.0379556 -0.0379556 -0.0379556 -0.0379556 -0.0379556 -0.0379556
Fitting the Symbolic Expression
We will follow the example from SymbolicRegression.jl docs to fit the symbolic expression.
srmodel = MultitargetSRRegressor(;
binary_operators=[+, -, *, /], niterations=100, save_to_file=false);
One important note here is to transpose the data because that is how MLJ expects the data to be structured (this is in contrast to how Lux or SymbolicRegression expects the data)
mach = machine(srmodel, X_train', Y_train')
fit!(mach; verbosity=0)
r = report(mach)
best_eq = [r.equations[1][r.best_idx[1]], r.equations[2][r.best_idx[2]],
r.equations[3][r.best_idx[3]], r.equations[4][r.best_idx[4]]]
4-element Vector{DynamicExpressions.EquationModule.Node{Float64}}:
(((((-0.12457569105519632 - x₂) * 0.2343670663440141) - x₃) * 1.066016984808304) / ((x₂ + 1.6008155566991769) / 0.5707867321223798)) * 1.4129843423818336
((((-0.7298122345092871 - x₂) - (x₄ / -0.8502922456672515)) * (x₃ - x₂)) - 0.022571730918171416) * 1.0774496467907178
x₂ + (((((0.7886489859741831 - x₄) - (x₂ * 1.5712426299551814)) + x₃) - x₂) + (x₁ * -0.1159778414043933))
((((x₂ * 3.9107757437862456) - -1.4055876452792275) * x₂) + -0.03813582748162943) + (x₁ * 0.11074465263055867)
Let's see the expressions that SymbolicRegression.jl found. In case you were wondering, these expressions are not hardcoded, it is live updated from the output of the code above using Latexify.jl
and the integration of SymbolicUtils.jl
with DynamicExpressions.jl
.
Combining the Neural Network with the Symbolic Expression
Now that we have the symbolic expression, we can combine it with the neural network to solve the optimal control problem. but we do need to perform some finetuning.
hybrid_mlp = Chain(Dense(1 => 4, gelu),
Layers.DynamicExpressionsLayer(OperatorEnum(; binary_operators=[+, -, *, /]), best_eq),
Dense(4 => 1))
Chain(
layer_1 = Dense(1 => 4, gelu), # 8 parameters
layer_2 = DynamicExpressionsLayer(
chain = Chain(
layer_1 = Parallel(
layer_1 = InternalDynamicExpressionWrapper(DynamicExpressions.OperatorEnumModule.OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{}}((+, -, *, /), ()), (((((-0.12457569105519632 - x₂) * 0.2343670663440141) - x₃) * 1.066016984808304) / ((x₂ + 1.6008155566991769) / 0.5707867321223798)) * 1.4129843423818336; eval_options=(turbo = Val{false}(), bumper = Val{false}())), # 6 parameters
layer_2 = InternalDynamicExpressionWrapper(DynamicExpressions.OperatorEnumModule.OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{}}((+, -, *, /), ()), ((((-0.7298122345092871 - x₂) - (x₄ / -0.8502922456672515)) * (x₃ - x₂)) - 0.022571730918171416) * 1.0774496467907178; eval_options=(turbo = Val{false}(), bumper = Val{false}())), # 4 parameters
layer_3 = InternalDynamicExpressionWrapper(DynamicExpressions.OperatorEnumModule.OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{}}((+, -, *, /), ()), x₂ + (((((0.7886489859741831 - x₄) - (x₂ * 1.5712426299551814)) + x₃) - x₂) + (x₁ * -0.1159778414043933)); eval_options=(turbo = Val{false}(), bumper = Val{false}())), # 3 parameters
layer_4 = InternalDynamicExpressionWrapper(DynamicExpressions.OperatorEnumModule.OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{}}((+, -, *, /), ()), ((((x₂ * 3.9107757437862456) - -1.4055876452792275) * x₂) + -0.03813582748162943) + (x₁ * 0.11074465263055867); eval_options=(turbo = Val{false}(), bumper = Val{false}())), # 4 parameters
),
layer_2 = WrappedFunction(stack1),
),
),
layer_3 = Dense(4 => 1), # 5 parameters
) # Total: 30 parameters,
# plus 0 states.
There you have it! It is that easy to take the fitted Symbolic Expression and combine it with a neural network. Let's see how it performs before fintetuning.
hybrid_ude = construct_ude(hybrid_mlp, Vern9(); abstol=1e-10, reltol=1e-10);
We want to reuse the trained neural network parameters, so we will copy them over to the new model
st = Lux.initialstates(rng, hybrid_ude)
ps = (;
mlp=(; layer_1=trained_ude.ps.mlp.layer_1,
layer_2=Lux.initialparameters(rng, hybrid_mlp[2]),
layer_3=trained_ude.ps.mlp.layer_3))
ps = ComponentArray(ps)
sol, us = hybrid_ude(([-4.0, 0.0], 0.0:0.01:8.0, Val(true)), ps, st)[1];
plot_dynamics(sol, us, 0.0:0.01:8.0)
Now that does perform well! But we could finetune this model very easily. We will skip that part on CI, but you can do it by using the same training code as above.
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.1
Commit 8f5b7ca12ad (2024-10-16 10:53 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 4 × AMD EPYC 7763 64-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)
Environment:
JULIA_NUM_THREADS = 1
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
This page was generated using Literate.jl.