Skip to content

Solving Optimal Control Problems with Symbolic Universal Differential Equations

This tutorial is based on SciMLSensitivity.jl tutorial. Instead of using a classical NN architecture, here we will combine the NN with a symbolic expression from DynamicExpressions.jl (the symbolic engine behind SymbolicRegression.jl and PySR).

Here we will solve a classic optimal control problem with a universal differential equation. Let

x=u3(t)

where we want to optimize our controller u(t) such that the following is minimized:

L(θ)=i(4x(ti)2+2x(ti)2+u(ti)2)

where i is measured on (0,8) at 0.01 intervals. To do this, we rewrite the ODE in first order form:

x=vv=u3(t)

and thus

L(θ)=i(4x(ti)2+2v(ti)2+u(ti)2)

is our loss function on the first order system. We thus choose a neural network form for u and optimize the equation with respect to this loss. Note that we will first reduce control cost (the last term) by 10x in order to bump the network out of a local minimum. This looks like:

Package Imports

julia
using Lux, Boltz, ComponentArrays, OrdinaryDiffEqVerner, Optimization, OptimizationOptimJL,
      OptimizationOptimisers, SciMLSensitivity, Statistics, Printf, Random
using DynamicExpressions, SymbolicRegression, MLJ, SymbolicUtils, Latexify
using CairoMakie
Precompiling Lux...
   1029.4 ms  ✓ ADTypes → ADTypesConstructionBaseExt
    903.8 ms  ✓ LuxCore → LuxCoreSetfieldExt
   1481.7 ms  ✓ MLDataDevices → MLDataDevicesChainRulesExt
   2753.4 ms  ✓ MLDataDevices → MLDataDevicesZygoteExt
   4641.9 ms  ✓ LazyArrays
   2463.8 ms  ✓ LazyArrays → LazyArraysStaticArraysExt
   6309.9 ms  ✓ LuxCore → LuxCoreArrayInterfaceReverseDiffExt
   6398.9 ms  ✓ MLDataDevices → MLDataDevicesReverseDiffExt
   1683.5 ms  ✓ Transducers → TransducersLazyArraysExt
   2292.3 ms  ✓ MLDataDevices → MLDataDevicesMLUtilsExt
   7346.9 ms  ✓ Tracker
   3269.5 ms  ✓ ArrayInterface → ArrayInterfaceTrackerExt
   3467.6 ms  ✓ LuxCore → LuxCoreArrayInterfaceTrackerExt
   3548.4 ms  ✓ MLDataDevices → MLDataDevicesTrackerExt
   3561.5 ms  ✓ Tracker → TrackerPDMatsExt
   1724.2 ms  ✓ SymbolicIndexingInterface
   3054.5 ms  ✓ Zygote → ZygoteTrackerExt
   2340.2 ms  ✓ RecursiveArrayTools
   1611.2 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsForwardDiffExt
   1611.8 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsSparseArraysExt
   1316.4 ms  ✓ MLDataDevices → MLDataDevicesRecursiveArrayToolsExt
   3402.0 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsTrackerExt
   5530.2 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsZygoteExt
   6808.5 ms  ✓ LuxLib → LuxLibTrackerExt
   9162.1 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsReverseDiffExt
   5430.7 ms  ✓ LuxLib → LuxLibSLEEFPiratesExt
   8952.3 ms  ✓ LuxLib → LuxLibReverseDiffExt
   6186.7 ms  ✓ LuxLib → LuxLibLoopVectorizationExt
 186575.5 ms  ✓ Enzyme
   4046.4 ms  ✓ LuxLib → LuxLibEnzymeExt
  10572.9 ms  ✓ Enzyme → EnzymeLogExpFunctionsExt
  11667.3 ms  ✓ Enzyme → EnzymeSpecialFunctionsExt
  14471.9 ms  ✓ Enzyme → EnzymeStaticArraysExt
  13699.6 ms  ✓ Lux
   5096.7 ms  ✓ Lux → LuxTrackerExt
   5487.4 ms  ✓ Lux → LuxMLUtilsExt
  25980.7 ms  ✓ Enzyme → EnzymeChainRulesCoreExt
   8571.9 ms  ✓ Lux → LuxReverseDiffExt
   4669.3 ms  ✓ Lux → LuxZygoteExt
  10335.7 ms  ✓ Lux → LuxEnzymeExt
  40 dependencies successfully precompiled in 215 seconds. 244 already precompiled.
