Skip to content

Solving Optimal Control Problems with Symbolic Universal Differential Equations

This tutorial is based on SciMLSensitivity.jl tutorial. Instead of using a classical NN architecture, here we will combine the NN with a symbolic expression from DynamicExpressions.jl (the symbolic engine behind SymbolicRegression.jl and PySR).

Here we will solve a classic optimal control problem with a universal differential equation. Let

x=u3(t)

where we want to optimize our controller u(t) such that the following is minimized:

L(θ)=i(4x(ti)2+2x(ti)2+u(ti)2)

where i is measured on (0,8) at 0.01 intervals. To do this, we rewrite the ODE in first order form:

x=vv=u3(t)

and thus

L(θ)=i(4x(ti)2+2v(ti)2+u(ti)2)

is our loss function on the first order system. We thus choose a neural network form for u and optimize the equation with respect to this loss. Note that we will first reduce control cost (the last term) by 10x in order to bump the network out of a local minimum. This looks like:

Package Imports

julia
using Lux,
    Boltz,
    ComponentArrays,
    OrdinaryDiffEqVerner,
    Optimization,
    OptimizationOptimJL,
    OptimizationOptimisers,
    SciMLSensitivity,
    Statistics,
    Printf,
    Random
using DynamicExpressions, SymbolicRegression, MLJ, SymbolicUtils, Latexify
using CairoMakie
Precompiling OrdinaryDiffEqVerner...
    353.0 ms  ✓ FastPower
   9786.0 ms  ✓ SciMLBase
   2514.0 ms  ✓ DiffEqBase
   3737.9 ms  ✓ OrdinaryDiffEqCore
   1192.7 ms  ✓ OrdinaryDiffEqCore → OrdinaryDiffEqCoreEnzymeCoreExt
  40301.6 ms  ✓ OrdinaryDiffEqVerner
  6 dependencies successfully precompiled in 58 seconds. 89 already precompiled.
Precompiling FastPowerForwardDiffExt...
    525.7 ms  ✓ FastPower → FastPowerForwardDiffExt
  1 dependency successfully precompiled in 1 seconds. 21 already precompiled.
Precompiling ComponentArraysSciMLBaseExt...
    780.7 ms  ✓ SciMLBase → SciMLBaseChainRulesCoreExt
    837.5 ms  ✓ ComponentArrays → ComponentArraysSciMLBaseExt
  2 dependencies successfully precompiled in 1 seconds. 67 already precompiled.
Precompiling DiffEqBaseForwardDiffExt...
   1323.6 ms  ✓ DiffEqBase → DiffEqBaseForwardDiffExt
  1 dependency successfully precompiled in 2 seconds. 104 already precompiled.
Precompiling DiffEqBaseChainRulesCoreExt...
   1118.9 ms  ✓ DiffEqBase → DiffEqBaseChainRulesCoreExt
  1 dependency successfully precompiled in 2 seconds. 93 already precompiled.
Precompiling Optimization...
    785.4 ms  ✓ PDMats
   1021.1 ms  ✓ DifferentiationInterface
    786.4 ms  ✓ FillArrays → FillArraysPDMatsExt
    888.3 ms  ✓ DifferentiationInterface → DifferentiationInterfaceSparseArraysExt
    927.3 ms  ✓ DifferentiationInterface → DifferentiationInterfaceSparseMatrixColoringsExt
    982.3 ms  ✓ DifferentiationInterface → DifferentiationInterfaceSparseConnectivityTracerExt
   1908.2 ms  ✓ OptimizationBase
   1759.2 ms  ✓ Optimization
  8 dependencies successfully precompiled in 6 seconds. 99 already precompiled.
Precompiling DiffEqBaseSparseArraysExt...
   1170.0 ms  ✓ DiffEqBase → DiffEqBaseSparseArraysExt
  1 dependency successfully precompiled in 1 seconds. 96 already precompiled.
Precompiling DifferentiationInterfaceChainRulesCoreExt...
    381.2 ms  ✓ DifferentiationInterface → DifferentiationInterfaceChainRulesCoreExt
  1 dependency successfully precompiled in 0 seconds. 11 already precompiled.
Precompiling DifferentiationInterfaceStaticArraysExt...
    542.2 ms  ✓ DifferentiationInterface → DifferentiationInterfaceStaticArraysExt
  1 dependency successfully precompiled in 1 seconds. 10 already precompiled.
Precompiling DifferentiationInterfaceForwardDiffExt...
    705.1 ms  ✓ DifferentiationInterface → DifferentiationInterfaceForwardDiffExt
  1 dependency successfully precompiled in 1 seconds. 22 already precompiled.
Precompiling OptimizationForwardDiffExt...
    530.6 ms  ✓ OptimizationBase → OptimizationForwardDiffExt
  1 dependency successfully precompiled in 1 seconds. 98 already precompiled.
Precompiling OptimizationMLDataDevicesExt...
   1164.4 ms  ✓ OptimizationBase → OptimizationMLDataDevicesExt
  1 dependency successfully precompiled in 1 seconds. 85 already precompiled.
