Skip to content

Getting Started

Prerequisites

Here we assume that you are familiar with Lux.jl. If not please take a look at the Lux.jl tutoials.

Boltz.jl is just like Lux.jl but comes with more "batteries included". Let's start by defining an MLP model.

julia
using Lux, Boltz, Random
Precompiling Lux...
    996.2 ms  ✓ Functors
   1255.6 ms  ✓ LuxCore
   1415.7 ms  ✓ MLDataDevices
   1039.7 ms  ✓ LuxCore → LuxCoreFunctorsExt
   1069.9 ms  ✓ LuxCore → LuxCoreEnzymeCoreExt
   1133.8 ms  ✓ LuxCore → LuxCoreChainRulesCoreExt
   1860.0 ms  ✓ Optimisers
   1034.8 ms  ✓ LuxCore → LuxCoreSetfieldExt
   1096.7 ms  ✓ MLDataDevices → MLDataDevicesChainRulesCoreExt
    975.9 ms  ✓ MLDataDevices → MLDataDevicesFillArraysExt
   2222.8 ms  ✓ MLDataDevices → MLDataDevicesGPUArraysExt
   1443.4 ms  ✓ MLDataDevices → MLDataDevicesSparseArraysExt
   3348.9 ms  ✓ MLDataDevices → MLDataDevicesOneHotArraysExt
    983.0 ms  ✓ LuxCore → LuxCoreMLDataDevicesExt
   1402.5 ms  ✓ MLDataDevices → MLDataDevicesChainRulesExt
   2694.0 ms  ✓ MLDataDevices → MLDataDevicesZygoteExt
   3355.4 ms  ✓ MLDataDevices → MLDataDevicesMLUtilsExt
   5882.9 ms  ✓ LuxLib
   9079.3 ms  ✓ Lux
   2818.5 ms  ✓ Lux → LuxMLUtilsExt
   3232.8 ms  ✓ Lux → LuxZygoteExt
  21 dependencies successfully precompiled in 26 seconds. 190 already precompiled.
Precompiling Boltz...
   6425.5 ms  ✓ Flux
   3514.6 ms  ✓ Lux → LuxFluxExt
   5240.1 ms  ✓ Boltz
   2734.8 ms  ✓ Boltz → BoltzZygoteExt
  4 dependencies successfully precompiled in 18 seconds. 212 already precompiled.

Multi-Layer Perceptron

If we were to do this in Lux.jl we would write the following:

julia
model = Chain(
    Dense(784, 256, relu),
    Dense(256, 10)
)
Chain(
    layer_1 = Dense(784 => 256, relu),  # 200_960 parameters
    layer_2 = Dense(256 => 10),         # 2_570 parameters
)         # Total: 203_530 parameters,
          #        plus 0 states.

But in Boltz.jl we can do this:

julia
model = Layers.MLP(784, (256, 10), relu)
MLP(
    chain = Chain(
        block1 = DenseNormActDropoutBlock(
            block = Chain(
                dense = Dense(784 => 256, relu),  # 200_960 parameters
            ),
        ),
        block2 = DenseNormActDropoutBlock(
            block = Chain(
                dense = Dense(256 => 10),  # 2_570 parameters
            ),
        ),
    ),
)         # Total: 203_530 parameters,
          #        plus 0 states.

The MLP function is just a convenience wrapper around Lux.Chain that constructs a multi-layer perceptron with the given number of layers and activation function.

How about VGG?

Let's take a look at the Vision module. We can construct a VGG model with the following code:

julia
Vision.VGG(13)
VGG(
    layer = Chain(
        feature_extractor = VGGFeatureExtractor(
            model = Chain(
                layer_1 = ConvNormActivation(
                    model = Chain(
                        block1 = ConvNormActivationBlock(
                            block = Conv((3, 3), 3 => 64, relu, pad=1),  # 1_792 parameters
                        ),
                        block2 = ConvNormActivationBlock(
                            block = Conv((3, 3), 64 => 64, relu, pad=1),  # 36_928 parameters
                        ),
                    ),
                ),
                layer_2 = MaxPool((2, 2)),
                layer_3 = ConvNormActivation(
                    model = Chain(
                        block1 = ConvNormActivationBlock(
                            block = Conv((3, 3), 64 => 128, relu, pad=1),  # 73_856 parameters
                        ),
                        block2 = ConvNormActivationBlock(
                            block = Conv((3, 3), 128 => 128, relu, pad=1),  # 147_584 parameters
                        ),
                    ),
                ),
                layer_4 = MaxPool((2, 2)),
                layer_5 = ConvNormActivation(
                    model = Chain(
                        block1 = ConvNormActivationBlock(
                            block = Conv((3, 3), 128 => 256, relu, pad=1),  # 295_168 parameters
                        ),
                        block2 = ConvNormActivationBlock(
                            block = Conv((3, 3), 256 => 256, relu, pad=1),  # 590_080 parameters
                        ),
                    ),
                ),
                layer_6 = MaxPool((2, 2)),
                layer_7 = ConvNormActivation(
                    model = Chain(
                        block1 = ConvNormActivationBlock(
                            block = Conv((3, 3), 256 => 512, relu, pad=1),  # 1_180_160 parameters
                        ),
                        block2 = ConvNormActivationBlock(
                            block = Conv((3, 3), 512 => 512, relu, pad=1),  # 2_359_808 parameters
                        ),
                    ),
                ),
                layer_8 = MaxPool((2, 2)),
                layer_9 = ConvNormActivation(
                    model = Chain(
                        block1 = ConvNormActivationBlock(
                            block = Conv((3, 3), 512 => 512, relu, pad=1),  # 2_359_808 parameters
                        ),
                        block2 = ConvNormActivationBlock(
                            block = Conv((3, 3), 512 => 512, relu, pad=1),  # 2_359_808 parameters
                        ),
                    ),
                ),
                layer_10 = MaxPool((2, 2)),
            ),
        ),
        classifier = VGGClassifier(
            model = Chain(
                layer_1 = Lux.FlattenLayer{Nothing}(nothing),
                layer_2 = Dense(25088 => 4096, relu),  # 102_764_544 parameters
                layer_3 = Dropout(0.5),
                layer_4 = Dense(4096 => 4096, relu),  # 16_781_312 parameters
                layer_5 = Dropout(0.5),
                layer_6 = Dense(4096 => 1000),  # 4_097_000 parameters
            ),
        ),
    ),
)         # Total: 133_047_848 parameters,
          #        plus 4 states.

We can also load pretrained ImageNet weights using

Load JLD2

You need to load JLD2 before being able to load pretrained weights.

julia
using JLD2

Vision.VGG(13; pretrained=true)
VGG(
    layer = Chain(
        feature_extractor = VGGFeatureExtractor(
            model = Chain(
                layer_1 = ConvNormActivation(
                    model = Chain(
                        block1 = ConvNormActivationBlock(
                            block = Conv((3, 3), 3 => 64, relu, pad=1),  # 1_792 parameters
                        ),
                        block2 = ConvNormActivationBlock(
                            block = Conv((3, 3), 64 => 64, relu, pad=1),  # 36_928 parameters
                        ),
                    ),
                ),
                layer_2 = MaxPool((2, 2)),
                layer_3 = ConvNormActivation(
                    model = Chain(
                        block1 = ConvNormActivationBlock(
                            block = Conv((3, 3), 64 => 128, relu, pad=1),  # 73_856 parameters
                        ),
                        block2 = ConvNormActivationBlock(
                            block = Conv((3, 3), 128 => 128, relu, pad=1),  # 147_584 parameters
                        ),
                    ),
                ),
                layer_4 = MaxPool((2, 2)),
                layer_5 = ConvNormActivation(
                    model = Chain(
                        block1 = ConvNormActivationBlock(
                            block = Conv((3, 3), 128 => 256, relu, pad=1),  # 295_168 parameters
                        ),
                        block2 = ConvNormActivationBlock(
                            block = Conv((3, 3), 256 => 256, relu, pad=1),  # 590_080 parameters
                        ),
                    ),
                ),
                layer_6 = MaxPool((2, 2)),
                layer_7 = ConvNormActivation(
                    model = Chain(
                        block1 = ConvNormActivationBlock(
                            block = Conv((3, 3), 256 => 512, relu, pad=1),  # 1_180_160 parameters
                        ),
                        block2 = ConvNormActivationBlock(
                            block = Conv((3, 3), 512 => 512, relu, pad=1),  # 2_359_808 parameters
                        ),
                    ),
                ),
                layer_8 = MaxPool((2, 2)),
                layer_9 = ConvNormActivation(
                    model = Chain(
                        block1 = ConvNormActivationBlock(
                            block = Conv((3, 3), 512 => 512, relu, pad=1),  # 2_359_808 parameters
                        ),
                        block2 = ConvNormActivationBlock(
                            block = Conv((3, 3), 512 => 512, relu, pad=1),  # 2_359_808 parameters
                        ),
                    ),
                ),
                layer_10 = MaxPool((2, 2)),
            ),
        ),
        classifier = VGGClassifier(
            model = Chain(
                layer_1 = Lux.FlattenLayer{Nothing}(nothing),
                layer_2 = Dense(25088 => 4096, relu),  # 102_764_544 parameters
                layer_3 = Dropout(0.5),
                layer_4 = Dense(4096 => 4096, relu),  # 16_781_312 parameters
                layer_5 = Dropout(0.5),
                layer_6 = Dense(4096 => 1000),  # 4_097_000 parameters
            ),
        ),
    ),
)         # Total: 133_047_848 parameters,
          #        plus 4 states.