Precompiling Boltz...
    958.2 ms  ✓ ComponentArrays
   1049.4 ms  ✓ ComponentArrays → ComponentArraysConstructionBaseExt
   1057.4 ms  ✓ ComponentArrays → ComponentArraysAdaptExt
   1140.5 ms  ✓ ComponentArrays → ComponentArraysOptimisersExt
   1372.1 ms  ✓ ComponentArrays → ComponentArraysRecursiveArrayToolsExt
   3390.1 ms  ✓ Lux → LuxLossFunctionsExt
   2991.4 ms  ✓ ComponentArrays → ComponentArraysGPUArraysExt
   3028.5 ms  ✓ ComponentArrays → ComponentArraysTrackerExt
   2217.9 ms  ✓ ComponentArrays → ComponentArraysZygoteExt
   2937.6 ms  ✓ Lux → LuxComponentArraysExt
   4622.9 ms  ✓ ComponentArrays → ComponentArraysReverseDiffExt
   5231.3 ms  ✓ Boltz
   2867.6 ms  ✓ Boltz → BoltzZygoteExt
   3514.1 ms  ✓ Boltz → BoltzTrackerExt
   5484.4 ms  ✓ Boltz → BoltzReverseDiffExt
  15 dependencies successfully precompiled in 18 seconds. 285 already precompiled.
Precompiling OrdinaryDiffEqVerner...
   1505.5 ms  ✓ XML2_jll
   1699.5 ms  ✓ Fontconfig_jll
   2669.5 ms  ✓ QOI
   3000.0 ms  ✓ OpenEXR
   4083.8 ms  ✓ FreeTypeAbstraction
   3307.3 ms  ✓ FastPower → FastPowerTrackerExt
   1600.1 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsFastBroadcastExt
   1479.8 ms  ✓ Gettext_jll
   1499.6 ms  ✓ XSLT_jll
  11177.2 ms  ✓ FastPower → FastPowerEnzymeExt
  11516.9 ms  ✓ QuadGK → QuadGKEnzymeExt
   1940.2 ms  ✓ Glib_jll
   2148.2 ms  ✓ Xorg_libxcb_jll
   1429.0 ms  ✓ Xorg_libX11_jll
   1466.3 ms  ✓ Xorg_libXrender_jll
   1475.7 ms  ✓ Xorg_libXext_jll
   1574.7 ms  ✓ Cairo_jll
   1973.0 ms  ✓ Libglvnd_jll
   1991.8 ms  ✓ HarfBuzz_jll
   2128.1 ms  ✓ libwebp_jll
   2251.3 ms  ✓ libass_jll
   2813.6 ms  ✓ FFMPEG_jll
  22322.6 ms  ✓ SciMLBase
  35778.1 ms  ✓ ImageCore
   2676.4 ms  ✓ SciMLBase → SciMLBaseChainRulesCoreExt
   5016.2 ms  ✓ ImageBase
  29231.1 ms  ✓ MathTeXEngine
   7000.5 ms  ✓ SciMLBase → SciMLBaseZygoteExt
   8413.2 ms  ✓ PNGFiles
   4815.6 ms  ✓ WebP
   6191.2 ms  ✓ JpegTurbo
   7033.3 ms  ✓ Sixel
   2732.9 ms  ✓ ImageAxes
   1197.0 ms  ✓ ImageMetadata
   1831.5 ms  ✓ Netpbm
  84015.9 ms  ✓ TiffImages
   1269.7 ms  ✓ ImageIO
 138666.4 ms  ✓ Makie
   8199.6 ms  ✓ SciMLBase → SciMLBaseMakieExt
   5526.1 ms  ✓ DiffEqBase
   2514.3 ms  ✓ DiffEqBase → DiffEqBaseChainRulesCoreExt
   2838.5 ms  ✓ DiffEqBase → DiffEqBaseSparseArraysExt
   3212.2 ms  ✓ DiffEqBase → DiffEqBaseUnitfulExt
   3583.0 ms  ✓ DiffEqBase → DiffEqBaseDistributionsExt
   4662.5 ms  ✓ DiffEqBase → DiffEqBaseTrackerExt
   4817.7 ms  ✓ DiffEqBase → DiffEqBaseReverseDiffExt
   9890.8 ms  ✓ DiffEqBase → DiffEqBaseEnzymeExt
   4199.5 ms  ✓ OrdinaryDiffEqCore
   1447.2 ms  ✓ OrdinaryDiffEqCore → OrdinaryDiffEqCoreEnzymeCoreExt
  34960.4 ms  ✓ OrdinaryDiffEqVerner
  50 dependencies successfully precompiled in 291 seconds. 355 already precompiled.
Precompiling ComponentArraysSciMLBaseExt...
   1083.1 ms  ✓ ComponentArrays → ComponentArraysSciMLBaseExt
  1 dependency successfully precompiled in 2 seconds. 209 already precompiled.