Precompiling OptimizationOptimJL...
    356.0 ms  ✓ OptimizationBase → OptimizationFiniteDiffExt
    464.7 ms  ✓ DifferentiationInterface → DifferentiationInterfaceFiniteDiffExt
    959.2 ms  ✓ NLSolversBase
   1628.1 ms  ✓ LineSearches
   2845.2 ms  ✓ Optim
  15855.1 ms  ✓ OptimizationOptimJL
  6 dependencies successfully precompiled in 22 seconds. 137 already precompiled.
Precompiling OptimizationOptimisers...
   1622.1 ms  ✓ OptimizationOptimisers
  1 dependency successfully precompiled in 2 seconds. 116 already precompiled.
Precompiling SciMLSensitivity...
   1110.6 ms  ✓ PreallocationTools
   4673.8 ms  ✓ SciMLJacobianOperators
   8790.2 ms  ✓ Tracker
   9469.9 ms  ✓ Distributions
   2775.9 ms  ✓ DifferentiationInterface → DifferentiationInterfaceZygoteExt
  10493.6 ms  ✓ DifferentiationInterface → DifferentiationInterfaceEnzymeExt
   5891.5 ms  ✓ SciMLBase → SciMLBaseZygoteExt
  10239.7 ms  ✓ FastPower → FastPowerEnzymeExt
   2034.3 ms  ✓ DifferentiationInterface → DifferentiationInterfaceTrackerExt
   2836.4 ms  ✓ Tracker → TrackerPDMatsExt
   2353.4 ms  ✓ FastPower → FastPowerTrackerExt
   8533.3 ms  ✓ DiffEqCallbacks
   2061.8 ms  ✓ ArrayInterface → ArrayInterfaceTrackerExt
   2531.1 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsTrackerExt
  30100.8 ms  ✓ ReverseDiff
   3710.9 ms  ✓ Zygote → ZygoteTrackerExt
  32142.7 ms  ✓ LinearSolve
   4028.3 ms  ✓ DiffEqBase → DiffEqBaseTrackerExt
   2518.6 ms  ✓ Distributions → DistributionsChainRulesCoreExt
   3272.4 ms  ✓ DiffEqBase → DiffEqBaseDistributionsExt
  22030.1 ms  ✓ DiffEqBase → DiffEqBaseEnzymeExt
   6061.4 ms  ✓ DifferentiationInterface → DifferentiationInterfaceReverseDiffExt
   6367.5 ms  ✓ FastPower → FastPowerReverseDiffExt
   6185.1 ms  ✓ ArrayInterface → ArrayInterfaceReverseDiffExt
   2948.6 ms  ✓ LinearSolve → LinearSolveEnzymeExt
   5747.2 ms  ✓ PreallocationTools → PreallocationToolsReverseDiffExt
  10125.1 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsReverseDiffExt
   8250.4 ms  ✓ DiffEqBase → DiffEqBaseReverseDiffExt
   5988.7 ms  ✓ LinearSolve → LinearSolveKernelAbstractionsExt
   4501.7 ms  ✓ DiffEqNoiseProcess
   5659.5 ms  ✓ LinearSolve → LinearSolveSparseArraysExt
   4063.4 ms  ✓ DiffEqNoiseProcess → DiffEqNoiseProcessReverseDiffExt
  19291.0 ms  ✓ SciMLSensitivity
  33 dependencies successfully precompiled in 72 seconds. 243 already precompiled.
Precompiling PreallocationToolsSparseConnectivityTracerExt...
    946.9 ms  ✓ PreallocationTools → PreallocationToolsSparseConnectivityTracerExt
  1 dependency successfully precompiled in 1 seconds. 38 already precompiled.
Precompiling OptimizationEnzymeExt...
  12622.8 ms  ✓ OptimizationBase → OptimizationEnzymeExt
  1 dependency successfully precompiled in 13 seconds. 112 already precompiled.
Precompiling MLDataDevicesTrackerExt...
   1065.2 ms  ✓ MLDataDevices → MLDataDevicesTrackerExt
  1 dependency successfully precompiled in 1 seconds. 58 already precompiled.
Precompiling LuxLibTrackerExt...
    910.1 ms  ✓ LuxCore → LuxCoreArrayInterfaceTrackerExt
   3109.8 ms  ✓ LuxLib → LuxLibTrackerExt
  2 dependencies successfully precompiled in 3 seconds. 96 already precompiled.
Precompiling LuxTrackerExt...
   1770.8 ms  ✓ Lux → LuxTrackerExt
  1 dependency successfully precompiled in 2 seconds. 110 already precompiled.
Precompiling BoltzTrackerExt...
   2096.7 ms  ✓ Boltz → BoltzTrackerExt
  1 dependency successfully precompiled in 2 seconds. 129 already precompiled.
Precompiling ComponentArraysTrackerExt...
   1003.8 ms  ✓ ComponentArrays → ComponentArraysTrackerExt
  1 dependency successfully precompiled in 1 seconds. 69 already precompiled.
Precompiling MLDataDevicesReverseDiffExt...
   2922.7 ms  ✓ MLDataDevices → MLDataDevicesReverseDiffExt
  1 dependency successfully precompiled in 3 seconds. 43 already precompiled.