Loading Models from Metalhead (Flux.jl)

We can load models from Metalhead (Flux.jl), just remember to load Metalhead before.

julia
using Metalhead

Vision.ResNet(18)
MetalheadWrapperLayer(
    layer = Chain(
        layer_1 = Chain(
            layer_1 = Chain(
                layer_1 = Conv((7, 7), 3 => 64, pad=3, stride=2, use_bias=false),  # 9_408 parameters
                layer_2 = BatchNorm(64, relu, affine=true, track_stats=true),  # 128 parameters, plus 129
                layer_3 = MaxPool((3, 3), pad=1, stride=2),
            ),
            layer_2 = Chain(
                layer_1 = Parallel(
                    connection = addact(NNlib.relu, ...),
                    layer_1 = Lux.NoOpLayer(),
                    layer_2 = Chain(
                        layer_1 = Conv((3, 3), 64 => 64, pad=1, use_bias=false),  # 36_864 parameters
                        layer_2 = BatchNorm(64, affine=true, track_stats=true),  # 128 parameters, plus 129
                        layer_3 = WrappedFunction(relu),
                        layer_4 = Conv((3, 3), 64 => 64, pad=1, use_bias=false),  # 36_864 parameters
                        layer_5 = BatchNorm(64, affine=true, track_stats=true),  # 128 parameters, plus 129
                    ),
                ),
                layer_2 = Parallel(
                    connection = addact(NNlib.relu, ...),
                    layer_1 = Lux.NoOpLayer(),
                    layer_2 = Chain(
                        layer_1 = Conv((3, 3), 64 => 64, pad=1, use_bias=false),  # 36_864 parameters
                        layer_2 = BatchNorm(64, affine=true, track_stats=true),  # 128 parameters, plus 129
                        layer_3 = WrappedFunction(relu),
                        layer_4 = Conv((3, 3), 64 => 64, pad=1, use_bias=false),  # 36_864 parameters
                        layer_5 = BatchNorm(64, affine=true, track_stats=true),  # 128 parameters, plus 129
                    ),
                ),
            ),
            layer_3 = Chain(
                layer_1 = Parallel(
                    connection = addact(NNlib.relu, ...),
                    layer_1 = Chain(
                        layer_1 = Conv((1, 1), 64 => 128, stride=2, use_bias=false),  # 8_192 parameters
                        layer_2 = BatchNorm(128, affine=true, track_stats=true),  # 256 parameters, plus 257
                    ),
                    layer_2 = Chain(
                        layer_1 = Conv((3, 3), 64 => 128, pad=1, stride=2, use_bias=false),  # 73_728 parameters
                        layer_2 = BatchNorm(128, affine=true, track_stats=true),  # 256 parameters, plus 257
                        layer_3 = WrappedFunction(relu),
                        layer_4 = Conv((3, 3), 128 => 128, pad=1, use_bias=false),  # 147_456 parameters
                        layer_5 = BatchNorm(128, affine=true, track_stats=true),  # 256 parameters, plus 257
                    ),
                ),
                layer_2 = Parallel(
                    connection = addact(NNlib.relu, ...),
                    layer_1 = Lux.NoOpLayer(),
                    layer_2 = Chain(
                        layer_1 = Conv((3, 3), 128 => 128, pad=1, use_bias=false),  # 147_456 parameters
                        layer_2 = BatchNorm(128, affine=true, track_stats=true),  # 256 parameters, plus 257
                        layer_3 = WrappedFunction(relu),
                        layer_4 = Conv((3, 3), 128 => 128, pad=1, use_bias=false),  # 147_456 parameters
                        layer_5 = BatchNorm(128, affine=true, track_stats=true),  # 256 parameters, plus 257
                    ),
                ),
            ),
            layer_4 = Chain(
                layer_1 = Parallel(
                    connection = addact(NNlib.relu, ...),
                    layer_1 = Chain(
                        layer_1 = Conv((1, 1), 128 => 256, stride=2, use_bias=false),  # 32_768 parameters
                        layer_2 = BatchNorm(256, affine=true, track_stats=true),  # 512 parameters, plus 513
                    ),
                    layer_2 = Chain(
                        layer_1 = Conv((3, 3), 128 => 256, pad=1, stride=2, use_bias=false),  # 294_912 parameters
                        layer_2 = BatchNorm(256, affine=true, track_stats=true),  # 512 parameters, plus 513
                        layer_3 = WrappedFunction(relu),
                        layer_4 = Conv((3, 3), 256 => 256, pad=1, use_bias=false),  # 589_824 parameters
                        layer_5 = BatchNorm(256, affine=true, track_stats=true),  # 512 parameters, plus 513
                    ),
                ),
                layer_2 = Parallel(
                    connection = addact(NNlib.relu, ...),
                    layer_1 = Lux.NoOpLayer(),
                    layer_2 = Chain(
                        layer_1 = Conv((3, 3), 256 => 256, pad=1, use_bias=false),  # 589_824 parameters
                        layer_2 = BatchNorm(256, affine=true, track_stats=true),  # 512 parameters, plus 513
                        layer_3 = WrappedFunction(relu),
                        layer_4 = Conv((3, 3), 256 => 256, pad=1, use_bias=false),  # 589_824 parameters
                        layer_5 = BatchNorm(256, affine=true, track_stats=true),  # 512 parameters, plus 513
                    ),
                ),
            ),
            layer_5 = Chain(
                layer_1 = Parallel(
                    connection = addact(NNlib.relu, ...),
                    layer_1 = Chain(
                        layer_1 = Conv((1, 1), 256 => 512, stride=2, use_bias=false),  # 131_072 parameters
                        layer_2 = BatchNorm(512, affine=true, track_stats=true),  # 1_024 parameters, plus 1_025
                    ),
                    layer_2 = Chain(
                        layer_1 = Conv((3, 3), 256 => 512, pad=1, stride=2, use_bias=false),  # 1_179_648 parameters
                        layer_2 = BatchNorm(512, affine=true, track_stats=true),  # 1_024 parameters, plus 1_025
                        layer_3 = WrappedFunction(relu),
                        layer_4 = Conv((3, 3), 512 => 512, pad=1, use_bias=false),  # 2_359_296 parameters
                        layer_5 = BatchNorm(512, affine=true, track_stats=true),  # 1_024 parameters, plus 1_025
                    ),
                ),
                layer_2 = Parallel(
                    connection = addact(NNlib.relu, ...),
                    layer_1 = Lux.NoOpLayer(),
                    layer_2 = Chain(
                        layer_1 = Conv((3, 3), 512 => 512, pad=1, use_bias=false),  # 2_359_296 parameters
                        layer_2 = BatchNorm(512, affine=true, track_stats=true),  # 1_024 parameters, plus 1_025
                        layer_3 = WrappedFunction(relu),
                        layer_4 = Conv((3, 3), 512 => 512, pad=1, use_bias=false),  # 2_359_296 parameters
                        layer_5 = BatchNorm(512, affine=true, track_stats=true),  # 1_024 parameters, plus 1_025
                    ),
                ),
            ),
        ),
        layer_2 = Chain(
            layer_1 = AdaptiveMeanPool((1, 1)),
            layer_2 = WrappedFunction(flatten),
            layer_3 = Dense(512 => 1000),  # 513_000 parameters
        ),
    ),
)         # Total: 11_689_512 parameters,
          #        plus 9_620 states.

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.1
Commit 8f5b7ca12ad (2024-10-16 10:53 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)
Environment:
  JULIA_NUM_THREADS = 1
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.