Precompiling Optimization...
   1025.1 ms  ✓ DifferentiationInterface
   1985.2 ms  ✓ SparseMatrixColorings
   1505.6 ms  ✓ DifferentiationInterface → DifferentiationInterfaceForwardDiffExt
    951.1 ms  ✓ DifferentiationInterface → DifferentiationInterfaceFiniteDiffExt
   1175.0 ms  ✓ DifferentiationInterface → DifferentiationInterfaceStaticArraysExt
   3412.6 ms  ✓ DifferentiationInterface → DifferentiationInterfaceTrackerExt
    885.4 ms  ✓ DifferentiationInterface → DifferentiationInterfaceChainRulesCoreExt
   1438.2 ms  ✓ DifferentiationInterface → DifferentiationInterfaceSparseArraysExt
   6635.4 ms  ✓ SparseConnectivityTracer
   2848.1 ms  ✓ DifferentiationInterface → DifferentiationInterfaceZygoteExt
   1679.6 ms  ✓ SparseMatrixColorings → SparseMatrixColoringsColorsExt
   1658.8 ms  ✓ DifferentiationInterface → DifferentiationInterfaceSparseMatrixColoringsExt
   6603.6 ms  ✓ DifferentiationInterface → DifferentiationInterfaceReverseDiffExt
   2215.5 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerNaNMathExt
   2270.1 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerLogExpFunctionsExt
   3787.3 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerNNlibExt
   2369.8 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerSpecialFunctionsExt
  10654.5 ms  ✓ DifferentiationInterface → DifferentiationInterfaceEnzymeExt
   2040.7 ms  ✓ OptimizationBase
    850.3 ms  ✓ OptimizationBase → OptimizationFiniteDiffExt
   2555.7 ms  ✓ OptimizationBase → OptimizationMLDataDevicesExt
   1415.9 ms  ✓ OptimizationBase → OptimizationForwardDiffExt
   3737.6 ms  ✓ OptimizationBase → OptimizationZygoteExt
   4917.0 ms  ✓ OptimizationBase → OptimizationMLUtilsExt
   5229.5 ms  ✓ OptimizationBase → OptimizationReverseDiffExt
  17972.7 ms  ✓ OptimizationBase → OptimizationEnzymeExt
   1871.8 ms  ✓ Optimization
  27 dependencies successfully precompiled in 34 seconds. 429 already precompiled.
Precompiling OptimizationOptimJL...
  15298.0 ms  ✓ OptimizationOptimJL
  1 dependency successfully precompiled in 16 seconds. 462 already precompiled.
Precompiling OptimizationOptimisers...
   1793.0 ms  ✓ OptimizationOptimisers
  1 dependency successfully precompiled in 3 seconds. 456 already precompiled.
Precompiling SciMLSensitivity...
   3814.0 ms  ✓ SciMLJacobianOperators
   5325.3 ms  ✓ DiffEqNoiseProcess
   6035.6 ms  ✓ DiffEqCallbacks
   4433.9 ms  ✓ DiffEqNoiseProcess → DiffEqNoiseProcessReverseDiffExt
  33319.7 ms  ✓ LinearSolve
   2349.1 ms  ✓ LinearSolve → LinearSolveRecursiveArrayToolsExt
   3294.7 ms  ✓ LinearSolve → LinearSolveEnzymeExt
   5368.0 ms  ✓ LinearSolve → LinearSolveKernelAbstractionsExt
  23889.6 ms  ✓ SciMLSensitivity
  9 dependencies successfully precompiled in 63 seconds. 447 already precompiled.
Precompiling BoltzDynamicExpressionsExt...
   2123.7 ms  ✓ Boltz → BoltzDynamicExpressionsExt
  1 dependency successfully precompiled in 3 seconds. 304 already precompiled.
Precompiling SymbolicUtils...
   1459.7 ms  ✓ LabelledArrays
  18209.3 ms  ✓ SymbolicUtils
  2 dependencies successfully precompiled in 20 seconds. 197 already precompiled.
Precompiling DynamicExpressionsSymbolicUtilsExt...
   1829.1 ms  ✓ DynamicExpressions → DynamicExpressionsSymbolicUtilsExt
  1 dependency successfully precompiled in 2 seconds. 202 already precompiled.
Precompiling SymbolicRegressionSymbolicUtilsExt...
   3499.8 ms  ✓ SymbolicRegression → SymbolicRegressionSymbolicUtilsExt
  1 dependency successfully precompiled in 4 seconds. 276 already precompiled.
Precompiling CairoMakie...
    819.3 ms  ✓ Pango_jll
   1318.3 ms  ✓ Cairo
  80356.4 ms  ✓ CairoMakie
  3 dependencies successfully precompiled in 83 seconds. 293 already precompiled.

Helper Functions