Precompiling LuxLibReverseDiffExt...
   2818.1 ms  ✓ LuxCore → LuxCoreArrayInterfaceReverseDiffExt
   3631.2 ms  ✓ LuxLib → LuxLibReverseDiffExt
  2 dependencies successfully precompiled in 4 seconds. 94 already precompiled.
Precompiling BoltzReverseDiffExt...
   3675.1 ms  ✓ Lux → LuxReverseDiffExt
   3990.3 ms  ✓ Boltz → BoltzReverseDiffExt
  2 dependencies successfully precompiled in 4 seconds. 129 already precompiled.
Precompiling ComponentArraysReverseDiffExt...
   2914.6 ms  ✓ ComponentArrays → ComponentArraysReverseDiffExt
  1 dependency successfully precompiled in 3 seconds. 51 already precompiled.
Precompiling OptimizationReverseDiffExt...
   2792.7 ms  ✓ OptimizationBase → OptimizationReverseDiffExt
  1 dependency successfully precompiled in 3 seconds. 118 already precompiled.
Precompiling OptimizationZygoteExt...
   1880.4 ms  ✓ OptimizationBase → OptimizationZygoteExt
  1 dependency successfully precompiled in 2 seconds. 162 already precompiled.
Precompiling DynamicExpressionsOptimExt...
   1186.8 ms  ✓ DynamicExpressions → DynamicExpressionsOptimExt
  1 dependency successfully precompiled in 1 seconds. 81 already precompiled.
Precompiling SymbolicRegression...
   5318.4 ms  ✓ DynamicQuantities
    656.2 ms  ✓ DynamicQuantities → DynamicQuantitiesLinearAlgebraExt
  71503.9 ms  ✓ SymbolicRegression
  3 dependencies successfully precompiled in 78 seconds. 107 already precompiled.
Precompiling SymbolicRegressionEnzymeExt...
  17564.3 ms  ✓ SymbolicRegression → SymbolicRegressionEnzymeExt
  1 dependency successfully precompiled in 18 seconds. 130 already precompiled.
Precompiling MLJ...
   1906.2 ms  ✓ Distributions → DistributionsTestExt
   3125.2 ms  ✓ ScientificTypes
   2268.3 ms  ✓ CategoricalDistributions
   6142.1 ms  ✓ MLJEnsembles
  10799.3 ms  ✓ MLJBase
  18344.3 ms  ✓ MLJModels
   9475.2 ms  ✓ MLJBalancing
   9774.0 ms  ✓ MLJIteration
  29330.7 ms  ✓ HTTP
   6389.1 ms  ✓ MLJTuning
   2237.4 ms  ✓ MLFlowClient
   3710.4 ms  ✓ OpenML
   4278.0 ms  ✓ MLJFlow
  30614.1 ms  ✓ StatisticalMeasures
   2263.4 ms  ✓ StatisticalMeasures → ScientificTypesExt
   2353.5 ms  ✓ MLJBase → DefaultMeasuresExt
   6013.0 ms  ✓ MLJ
  17 dependencies successfully precompiled in 45 seconds. 188 already precompiled.
Precompiling DynamicQuantitiesScientificTypesExt...
   1377.9 ms  ✓ DynamicQuantities → DynamicQuantitiesScientificTypesExt
  1 dependency successfully precompiled in 2 seconds. 69 already precompiled.
Precompiling SciMLBaseMLStyleExt...
    929.3 ms  ✓ SciMLBase → SciMLBaseMLStyleExt
  1 dependency successfully precompiled in 2 seconds. 56 already precompiled.
Precompiling OptimizationMLUtilsExt...
   1777.2 ms  ✓ OptimizationBase → OptimizationMLUtilsExt
  1 dependency successfully precompiled in 2 seconds. 144 already precompiled.
Precompiling LossFunctionsExt...
   2310.4 ms  ✓ StatisticalMeasures → LossFunctionsExt
  1 dependency successfully precompiled in 3 seconds. 138 already precompiled.
Precompiling SymbolicRegressionSymbolicUtilsExt...
   3991.7 ms  ✓ SymbolicRegression → SymbolicRegressionSymbolicUtilsExt
  1 dependency successfully precompiled in 4 seconds. 146 already precompiled.
Precompiling SymbolicUtilsReverseDiffExt...
   3790.4 ms  ✓ SymbolicUtils → SymbolicUtilsReverseDiffExt
  1 dependency successfully precompiled in 4 seconds. 87 already precompiled.