julia
function plot_dynamics(sol, us, ts)
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel=L"t")
    ylims!(ax, (-6, 6))

    lines!(ax, ts, sol[1, :]; label=L"u_1(t)", linewidth=3)
    lines!(ax, ts, sol[2, :]; label=L"u_2(t)", linewidth=3)

    lines!(ax, ts, vec(us); label=L"u(t)", linewidth=3)

    axislegend(ax; position=:rb)

    return fig
end
plot_dynamics (generic function with 1 method)

Training a Neural Network based UDE

Let's setup the neural network. For the first part, we won't do any symbolic regression. We will plain and simple train a neural network to solve the optimal control problem.

julia
rng = Xoshiro(0)
tspan = (0.0, 8.0)

mlp = Chain(Dense(1 => 4, gelu), Dense(4 => 4, gelu), Dense(4 => 1))

function construct_ude(mlp, solver; kwargs...)
    return @compact(; mlp, solver, kwargs...) do x_in, ps
        x, ts, ret_sol = x_in

        function dudt(du, u, p, t)
            u₁, u₂ = u
            du[1] = u₂
            du[2] = mlp([t], p)[1]^3
            return
        end

        prob = ODEProblem{true}(dudt, x, extrema(ts), ps.mlp)

        sol = solve(prob, solver; saveat=ts,
            sensealg=QuadratureAdjoint(; autojacvec=ReverseDiffVJP(true)), kwargs...)

        us = mlp(reshape(ts, 1, :), ps.mlp)
        ret_sol === Val(true) && @return sol, us
        @return Array(sol), us
    end
end

ude = construct_ude(mlp, Vern9(); abstol=1e-10, reltol=1e-10);

Here we are going to tuse the same configuration for testing, but this is to show that we can setup them up with different ode solve configurations

julia
ude_test = construct_ude(mlp, Vern9(); abstol=1e-10, reltol=1e-10);

function train_model_1(ude, rng, ts_)
    ps, st = Lux.setup(rng, ude)
    ps = ComponentArray{Float64}(ps)
    stateful_ude = StatefulLuxLayer{true}(ude, nothing, st)

    ts = collect(ts_)

    function loss_adjoint(θ)
        x, us = stateful_ude(([-4.0, 0.0], ts, Val(false)), θ)
        return mean(abs2, 4 .- x[1, :]) + 2 * mean(abs2, x[2, :]) + 0.1 * mean(abs2, us)
    end

    callback = function (state, l)
        state.iter % 50 == 1 && @printf "Iteration: %5d\tLoss: %10g\n" state.iter l
        return false
    end

    optf = OptimizationFunction((x, p) -> loss_adjoint(x), AutoZygote())
    optprob = OptimizationProblem(optf, ps)
    res1 = solve(optprob, Optimisers.Adam(0.001); callback, maxiters=500)

    optprob = OptimizationProblem(optf, res1.u)
    res2 = solve(optprob, LBFGS(); callback, maxiters=100)

    return StatefulLuxLayer{true}(ude, res2.u, st)
end

trained_ude = train_model_1(ude, rng, 0.0:0.01:8.0)
┌ Warning: Lux.apply(m::AbstractLuxLayer, x::AbstractArray{<:ReverseDiff.TrackedReal}, ps, st) input was corrected to Lux.apply(m::AbstractLuxLayer, x::ReverseDiff.TrackedArray}, ps, st).

│ 1. If this was not the desired behavior overload the dispatch on `m`.

│ 2. This might have performance implications. Check which layer was causing this problem using `Lux.Experimental.@debug_mode`.
└ @ LuxCoreArrayInterfaceReverseDiffExt ~/.julia/packages/LuxCore/SN4dl/ext/LuxCoreArrayInterfaceReverseDiffExt.jl:10
Iteration:     1	Loss:    40.5618
Iteration:    51	Loss:    29.4147
Iteration:   101	Loss:    28.2559
Iteration:   151	Loss:     27.217
Iteration:   201	Loss:    26.1657
Iteration:   251	Loss:    25.1631
Iteration:   301	Loss:    24.2914
Iteration:   351	Loss:    23.5965
Iteration:   401	Loss:    23.0763
Iteration:   451	Loss:    22.6983
Iteration:     1	Loss:    22.2401
Iteration:    51	Loss:     11.981
julia
sol, us = ude_test(([-4.0, 0.0], 0.0:0.01:8.0, Val(true)), trained_ude.ps, trained_ude.st)[1];
plot_dynamics(sol, us, 0.0:0.01:8.0)

Now that the system is in a better behaved part of parameter space, we return to the original loss function to finish the optimization:

julia
function train_model_2(stateful_ude::StatefulLuxLayer, ts_)
    ts = collect(ts_)

    function loss_adjoint(θ)
        x, us = stateful_ude(([-4.0, 0.0], ts, Val(false)), θ)
        return mean(abs2, 4 .- x[1, :]) .+ 2 * mean(abs2, x[2, :]) .+ mean(abs2, us)
    end

    callback = function (state, l)
        state.iter % 10 == 1 && @printf "Iteration: %5d\tLoss: %10g\n" state.iter l
        return false
    end

    optf = OptimizationFunction((x, p) -> loss_adjoint(x), AutoZygote())
    optprob = OptimizationProblem(optf, stateful_ude.ps)
    res2 = solve(optprob, LBFGS(); callback, maxiters=100)

    return StatefulLuxLayer{true}(stateful_ude.model, res2.u, stateful_ude.st)
end

trained_ude = train_model_2(trained_ude, 0.0:0.01:8.0)
┌ Warning: Lux.apply(m::AbstractLuxLayer, x::AbstractArray{<:ReverseDiff.TrackedReal}, ps, st) input was corrected to Lux.apply(m::AbstractLuxLayer, x::ReverseDiff.TrackedArray}, ps, st).

│ 1. If this was not the desired behavior overload the dispatch on `m`.

│ 2. This might have performance implications. Check which layer was causing this problem using `Lux.Experimental.@debug_mode`.
└ @ LuxCoreArrayInterfaceReverseDiffExt ~/.julia/packages/LuxCore/SN4dl/ext/LuxCoreArrayInterfaceReverseDiffExt.jl:10
Iteration:     1	Loss:    12.7005
Iteration:    11	Loss:    12.6854
Iteration:    21	Loss:    12.6642
Iteration:    31	Loss:    12.6517
Iteration:    41	Loss:    12.6409
Iteration:    51	Loss:    12.6279
Iteration:    61	Loss:     12.618
Iteration:    71	Loss:    12.6067
Iteration:    81	Loss:    12.5826
Iteration:    91	Loss:    12.5693
julia
sol, us = ude_test(([-4.0, 0.0], 0.0:0.01:8.0, Val(true)), trained_ude.ps, trained_ude.st)[1];
plot_dynamics(sol, us, 0.0:0.01:8.0)

Symbolic Regression

Ok so now we have a trained neural network that solves the optimal control problem. But can we replace Dense(4 => 4, gelu) with a symbolic expression? Let's try!

Data Generation for Symbolic Regression

First, we need to generate data for the symbolic regression.

julia
ts = reshape(collect(0.0:0.1:8.0), 1, :)