Precompiling CairoMakie...
    831.7 ms  ✓ FreeType2_jll
   1123.7 ms  ✓ Fontconfig_jll
   2151.8 ms  ✓ KernelDensity
   1375.5 ms  ✓ FreeType
    843.2 ms  ✓ Cairo_jll
    762.8 ms  ✓ HarfBuzz_jll
    738.2 ms  ✓ libass_jll
   1198.1 ms  ✓ Pango_jll
    947.9 ms  ✓ FFMPEG_jll
   1490.1 ms  ✓ Cairo
  13128.6 ms  ✓ GeometryBasics
   1907.6 ms  ✓ Packing
   2225.1 ms  ✓ ShaderAbstractions
   3393.1 ms  ✓ FreeTypeAbstraction
   6040.7 ms  ✓ MakieCore
   8844.9 ms  ✓ GridLayoutBase
  10538.3 ms  ✓ MathTeXEngine
 145957.5 ms  ✓ Makie
  89001.6 ms  ✓ CairoMakie
  19 dependencies successfully precompiled in 262 seconds. 254 already precompiled.
Precompiling DiffEqBaseUnitfulExt...
   1217.7 ms  ✓ DiffEqBase → DiffEqBaseUnitfulExt
  1 dependency successfully precompiled in 2 seconds. 93 already precompiled.
Precompiling DynamicQuantitiesUnitfulExt...
   1125.5 ms  ✓ DynamicQuantities → DynamicQuantitiesUnitfulExt
  1 dependency successfully precompiled in 1 seconds. 13 already precompiled.
Precompiling HTTPExt...
   1788.5 ms  ✓ FileIO → HTTPExt
  1 dependency successfully precompiled in 2 seconds. 43 already precompiled.
Precompiling SciMLBaseMakieExt...
   7500.8 ms  ✓ SciMLBase → SciMLBaseMakieExt
  1 dependency successfully precompiled in 8 seconds. 307 already precompiled.

Helper Functions

julia
function plot_dynamics(sol, us, ts)
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel=L"t")
    ylims!(ax, (-6, 6))

    lines!(ax, ts, sol[1, :]; label=L"u_1(t)", linewidth=3)
    lines!(ax, ts, sol[2, :]; label=L"u_2(t)", linewidth=3)

    lines!(ax, ts, vec(us); label=L"u(t)", linewidth=3)

    axislegend(ax; position=:rb)

    return fig
end
plot_dynamics (generic function with 1 method)

Training a Neural Network based UDE

Let's setup the neural network. For the first part, we won't do any symbolic regression. We will plain and simple train a neural network to solve the optimal control problem.

julia
rng = Xoshiro(0)
tspan = (0.0, 8.0)

mlp = Chain(Dense(1 => 4, gelu), Dense(4 => 4, gelu), Dense(4 => 1))

function construct_ude(mlp, solver; kwargs...)
    return @compact(; mlp, solver, kwargs...) do x_in, ps
        x, ts, ret_sol = x_in

        function dudt(du, u, p, t)
            u₁, u₂ = u
            du[1] = u₂
            du[2] = mlp([t], p)[1]^3
            return nothing
        end

        prob = ODEProblem{true}(dudt, x, extrema(ts), ps.mlp)

        sol = solve(
            prob,
            solver;
            saveat=ts,
            sensealg=QuadratureAdjoint(; autojacvec=ReverseDiffVJP(true)),
            kwargs...,
        )

        us = mlp(reshape(ts, 1, :), ps.mlp)
        ret_sol === Val(true) && @return sol, us
        @return Array(sol), us
    end
end

ude = construct_ude(mlp, Vern9(); abstol=1e-10, reltol=1e-10);

Here we are going to tuse the same configuration for testing, but this is to show that we can setup them up with different ode solve configurations

julia
ude_test = construct_ude(mlp, Vern9(); abstol=1e-10, reltol=1e-10);

function train_model_1(ude, rng, ts_)
    ps, st = Lux.setup(rng, ude)
    ps = ComponentArray{Float64}(ps)
    stateful_ude = StatefulLuxLayer{true}(ude, nothing, st)

    ts = collect(ts_)

    function loss_adjoint(θ)
        x, us = stateful_ude(([-4.0, 0.0], ts, Val(false)), θ)
        return mean(abs2, 4 .- x[1, :]) + 2 * mean(abs2, x[2, :]) + 0.1 * mean(abs2, us)
    end

    callback = function (state, l)
        state.iter % 50 == 1 && @printf "Iteration: %5d\tLoss: %10g\n" state.iter l
        return false
    end

    optf = OptimizationFunction((x, p) -> loss_adjoint(x), AutoZygote())
    optprob = OptimizationProblem(optf, ps)
    res1 = solve(optprob, Optimisers.Adam(0.001); callback, maxiters=500)

    optprob = OptimizationProblem(optf, res1.u)
    res2 = solve(optprob, LBFGS(); callback, maxiters=100)

    return StatefulLuxLayer{true}(ude, res2.u, st)
end

trained_ude = train_model_1(ude, rng, 0.0:0.01:8.0)
┌ Warning: Lux.apply(m::AbstractLuxLayer, x::AbstractArray{<:ReverseDiff.TrackedReal}, ps, st) input was corrected to Lux.apply(m::AbstractLuxLayer, x::ReverseDiff.TrackedArray}, ps, st).

│ 1. If this was not the desired behavior overload the dispatch on `m`.