X_train = mlp[1](ts, trained_ude.ps.mlp.layer_1, trained_ude.st.mlp.layer_1)[1]
4×81 Matrix{Float64}:
 -0.126364   -0.113737   -0.101013  -0.088566   -0.0766875  -0.0655939  -0.0554318   -0.0462866   -0.0381916   -0.031138     -0.025084     -0.0199636   -0.0156948    -0.0121862    -0.00934305  -0.00707156  -0.00528246  -0.00389342  -0.00283057  -0.0020292   -0.001434     -0.000998613  -0.000685048  -0.00046277   -0.000307734  -0.000201368  -0.000129614  -8.2034e-5    -5.10331e-5   -3.11933e-5   -1.87264e-5   -1.10373e-5   -6.38434e-6   -3.62282e-6   -2.01597e-6   -1.09965e-6   -5.87746e-7   -3.07691e-7   -1.5771e-7    -7.91131e-8   -3.88248e-8   -1.86324e-8  -8.74084e-9  -4.0067e-9   -1.79389e-9  -7.84159e-10  -3.34532e-10  -1.39226e-10  -5.65039e-11  -2.23529e-11  -8.6161e-12  -3.2347e-12  -1.1823e-12   -4.20545e-13  -1.45517e-13  -4.89617e-14  -1.60126e-14  -5.08804e-15  -1.57018e-15  -4.70413e-16  -1.36761e-16  -3.85676e-17  -1.05458e-17  -2.79487e-18  -7.17604e-19  -1.78433e-19  -4.29488e-20  -1.00032e-20  -2.2535e-21   -4.90829e-22  -1.03319e-22  -2.10099e-23  -4.12562e-24  -7.8198e-25   -1.43009e-25  -2.52242e-26  -4.2892e-27   -7.0285e-28   -1.10942e-28  -1.68618e-29  -2.46661e-30
  0.320806    0.223676    0.136649   0.0604756  -0.0043859  -0.0577953  -0.0999407   -0.131326    -0.152741    -0.165206     -0.169916     -0.16817     -0.1613       -0.150608     -0.137308    -0.122485    -0.107066    -0.0918026   -0.0772691   -0.0638734   -0.0518714    -0.0413899    -0.0324512    -0.0249974    -0.0189152    -0.0140561    -0.0102543    -0.00734124   -0.00515532   -0.00354938   -0.00239457   -0.00158211   -0.00102311   -0.000647175  -0.00040018   -0.000241738  -0.000142562  -8.20247e-5   -4.60118e-5   -2.51468e-5   -1.33809e-5   -6.92754e-6  -3.48704e-6  -1.70536e-6  -8.09753e-7  -3.73041e-7   -1.66618e-7   -7.21007e-8   -3.02064e-8   -1.22431e-8   -4.79739e-9  -1.81605e-9  -6.63662e-10  -2.33965e-10  -7.95106e-11  -2.60289e-11  -8.2022e-12   -2.48618e-12  -7.24349e-13  -2.02704e-13  -5.44449e-14  -1.40256e-14  -3.46284e-15  -8.18804e-16  -1.85287e-16  -4.00968e-17  -8.29198e-18  -1.63747e-18  -3.08559e-19  -5.54413e-20  -9.49168e-21  -1.54721e-21  -2.39956e-22  -3.53814e-23  -4.95633e-24  -6.59127e-25  -8.31538e-26  -9.94447e-27  -1.12655e-27  -1.208e-28    -1.22521e-29
 -0.16266    -0.142434   -0.116028  -0.0886899  -0.0639064  -0.0435037  -0.0279947   -0.0170192   -0.00976053  -0.0052691    -0.00267028   -0.00126644  -0.000560214  -0.000230301  -8.76569e-5  -3.0772e-5   -9.92434e-6  -2.92889e-6  -7.87805e-7  -1.92352e-7  -4.24592e-8   -8.43867e-9   -1.50391e-9   -2.39349e-10  -3.38773e-11  -4.24678e-12  -4.69551e-13  -4.56012e-14  -3.87376e-15  -2.86645e-16  -1.83991e-17  -1.02018e-18  -4.86601e-20  -1.98822e-21  -6.92998e-23  -2.05191e-24  -5.13949e-26  -1.08442e-27  -1.91939e-29  -2.83789e-31  -3.49036e-33  -0.0         -0.0         -0.0         -0.0         -0.0          -0.0          -0.0          -0.0          -0.0          -0.0         -0.0         -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0
 -0.0971689  -0.0708357  -0.048652  -0.0315161  -0.0192489  -0.0110697  -0.00598116  -0.00302802  -0.00143173  -0.000630012  -0.000257027  -9.68351e-5  -3.35541e-5   -1.06492e-5   -3.08264e-6  -8.10436e-7  -1.92684e-7  -4.12511e-8  -7.91789e-9  -1.35669e-9  -2.06611e-10  -2.7844e-11   -3.30605e-12  -3.44332e-13  -3.13202e-14  -2.47705e-15  -1.69586e-16  -1.00062e-17  -5.06584e-19  -2.19083e-20  -8.05775e-22  -2.50923e-23  -6.58656e-25  -1.4509e-26   -2.67019e-28  -4.08736e-30  -5.1809e-32   -5.41368e-34  -0.0          -0.0          -0.0          -0.0         -0.0         -0.0         -0.0         -0.0          -0.0          -0.0          -0.0          -0.0          -0.0         -0.0         -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0

This is the training input data. Now we generate the targets

julia
Y_train = mlp[2](X_train, trained_ude.ps.mlp.layer_2, trained_ude.st.mlp.layer_2)[1]
4×81 Matrix{Float64}:
 0.0265776  0.0287949  0.0268332  0.0230741   0.0191051   0.0157478   0.0132492   0.0115036   0.0102435   0.00917577   0.00806274   0.00675754   0.00520492   0.00342205   0.00147185  -0.000562454  -0.00259652  -0.00455687  -0.00638661  -0.00804687  -0.00951551  -0.0107846  -0.0118575  -0.0127457  -0.0134662  -0.0140389  -0.0144852  -0.0148261  -0.0150812  -0.0152682  -0.0154025  -0.0154968  -0.0155616  -0.0156052  -0.0156338  -0.0156521  -0.0156636  -0.0156706  -0.0156747  -0.0156771  -0.0156785  -0.0156792  -0.0156796  -0.0156798  -0.0156799  -0.01568    -0.01568    -0.01568    -0.01568    -0.01568    -0.01568    -0.01568    -0.01568    -0.01568    -0.01568    -0.01568    -0.01568    -0.01568    -0.01568    -0.01568    -0.01568    -0.01568    -0.01568    -0.01568    -0.01568    -0.01568    -0.01568    -0.01568    -0.01568    -0.01568    -0.01568    -0.01568    -0.01568    -0.01568    -0.01568    -0.01568    -0.01568    -0.01568    -0.01568    -0.01568    -0.01568
 0.584846   0.385835   0.226072   0.106037    0.0208161  -0.0369618  -0.0746977  -0.0985211  -0.112936   -0.120989    -0.124619    -0.125013    -0.122896    -0.118746    -0.112935    -0.105816     -0.0977562   -0.0891399   -0.0803507   -0.0717428   -0.0636151   -0.0561935  -0.049624   -0.0439755  -0.0392505  -0.0354001  -0.0323404  -0.0299676  -0.0281706  -0.0268413  -0.0258805  -0.025202   -0.0247339  -0.0244185  -0.024211   -0.0240778  -0.0239943  -0.0239434  -0.023913   -0.0238955  -0.0238855  -0.0238801  -0.0238772  -0.0238757  -0.0238749  -0.0238746  -0.0238744  -0.0238743  -0.0238743  -0.0238743  -0.0238743  -0.0238743  -0.0238743  -0.0238743  -0.0238743  -0.0238743  -0.0238743  -0.0238743  -0.0238743  -0.0238743  -0.0238743  -0.0238743  -0.0238743  -0.0238743  -0.0238743  -0.0238743  -0.0238743  -0.0238743  -0.0238743  -0.0238743  -0.0238743  -0.0238743  -0.0238743  -0.0238743  -0.0238743  -0.0238743  -0.0238743  -0.0238743  -0.0238743  -0.0238743  -0.0238743
 0.234934   0.37719    0.517749   0.647565    0.760965    0.855234    0.929779    0.985351    1.02349     1.04617      1.05559      1.05407      1.04389      1.02728      1.00629      0.982733      0.958166     0.933837     0.910693     0.889395     0.870353     0.853758    0.839636    0.827881    0.818306    0.810666    0.804696    0.800125    0.796697    0.794179    0.79237     0.791097    0.790221    0.789632    0.789245    0.788997    0.788841    0.788746    0.78869     0.788657    0.788639    0.788629    0.788623    0.788621    0.788619    0.788619    0.788618    0.788618    0.788618    0.788618    0.788618    0.788618    0.788618    0.788618    0.788618    0.788618    0.788618    0.788618    0.788618    0.788618    0.788618    0.788618    0.788618    0.788618    0.788618    0.788618    0.788618    0.788618    0.788618    0.788618    0.788618    0.788618    0.788618    0.788618    0.788618    0.788618    0.788618    0.788618    0.788618    0.788618    0.788618
 0.798978   0.466864   0.217185   0.0480198  -0.0559418  -0.114347   -0.144489   -0.15878    -0.164965   -0.167342    -0.167978    -0.167615    -0.166271    -0.163642    -0.159372    -0.153233     -0.145218    -0.135568    -0.124725    -0.113252    -0.101734    -0.0906964  -0.0805516  -0.0715713  -0.0638896  -0.0575232  -0.0523995  -0.0483888  -0.0453308  -0.0430579  -0.0414097  -0.0402432  -0.0394374  -0.0388939  -0.0385362  -0.0383065  -0.0381626  -0.0380747  -0.0380225  -0.0379922  -0.0379751  -0.0379657  -0.0379607  -0.0379581  -0.0379568  -0.0379562  -0.0379559  -0.0379558  -0.0379557  -0.0379557  -0.0379557  -0.0379556  -0.0379556  -0.0379556  -0.0379556  -0.0379556  -0.0379556  -0.0379556  -0.0379556  -0.0379556  -0.0379556  -0.0379556  -0.0379556  -0.0379556  -0.0379556  -0.0379556  -0.0379556  -0.0379556  -0.0379556  -0.0379556  -0.0379556  -0.0379556  -0.0379556  -0.0379556  -0.0379556  -0.0379556  -0.0379556  -0.0379556  -0.0379556  -0.0379556  -0.0379556

Fitting the Symbolic Expression

We will follow the example from SymbolicRegression.jl docs to fit the symbolic expression.

julia
srmodel = MultitargetSRRegressor(;
    binary_operators=[+, -, *, /], niterations=100, save_to_file=false);

One important note here is to transpose the data because that is how MLJ expects the data to be structured (this is in contrast to how Lux or SymbolicRegression expects the data)

julia
mach = machine(srmodel, X_train', Y_train')
fit!(mach; verbosity=0)
r = report(mach)
best_eq = [r.equations[1][r.best_idx[1]], r.equations[2][r.best_idx[2]],
    r.equations[3][r.best_idx[3]], r.equations[4][r.best_idx[4]]]