│ 2. This might have performance implications. Check which layer was causing this problem using `Lux.Experimental.@debug_mode`.
└ @ LuxCoreArrayInterfaceReverseDiffExt ~/.julia/packages/LuxCore/Av7WJ/ext/LuxCoreArrayInterfaceReverseDiffExt.jl:9
Iteration:     1	Loss:    40.5618
Iteration:    51	Loss:    29.4147
Iteration:   101	Loss:    28.2559
Iteration:   151	Loss:     27.217
Iteration:   201	Loss:    26.1657
Iteration:   251	Loss:    25.1631
Iteration:   301	Loss:    24.2914
Iteration:   351	Loss:    23.5965
Iteration:   401	Loss:    23.0763
Iteration:   451	Loss:    22.6983
Iteration:     1	Loss:     22.245
Iteration:    51	Loss:    12.0085
julia
sol, us = ude_test(([-4.0, 0.0], 0.0:0.01:8.0, Val(true)), trained_ude.ps, trained_ude.st)[1];
plot_dynamics(sol, us, 0.0:0.01:8.0)

Now that the system is in a better behaved part of parameter space, we return to the original loss function to finish the optimization:

julia
function train_model_2(stateful_ude::StatefulLuxLayer, ts_)
    ts = collect(ts_)

    function loss_adjoint(θ)
        x, us = stateful_ude(([-4.0, 0.0], ts, Val(false)), θ)
        return mean(abs2, 4 .- x[1, :]) .+ 2 * mean(abs2, x[2, :]) .+ mean(abs2, us)
    end

    callback = function (state, l)
        state.iter % 10 == 1 && @printf "Iteration: %5d\tLoss: %10g\n" state.iter l
        return false
    end

    optf = OptimizationFunction((x, p) -> loss_adjoint(x), AutoZygote())
    optprob = OptimizationProblem(optf, stateful_ude.ps)
    res2 = solve(optprob, LBFGS(); callback, maxiters=100)

    return StatefulLuxLayer{true}(stateful_ude.model, res2.u, stateful_ude.st)
end

trained_ude = train_model_2(trained_ude, 0.0:0.01:8.0)
┌ Warning: Lux.apply(m::AbstractLuxLayer, x::AbstractArray{<:ReverseDiff.TrackedReal}, ps, st) input was corrected to Lux.apply(m::AbstractLuxLayer, x::ReverseDiff.TrackedArray}, ps, st).

│ 1. If this was not the desired behavior overload the dispatch on `m`.

│ 2. This might have performance implications. Check which layer was causing this problem using `Lux.Experimental.@debug_mode`.
└ @ LuxCoreArrayInterfaceReverseDiffExt ~/.julia/packages/LuxCore/Av7WJ/ext/LuxCoreArrayInterfaceReverseDiffExt.jl:9
Iteration:     1	Loss:    12.7099
Iteration:    11	Loss:    12.6761
Iteration:    21	Loss:     12.664
Iteration:    31	Loss:    12.6503
Iteration:    41	Loss:    12.6331
Iteration:    51	Loss:    12.6149
Iteration:    61	Loss:    12.5907
Iteration:    71	Loss:    12.5801
Iteration:    81	Loss:    12.5585
Iteration:    91	Loss:    12.5348
julia
sol, us = ude_test(([-4.0, 0.0], 0.0:0.01:8.0, Val(true)), trained_ude.ps, trained_ude.st)[1];
plot_dynamics(sol, us, 0.0:0.01:8.0)

Symbolic Regression

Ok so now we have a trained neural network that solves the optimal control problem. But can we replace Dense(4 => 4, gelu) with a symbolic expression? Let's try!

Data Generation for Symbolic Regression

First, we need to generate data for the symbolic regression.

julia
ts = reshape(collect(0.0:0.1:8.0), 1, :)

X_train = mlp[1](ts, trained_ude.ps.mlp.layer_1, trained_ude.st.mlp.layer_1)[1]
4×81 Matrix{Float64}:
 -0.120704   -0.107623   -0.0946398  -0.0821155  -0.0703241  -0.059458    -0.0496368   -0.0409181   -0.0333078    -0.0267716    -0.0212448    -0.0166426   -0.0128676    -0.00981729   -0.0073891   -0.00548503  -0.00401443  -0.00289594  -0.0020584   -0.00144109   -0.000993376  -0.000673966  -0.000449879  -0.000295336  -0.000190602  -0.000120879  -7.53033e-5   -4.6061e-5    -2.76522e-5   -1.62862e-5   -9.40639e-6   -5.32542e-6   -2.95411e-6   -1.60494e-6   -8.53613e-7   -4.44271e-7   -2.26168e-7   -1.12571e-7   -5.47571e-8   -2.60188e-8   -1.2072e-8    -5.46667e-9  -2.41508e-9  -1.04044e-9  -4.36901e-10  -1.78749e-10  -7.1221e-11  -2.76238e-11  -1.04251e-11  -3.82652e-12  -1.36542e-12  -4.73449e-13  -1.59453e-13  -5.21382e-14  -1.65443e-14  -5.09231e-15  -1.51973e-15  -4.39553e-16  -1.23155e-16  -3.3412e-17   -8.77331e-18  -2.22866e-18  -5.4746e-19   -1.29985e-19  -2.98178e-20  -6.60546e-21  -1.41249e-21  -2.91424e-22  -5.79873e-23  -1.11227e-23  -2.05575e-24  -3.65943e-25  -6.27114e-26  -1.03414e-26  -1.64026e-27  -2.50124e-28  -3.66533e-29  -5.15932e-30  -6.97266e-31  -9.04352e-32  -1.12516e-32
  0.284447    0.191474    0.108764    0.0369542  -0.0236105  -0.0729022   -0.111211    -0.139126    -0.157498     -0.167391     -0.170012     -0.166652    -0.158613     -0.147152     -0.133429    -0.118465    -0.103124    -0.0880986   -0.0739084   -0.0609151    -0.0493372    -0.0392727    -0.0307237    -0.0236197    -0.0178406    -0.0132361    -0.0096422    -0.00689425   -0.00483611   -0.00332654   -0.00224257   -0.00148086   -0.000957287  -0.00060544   -0.000374394  -0.000226225  -0.000133483  -7.68597e-5   -4.31589e-5   -2.36184e-5   -1.25877e-5   -6.52923e-6  -3.29382e-6  -1.61496e-6  -7.6904e-7    -3.55435e-7   -1.59328e-7  -6.92222e-8   -2.91282e-8   -1.1863e-8    -4.67283e-9   -1.77897e-9   -6.54109e-10  -2.32125e-10  -7.94465e-11  -2.6206e-11   -8.32519e-12  -2.54533e-12  -7.48418e-13  -2.11487e-13  -5.73924e-14  -1.49467e-14  -3.73292e-15  -8.93415e-16  -2.04762e-16  -4.49082e-17  -9.4183e-18   -1.88746e-18  -3.61187e-19  -6.59513e-20  -1.14826e-20  -1.90489e-21  -3.00888e-22  -4.52198e-23  -6.46148e-24  -8.77208e-25  -1.13065e-25  -1.3826e-26   -1.60286e-27  -1.7604e-28   -1.83035e-29
 -0.162793   -0.142606   -0.116171   -0.0887784  -0.0639403  -0.0434972   -0.0279656   -0.0169827   -0.00972666   -0.00524259   -0.002652     -0.00125513  -0.000553878  -0.000227077  -8.61651e-5  -3.01445e-5  -9.68487e-6  -2.84613e-6  -7.61976e-7  -1.85093e-7   -4.06283e-8   -8.02553e-9   -1.42081e-9   -2.245e-10    -3.15296e-11  -3.91952e-12  -4.29487e-13  -4.13102e-14  -3.47325e-15  -2.54196e-16  -1.61262e-17  -8.83088e-19  -4.15679e-20  -1.67481e-21  -5.75174e-23  -1.6766e-24   -4.13069e-26  -8.56539e-28  -1.48857e-29  -2.15902e-31  -2.60239e-33  -0.0         -0.0         -0.0         -0.0          -0.0          -0.0         -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0
 -0.0873662  -0.0613839  -0.0404561  -0.0250232  -0.0145118  -0.00787502  -0.00398753  -0.00187753  -0.000818874  -0.000329434  -0.000121709  -4.11053e-5  -1.26319e-5   -3.5154e-6    -8.81734e-7  -1.98363e-7  -3.98318e-8  -7.10437e-9  -1.11999e-9  -1.55295e-10  -1.88456e-11  -1.99166e-12  -1.82399e-13  -1.44034e-14  -9.7585e-16   -5.64424e-17  -2.77308e-18  -1.15154e-19  -4.02145e-21  -1.17515e-22  -2.85914e-24  -5.76268e-26  -9.57371e-28  -1.30442e-29  -1.45029e-31  -1.3092e-33   -0.0          -0.0          -0.0          -0.0          -0.0          -0.0         -0.0         -0.0         -0.0          -0.0          -0.0         -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0          -0.0

This is the training input data. Now we generate the targets