4-element Vector{DynamicExpressions.EquationModule.Node{Float64}}:
 (((((-0.12457569105519632 - x₂) * 0.2343670663440141) - x₃) * 1.066016984808304) / ((x₂ + 1.6008155566991769) / 0.5707867321223798)) * 1.4129843423818336
 ((((-0.7298122345092871 - x₂) - (x₄ / -0.8502922456672515)) * (x₃ - x₂)) - 0.022571730918171416) * 1.0774496467907178
 x₂ + (((((0.7886489859741831 - x₄) - (x₂ * 1.5712426299551814)) + x₃) - x₂) + (x₁ * -0.1159778414043933))
 ((((x₂ * 3.9107757437862456) - -1.4055876452792275) * x₂) + -0.03813582748162943) + (x₁ * 0.11074465263055867)

Let's see the expressions that SymbolicRegression.jl found. In case you were wondering, these expressions are not hardcoded, it is live updated from the output of the code above using Latexify.jl and the integration of SymbolicUtils.jl with DynamicExpressions.jl.

((0.12458x2)0.23437x3)1.066x2+1.60080.570791.413((0.72981x2x40.85029)(x3x2)0.022572)1.0774x2+0.78865x4x21.5712+x3x2+x10.11598(x23.9108+1.4056)x20.038136+x10.11074

Combining the Neural Network with the Symbolic Expression

Now that we have the symbolic expression, we can combine it with the neural network to solve the optimal control problem. but we do need to perform some finetuning.

julia
hybrid_mlp = Chain(Dense(1 => 4, gelu),
    Layers.DynamicExpressionsLayer(OperatorEnum(; binary_operators=[+, -, *, /]), best_eq),
    Dense(4 => 1))
Chain(
    layer_1 = Dense(1 => 4, gelu),      # 8 parameters
    layer_2 = DynamicExpressionsLayer(
        chain = Chain(
            layer_1 = Parallel(
                layer_1 = InternalDynamicExpressionWrapper(DynamicExpressions.OperatorEnumModule.OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{}}((+, -, *, /), ()), (((((-0.12457569105519632 - x₂) * 0.2343670663440141) - x₃) * 1.066016984808304) / ((x₂ + 1.6008155566991769) / 0.5707867321223798)) * 1.4129843423818336; eval_options=(turbo = Val{false}(), bumper = Val{false}())),  # 6 parameters
                layer_2 = InternalDynamicExpressionWrapper(DynamicExpressions.OperatorEnumModule.OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{}}((+, -, *, /), ()), ((((-0.7298122345092871 - x₂) - (x₄ / -0.8502922456672515)) * (x₃ - x₂)) - 0.022571730918171416) * 1.0774496467907178; eval_options=(turbo = Val{false}(), bumper = Val{false}())),  # 4 parameters
                layer_3 = InternalDynamicExpressionWrapper(DynamicExpressions.OperatorEnumModule.OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{}}((+, -, *, /), ()), x₂ + (((((0.7886489859741831 - x₄) - (x₂ * 1.5712426299551814)) + x₃) - x₂) + (x₁ * -0.1159778414043933)); eval_options=(turbo = Val{false}(), bumper = Val{false}())),  # 3 parameters
                layer_4 = InternalDynamicExpressionWrapper(DynamicExpressions.OperatorEnumModule.OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{}}((+, -, *, /), ()), ((((x₂ * 3.9107757437862456) - -1.4055876452792275) * x₂) + -0.03813582748162943) + (x₁ * 0.11074465263055867); eval_options=(turbo = Val{false}(), bumper = Val{false}())),  # 4 parameters
            ),
            layer_2 = WrappedFunction(stack1),
        ),
    ),
    layer_3 = Dense(4 => 1),            # 5 parameters
)         # Total: 30 parameters,
          #        plus 0 states.

There you have it! It is that easy to take the fitted Symbolic Expression and combine it with a neural network. Let's see how it performs before fintetuning.

julia
hybrid_ude = construct_ude(hybrid_mlp, Vern9(); abstol=1e-10, reltol=1e-10);

We want to reuse the trained neural network parameters, so we will copy them over to the new model

julia
st = Lux.initialstates(rng, hybrid_ude)
ps = (;
    mlp=(; layer_1=trained_ude.ps.mlp.layer_1,
        layer_2=Lux.initialparameters(rng, hybrid_mlp[2]),
        layer_3=trained_ude.ps.mlp.layer_3))
ps = ComponentArray(ps)

sol, us = hybrid_ude(([-4.0, 0.0], 0.0:0.01:8.0, Val(true)), ps, st)[1];
plot_dynamics(sol, us, 0.0:0.01:8.0)

Now that does perform well! But we could finetune this model very easily. We will skip that part on CI, but you can do it by using the same training code as above.

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.1
Commit 8f5b7ca12ad (2024-10-16 10:53 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)
Environment:
  JULIA_NUM_THREADS = 1
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.