julia
Y_train = mlp[2](X_train, trained_ude.ps.mlp.layer_2, trained_ude.st.mlp.layer_2)[1]
4×81 Matrix{Float64}:
 0.0258644  0.0276239   0.0251692   0.0209178   0.0164794   0.0126878   0.00979604   0.0077031   0.00614522   0.0048323   0.0035284   0.00208602   0.000446717  -0.00137783  -0.00333192  -0.00534041  -0.0073271  -0.00922609  -0.0109871  -0.0125765  -0.0139762  -0.0151813  -0.0161968  -0.0170352  -0.0177135  -0.0182515  -0.01867   -0.0189891  -0.0192275  -0.0194021  -0.0195272  -0.0196151  -0.0196754  -0.0197159  -0.0197425  -0.0197596  -0.0197702  -0.0197768  -0.0197806  -0.0197829  -0.0197841  -0.0197848  -0.0197852  -0.0197854  -0.0197855  -0.0197855  -0.0197856  -0.0197856  -0.0197856  -0.0197856  -0.0197856  -0.0197856  -0.0197856  -0.0197856  -0.0197856  -0.0197856  -0.0197856  -0.0197856  -0.0197856  -0.0197856  -0.0197856  -0.0197856  -0.0197856  -0.0197856  -0.0197856  -0.0197856  -0.0197856  -0.0197856  -0.0197856  -0.0197856  -0.0197856  -0.0197856  -0.0197856  -0.0197856  -0.0197856  -0.0197856  -0.0197856  -0.0197856  -0.0197856  -0.0197856  -0.0197856
 1.15615    0.893431    0.66566     0.479893    0.336555    0.231173    0.157009     0.107194    0.075877     0.058558    0.0519547   0.0536973    0.0620088     0.0754495    0.0927485    0.112718     0.134233    0.156257     0.177881    0.19836     0.217143    0.233871    0.248367    0.260613    0.270707    0.278833    0.285228   0.290149    0.293852    0.296577    0.29854     0.299922    0.300873    0.301512    0.301933    0.302203    0.302372    0.302475    0.302536    0.302572    0.302592    0.302603    0.302609    0.302612    0.302613    0.302614    0.302614    0.302615    0.302615    0.302615    0.302615    0.302615    0.302615    0.302615    0.302615    0.302615    0.302615    0.302615    0.302615    0.302615    0.302615    0.302615    0.302615    0.302615    0.302615    0.302615    0.302615    0.302615    0.302615    0.302615    0.302615    0.302615    0.302615    0.302615    0.302615    0.302615    0.302615    0.302615    0.302615    0.302615    0.302615
 0.734348   0.908501    1.06398     1.19699     1.30671     1.39399     1.46054      1.50839     1.53967      1.55659     1.56133     1.5561       1.54306       1.52427      1.50164      1.47686      1.45139     1.42637      1.40271     1.38101     1.36166     1.34482     1.33051     1.31862     1.30894     1.30122     1.2952     1.29059     1.28714     1.28461     1.2828      1.28152     1.28064     1.28005     1.27966     1.27942     1.27926     1.27916     1.27911     1.27908     1.27906     1.27905     1.27904     1.27904     1.27904     1.27904     1.27904     1.27904     1.27904     1.27904     1.27904     1.27904     1.27904     1.27904     1.27904     1.27904     1.27904     1.27904     1.27904     1.27904     1.27904     1.27904     1.27904     1.27904     1.27904     1.27904     1.27904     1.27904     1.27904     1.27904     1.27904     1.27904     1.27904     1.27904     1.27904     1.27904     1.27904     1.27904     1.27904     1.27904     1.27904
 0.510947   0.184332   -0.0184213  -0.121812   -0.162144   -0.169999   -0.164754    -0.156503   -0.149661    -0.145724   -0.144839   -0.146556    -0.150164     -0.154843    -0.159763    -0.164184    -0.167543   -0.169516    -0.170028   -0.16922    -0.167379   -0.164859   -0.162008   -0.159119   -0.156409   -0.154013   -0.151994  -0.150361   -0.149086   -0.148123   -0.147417   -0.146913   -0.146564   -0.146327   -0.146171   -0.14607    -0.146007   -0.145969   -0.145946   -0.145933   -0.145925   -0.145921   -0.145919   -0.145918   -0.145917   -0.145917   -0.145917   -0.145917   -0.145916   -0.145916   -0.145916   -0.145916   -0.145916   -0.145916   -0.145916   -0.145916   -0.145916   -0.145916   -0.145916   -0.145916   -0.145916   -0.145916   -0.145916   -0.145916   -0.145916   -0.145916   -0.145916   -0.145916   -0.145916   -0.145916   -0.145916   -0.145916   -0.145916   -0.145916   -0.145916   -0.145916   -0.145916   -0.145916   -0.145916   -0.145916   -0.145916

Fitting the Symbolic Expression

We will follow the example from SymbolicRegression.jl docs to fit the symbolic expression.

julia
srmodel = MultitargetSRRegressor(;
    binary_operators=[+, -, *, /], niterations=100, save_to_file=false
);

One important note here is to transpose the data because that is how MLJ expects the data to be structured (this is in contrast to how Lux or SymbolicRegression expects the data)

julia
mach = machine(srmodel, X_train', Y_train')
fit!(mach; verbosity=0)
r = report(mach)
best_eq = [
    r.equations[1][r.best_idx[1]],
    r.equations[2][r.best_idx[2]],
    r.equations[3][r.best_idx[3]],
    r.equations[4][r.best_idx[4]],
]
4-element Vector{DynamicExpressions.ExpressionModule.Expression{Float64, DynamicExpressions.NodeModule.Node{Float64}, @NamedTuple{operators::DynamicExpressions.OperatorEnumModule.OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{}}, variable_names::Vector{String}}}}:
 (x3 + (((((x1 * ((x1 * -2.571332464832935) + -1.2506172908804345)) * ((x4 + x2) / (-0.4995489388459282 - x2))) - x2) * -0.2346860578013044) - -0.03877183753159419)) * -0.5106275965445597
 ((x2 + ((x1 + 1.8154095463794062) - (x2 * ((x2 + x1) + -0.6099459354506391)))) * (x2 - x3)) + (0.3026148961127518 - (x3 * -0.542910731775629))
 (x2 * -1.678225535156146) + ((x4 + 1.2790322763056396) + (((x1 * 2.58289672213348) + 0.14680398353376703) * (x1 + ((((x4 - x1) + (x4 * (x2 - -0.113273574956859))) * x2) / 0.3590934006394513))))
 (((x3 - ((x4 / 0.2667449510321438) - x2)) * ((x2 + x2) + (x4 + 0.334038458901784))) + -0.08278131152493046) * (1.7642104058680492 - (x3 - x1))

Let's see the expressions that SymbolicRegression.jl found. In case you were wondering, these expressions are not hardcoded, it is live updated from the output of the code above using Latexify.jl and the integration of SymbolicUtils.jl with DynamicExpressions.jl.

(x3+(x1(x12.57131.2506)x4+x20.49955x2x2)0.23469+0.038772)0.51063(x2+x1+1.8154x2(x2+x10.60995))(x2x3)+0.30261x30.54291x21.6782+x4+1.279+(x12.5829+0.1468)(x1+(x4x1+x4(x2+0.11327))x20.35909)((x3(x40.26674x2))(x2+x2+x4+0.33404)0.082781)(1.7642(x3x1))

Combining the Neural Network with the Symbolic Expression

Now that we have the symbolic expression, we can combine it with the neural network to solve the optimal control problem. but we do need to perform some finetuning.

julia
hybrid_mlp = Chain(
    Dense(1 => 4, gelu),
    Layers.DynamicExpressionsLayer(OperatorEnum(; binary_operators=[+, -, *, /]), best_eq),
    Dense(4 => 1),
)
Chain(
    layer_1 = Dense(1 => 4, gelu_tanh),  # 8 parameters
    layer_2 = DynamicExpressionsLayer(
        chain = Chain(
            layer_1 = Parallel(
                layer_1 = InternalDynamicExpressionWrapper(DynamicExpressions.OperatorEnumModule.OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{}}((+, -, *, /), ()), (x3 + (((((x1 * ((x1 * -2.571332464832935) + -1.2506172908804345)) * ((x4 + x2) / (-0.4995489388459282 - x2))) - x2) * -0.2346860578013044) - -0.03877183753159419)) * -0.5106275965445597; eval_options=(turbo = Val{false}(), bumper = Val{false}())),  # 6 parameters
                layer_2 = InternalDynamicExpressionWrapper(DynamicExpressions.OperatorEnumModule.OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{}}((+, -, *, /), ()), ((x2 + ((x1 + 1.8154095463794062) - (x2 * ((x2 + x1) + -0.6099459354506391)))) * (x2 - x3)) + (0.3026148961127518 - (x3 * -0.542910731775629)); eval_options=(turbo = Val{false}(), bumper = Val{false}())),  # 4 parameters
                layer_3 = InternalDynamicExpressionWrapper(DynamicExpressions.OperatorEnumModule.OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{}}((+, -, *, /), ()), (x2 * -1.678225535156146) + ((x4 + 1.2790322763056396) + (((x1 * 2.58289672213348) + 0.14680398353376703) * (x1 + ((((x4 - x1) + (x4 * (x2 - -0.113273574956859))) * x2) / 0.3590934006394513)))); eval_options=(turbo = Val{false}(), bumper = Val{false}())),  # 6 parameters
                layer_4 = InternalDynamicExpressionWrapper(DynamicExpressions.OperatorEnumModule.OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{}}((+, -, *, /), ()), (((x3 - ((x4 / 0.2667449510321438) - x2)) * ((x2 + x2) + (x4 + 0.334038458901784))) + -0.08278131152493046) * (1.7642104058680492 - (x3 - x1)); eval_options=(turbo = Val{false}(), bumper = Val{false}())),  # 4 parameters
            ),
            layer_2 = WrappedFunction(stack1),
        ),
    ),
    layer_3 = Dense(4 => 1),            # 5 parameters
)         # Total: 33 parameters,
          #        plus 0 states.

There you have it! It is that easy to take the fitted Symbolic Expression and combine it with a neural network. Let's see how it performs before fintetuning.

julia
hybrid_ude = construct_ude(hybrid_mlp, Vern9(); abstol=1e-10, reltol=1e-10);

We want to reuse the trained neural network parameters, so we will copy them over to the new model

julia
st = Lux.initialstates(rng, hybrid_ude)
ps = (;
    mlp=(;
        layer_1=trained_ude.ps.mlp.layer_1,
        layer_2=Lux.initialparameters(rng, hybrid_mlp[2]),
        layer_3=trained_ude.ps.mlp.layer_3,
    )
)
ps = ComponentArray(ps)

sol, us = hybrid_ude(([-4.0, 0.0], 0.0:0.01:8.0, Val(true)), ps, st)[1];
plot_dynamics(sol, us, 0.0:0.01:8.0)

Now that does perform well! But we could finetune this model very easily. We will skip that part on CI, but you can do it by using the same training code as above.

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.4
Commit 8561cc3d68d (2025-03-10 11:36 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)
Environment:
  JULIA_NUM_THREADS = 1
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.